<a href="https://colab.research.google.com/github/YossiAsher/abstract-learning-in-image-processing/blob/main/png_resnet50.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab.patches import cv2_imshow

import numpy as np
import matplotlib.pyplot as plt
import glob

import os
import tensorflow as tf

from tensorflow.keras.preprocessing import image_dataset_from_directory

import pathlib
import shutil
from pathlib import Path


In [None]:
png_zip_link = 'https://drive.google.com/file/d/1yGUqTua6S_7zG4B_XMe5k8cAx2Oz9Kxw/view?usp=sharing'

from googleapiclient.http import MediaIoBaseDownload
import io
from google.colab import auth
auth.authenticate_user()
from googleapiclient.discovery import build
drive_service = build('drive', 'v3').files()

In [None]:
def download_file(name, link):
  fileId = link.split('/')[-2]
  request = drive_service.get_media(fileId=fileId)
  fh = io.BytesIO()
  downloader = MediaIoBaseDownload(fh, request)
  done = False
  while done is False:
      status, done = downloader.next_chunk()
      print("Download %d%%" % int(status.progress() * 100))
  fh.seek(0)
  with open(name, 'wb') as f:
      shutil.copyfileobj(fh, f)

In [None]:
download_file('png.zip', png_zip_link)

In [None]:
!unzip png.zip

In [None]:
%load_ext tensorboard

In [None]:
!rm -rf ./logs/ 

In [None]:
BATCH_SIZE = 32
IMG_SIZE = (256, 256)

train_dataset = image_dataset_from_directory('png',
                                             validation_split=0.1,
                                             subset="training",
                                             seed=1337,
                                             shuffle=True,
                                             batch_size=BATCH_SIZE,
                                             image_size=IMG_SIZE)

test_dataset = image_dataset_from_directory('png',
                                             validation_split=0.1,
                                             subset="validation",
                                             seed=1337,
                                             shuffle=True,
                                             batch_size=BATCH_SIZE,
                                             image_size=IMG_SIZE)
num_classes = len(train_dataset.class_names)

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)

In [None]:
preprocess_input = tf.keras.applications.resnet50.preprocess_input

In [None]:
rescale = tf.keras.layers.experimental.preprocessing.Rescaling(1./127.5, offset= -1)

In [None]:
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.ResNet50(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')
base_model.trainable = True

In [None]:
image_batch, label_batch = next(iter(train_dataset))
feature_batch = base_model(image_batch)
print(feature_batch.shape)

In [None]:
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)

In [None]:
prediction_layer = tf.keras.layers.Dense(num_classes)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)

In [None]:
inputs = tf.keras.Input(shape=IMG_SHAPE)
x = preprocess_input(inputs)
x = base_model(x, training=True)
x = global_average_layer(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)

In [None]:
model.summary()

In [None]:
base_learning_rate = 0.00001
model.compile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [None]:
log_dir = "logs/png-resnet50"
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

In [None]:
initial_epochs = 10

history = model.fit(train_dataset,
                    validation_data=test_dataset,
                    callbacks=[tensorboard_callback],
                    epochs=initial_epochs)

In [None]:
%tensorboard --logdir logs

In [None]:
!tensorboard dev upload \
  --logdir logs/png-resnet50 \
  --name "abstract-learning-in-image-processing-png-resnet50" \
  --one_shot