In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import zipfile
zip_ref = zipfile.ZipFile("/content/drive/MyDrive/transformed_oct.zip", 'r')
zip_ref.extractall("/content/")
zip_ref.close()

In [None]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Flatten, Dropout, BatchNormalization, ReLU
from tensorflow.keras.models import Model
from tensorflow.keras.applications import InceptionV3, ResNet50
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import Accuracy, Precision, Recall, AUC
from tensorflow.keras.callbacks import ReduceLROnPlateau, Callback

In [None]:
train_dir = '/content/Downstream Task/train'
val_dir = '/content/Downstream Task/validation'

train_datagen = ImageDataGenerator(rescale=1./255)

val_datagen = ImageDataGenerator(rescale=1./255)

# Create the train generator
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=64,
    class_mode='categorical'
)

# Create the validation generator
val_generator = val_datagen.flow_from_directory(
    val_dir,
    target_size=(224, 224),
    batch_size=62,
    class_mode='categorical'
)

In [None]:
# Load the pre-trained ResNet50 model
base_model = ResNet50(weights=None, include_top=False, input_shape=(156, 156, 3))

# Freeze layers
for layer in base_model.layers[:50]:
     layer.trainable = False

# Set trainable
for layer in base_model.layers[50:]:
     layer.trainable = True

# New layers

#x = base_model.get_layer("conv3_block4_out").output
x = base_model.layers[-1].output
x = GlobalAveragePooling2D()(x)
x = Dense(128, activation='relu')(x)
output = Dense(4, activation='softmax')(x)  # 4 output classes

# Define model
downstream_model = Model(inputs=base_model.input, outputs=output)

# Load weights
downstream_model.load_weights('/content/drive/MyDrive/weights/Full_Model/pretext_task_weights.weights.h5',skip_mismatch=True)

reduce_lr = ReduceLROnPlateau(monitor='val_accuracy', factor=0.2,
                              patience=2, min_lr=0.000001,verbose=1)

# Compile the model for the downstream task
downstream_model.compile(
    optimizer= tf.keras.optimizers.Adam(learning_rate=0.001),
    loss='categorical_crossentropy',
    metrics = ['accuracy', 'precision', 'recall', 'auc']
)

In [None]:
# Train on labeled data
history = downstream_model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=20,
    verbose = 1,
    callbacks=[reduce_lr],
)

In [None]:
weights_path = '/content/drive/MyDrive/weights/Full_Model/downstream_task_weights.weights.h5'
downstream_model.save_weights(weights_path)

print(f"Weights saved at: {weights_path}")

In [None]:
# Terminate
import os, signal
os.kill(os.getpid(), signal.SIGKILL)