<a href="https://colab.research.google.com/github/Chen-Terese/CNN-SOM-code/blob/main/Thesis2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#from google.colab import drive
#drive.mount('/content/drive')

# **Required Libraries**

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.resnet import preprocess_input
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# **Loading the Data**

In [None]:
DATA_PATH = '/content/drive/MyDrive/MRI dataset/Alzheimer_MRI_4_classes_dataset'
IMG_SIZE = 224
BATCH_SIZE = 8
EPOCHS = 20

In [None]:
dataset = tf.keras.utils.image_dataset_from_directory(
    DATA_PATH,
    label_mode=None,
    image_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    shuffle=True
)

Found 6400 files.


# **Data Augmentation**

In [None]:
def augment(image):
    # image: (224, 224, 3) from image_dataset_from_directory
    # Convert to grayscale â†’ (224, 224, 1)
    image = tf.image.rgb_to_grayscale(image)
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    # Ensure static shape
    image = tf.reshape(image, [IMG_SIZE, IMG_SIZE, 1])
    return image

In [None]:
# Create dataset of paired augmented images
def prepare_simclr_dataset(input_dataset):
    def _augment(image):
        return augment(image), augment(image)

    dataset = (
        input_dataset
        .unbatch()
        .map(_augment, num_parallel_calls=tf.data.AUTOTUNE)
        .batch(BATCH_SIZE)
        .prefetch(tf.data.AUTOTUNE)
    )
    return dataset

# **Building the ResNet50 Model**

In [None]:
base = tf.keras.applications.ResNet50(
    include_top=False,
    weights=None,
    input_shape=(IMG_SIZE, IMG_SIZE, 1)
)

In [None]:
# Stop aggressive downsampling
base.layers[2].strides = (1, 1)  # first conv stride 1
x = base.output
x = tf.keras.layers.GlobalAveragePooling2D()(x)
encoder = tf.keras.Model(base.input, x, name="MRI_Encoder")

# **Simclr Projection Head**

In [None]:
# Simple projection head
proj_head = tf.keras.Sequential([
    tf.keras.layers.Dense(512, activation="relu"),
    tf.keras.layers.Dense(128)
], name="projection")

In [None]:
# Full SimCLR model
inputs = tf.keras.Input((IMG_SIZE, IMG_SIZE, 1))
features = encoder(inputs)
projections = proj_head(features)
simclr_model = tf.keras.Model(inputs, projections)

# **Loss**

In [None]:
# Simple NT-Xent loss
def simclr_loss(z_i, z_j, temperature=0.1):
    z_i = tf.math.l2_normalize(z_i, axis=1)
    z_j = tf.math.l2_normalize(z_j, axis=1)
    z = tf.concat([z_i, z_j], axis=0)
    similarity = tf.matmul(z, z, transpose_b=True) / temperature
    batch_size = tf.shape(z_i)[0]
    labels = tf.range(batch_size)
    labels = tf.concat([labels + batch_size, labels], axis=0)
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)(labels, similarity)
    return loss

optimizer = tf.keras.optimizers.Adam(3e-4)

In [None]:
@tf.function
def train_step(view1, view2):
    with tf.GradientTape() as tape:
        z1 = simclr_model(view1, training=True)
        z2 = simclr_model(view2, training=True)
        loss = simclr_loss(z1, z2)
    grads = tape.gradient(loss, simclr_model.trainable_variables)
    optimizer.apply_gradients(zip(grads, simclr_model.trainable_variables))
    return loss

# **Training**

In [None]:
# ================== TRAINING STARTS HERE ==================
print("Training in progress...")

# Prepare the dataset
simclr_training_dataset = prepare_simclr_dataset(dataset)

# Training loop with debug prints for shapes
for epoch in range(EPOCHS):
    epoch_loss_avg = tf.keras.metrics.Mean()
    for batch in simclr_training_dataset:
        # Unpack batch properly: batch is a tuple of two tensors (view1s, view2s)
        view1, view2 = batch

        # Check shapes to make sure they are correct
        # print(f"view1 shape: {view1.shape}, view2 shape: {view2.shape}")

        loss = train_step(view1, view2)
        epoch_loss_avg.update_state(loss)
    print(f"Epoch: {epoch + 1}, loss: {epoch_loss_avg.result().numpy()}")

Training in progress...
Epoch: 1, loss: 0.8377034068107605
Epoch: 2, loss: 0.738017737865448
Epoch: 3, loss: 0.7256376147270203
Epoch: 4, loss: 0.7225548624992371
Epoch: 5, loss: 0.7152161598205566
Epoch: 6, loss: 0.7131702303886414
Epoch: 7, loss: 0.7148420810699463
Epoch: 8, loss: 0.7138302326202393
Epoch: 9, loss: 0.7267529368400574
