In [None]:
import platform

from numpy.f2py.auxfuncs import throw_error

from utils import set_seed, create_directory, display_images
import os
import glob
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
import time

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import timm # For CoAtNet
from tqdm.notebook import tqdm

SEED = 111

set_seed(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

In [None]:
# Download dataset if not available
data_path = create_directory("data")
import kagglehub

# https://github.com/Kaggle/kagglehub/issues/175
os.environ['KAGGLEHUB_CACHE'] = data_path
# Download latest version
utkface_data = kagglehub.dataset_download("jangedoo/utkface-new")

# Thanks windows, cant get the full path due to limit
utkface_data_path = os.path.join(data_path, "datasets", "jangedoo", "utkface-new", "versions", "1", "UTKFace")
if not os.path.isdir(utkface_data_path):
    print(f"path does is not dir: {utkface_data_path}")


In [None]:
# Pre evaluate dataset
def get_img_ages(dataset_path, image_paths, ages):
    parsed_files_count = 0
    skipped_files_count = 0
    for filepath in glob.glob(os.path.join(dataset_path, '*.jpg')):
        filename = os.path.basename(filepath)
        try:
            parts = filename.split('_')
            if len(parts) < 2:
                skipped_files_count += 1
                continue
            age = int(parts[0])
            if not (0 <= age <= 116):
                skipped_files_count += 1
                continue
            image_paths.append(filepath)
            ages.append(age)
            parsed_files_count +=1
        except (ValueError, IndexError):
            skipped_files_count += 1

    print(f"Successfully parsed {parsed_files_count} image files.")
    print(f"Skipped {skipped_files_count} image files due to parsing issues or invalid data.")

def age_to_class(age):
    if age < 18: return 0
    elif age <= 40: return 1
    elif age <= 60: return 2
    else: return 3

image_paths_all = []
ages_all = []
get_img_ages(utkface_data_path, image_paths_all, ages_all)

if not image_paths_all:
    raise FileNotFoundError(f"No valid JPG images found or parsed in {utkface_data_path}.")

# Create DF
df_all = pd.DataFrame({'image_path': image_paths_all, 'age': ages_all})
df_all['age_class'] = df_all['age'].apply(age_to_class)
class_names_list = ['<18', '18-40', '41-60', '>60']
print("Class distribution:\n", df_all['age_class'].value_counts(normalize=True).sort_index())

In [None]:
VAL_RATIO = 0.15
TEST_RATIO = 0.15
train_df, test_df = train_test_split(df_all, test_size=TEST_RATIO, random_state=SEED, stratify=df_all['age_class'])
train_df, val_df = train_test_split(train_df, test_size=VAL_RATIO / (1 - TEST_RATIO), random_state=SEED, stratify=train_df['age_class'])

total_samples = len(df_all["image_path"])
count_train = len(train_df)
count_val = len(val_df)
count_test = len(test_df)

percent_train = (count_train / total_samples) * 100
percent_val = (count_val / total_samples) * 100
percent_test = (count_test / total_samples) * 100
print(f"Total samples:      {total_samples}")
print(f"Training samples:   {count_train} ({percent_train}%)")
print(f"Validation samples: {count_val}   ({percent_val}%)")
print(f"Test samples:       {count_test}  ({percent_test}%)")


In [None]:
# Taken from https://github.com/hamkerlab/DL_for_practitioners/blob/main/06_1_SSL_SimCLR/06_1_SSL_SimCLR.ipynb
class LinearLayer(nn.Module):
    def __init__(self, in_features, out_features, use_bias=True, use_bn=False, **kwargs):
        super(LinearLayer, self).__init__(**kwargs)
        self.in_features = in_features
        self.out_features = out_features
        self.use_bias = use_bias
        self.use_bn = use_bn
        self.linear = nn.Linear(self.in_features, self.out_features, bias=self.use_bias and not self.use_bn)
        if self.use_bn:
            self.bn = nn.BatchNorm1d(self.out_features)
    def forward(self, x):
        x = self.linear(x)
        if self.use_bn:
            x = self.bn(x)
        return x


# Taken from https://github.com/hamkerlab/DL_for_practitioners/blob/main/06_1_SSL_SimCLR/06_1_SSL_SimCLR.ipynb
class ProjectionHead(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, head_type='nonlinear', **kwargs):
        super(ProjectionHead, self).__init__(**kwargs)
        # ... (same as before) ...
        self.in_features = in_features
        self.out_features = out_features
        self.hidden_features = hidden_features
        self.head_type = head_type

        if self.head_type == 'linear':
            self.layers = LinearLayer(self.in_features, self.out_features, False, True)
        elif self.head_type == 'nonlinear': # Standard for SimCLR
            self.layers = nn.Sequential(
                LinearLayer(self.in_features, self.hidden_features, True, True),
                nn.ReLU(),
                LinearLayer(self.hidden_features, self.out_features, False, True)
            )
    def forward(self, x):
        x = self.layers(x)
        return x


# Taken from https://github.com/hamkerlab/DL_for_practitioners/blob/main/06_1_SSL_SimCLR/06_1_SSL_SimCLR.ipynb
class SimCLR_Loss(nn.Module):
    def __init__(self, batch_size, temperature):
        super().__init__()
        self.batch_size = batch_size
        self.temperature = temperature

        self.mask = self._mask_correlated_samples(batch_size)
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        self.similarity_f = nn.CosineSimilarity(dim=2)

    def _mask_correlated_samples(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N), dtype=bool)
        mask = mask.fill_diagonal_(0)

        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        return mask

    def forward(self, z_i, z_j):
        N = 2 * self.batch_size

        z = torch.cat((z_i, z_j), dim=0)
        sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature

        sim_i_j = torch.diag(sim, diagonal=self.batch_size) #torch.diag(input = sim, diagonal = self.batch_size)
        sim_j_i = torch.diag(input = sim, diagonal =-self.batch_size)

        # We have 2N samples
        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
        negative_samples = sim[self.mask].reshape(N, -1)

        #SIMCLR
        labels = torch.from_numpy(np.array([0]*N)).reshape(-1).to(positive_samples.device).long() #.float()

        logits = torch.cat((positive_samples, negative_samples), dim=1)
        loss = self.criterion(logits, labels)
        loss /= N

        return loss


In [None]:
NORMALIZE_MEAN = [0.485, 0.456, 0.406]
NORMALIZE_STD = [0.229, 0.224, 0.225]

class UTKFaceSimCLRDataset(Dataset):
    def __init__(self, dataframe, image_size, s_jitter=0.5, is_train=True):
        self.dataframe = dataframe
        self.image_size = image_size

        # SimCLR Augmentation
        if is_train:
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(size=image_size, scale=(0.2, 1.0)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomApply([
                    transforms.ColorJitter(brightness=0.8*s_jitter, contrast=0.8*s_jitter,
                                           saturation=0.8*s_jitter, hue=0.2*s_jitter)
                ], p=0.8),
                transforms.RandomGrayscale(p=0.2),
                transforms.RandomApply([
                    transforms.GaussianBlur(kernel_size=max(3, image_size//20*2+1), sigma=(0.1, 2.0))
                ], p=0.5),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=NORMALIZE_STD)
            ])
        else: # Minimal augmentation for validation because still need two views for SimCLR loss
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(size=image_size, scale=(0.8, 1.0)), # Less aggressive crop
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomGrayscale(p=0.1),
                transforms.ToTensor(),
                transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD)
            ])

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_path = self.dataframe.iloc[idx]['image_path']
        image = Image.open(img_path).convert('RGB')
        label = int(self.dataframe.iloc[idx]['age_class'])

        x_i = self.transform(image)
        x_j = self.transform(image)
        return x_i, x_j, label



In [None]:
class SimCLRCoAtNetModel(nn.Module):
    def __init__(self, coatnet_model_name, projection_hidden_features, projection_out_features):
        super().__init__()
        #  TODO: use num_classes 0?
        # num_classes=0 = no need for final classification layer.
        self.encoder = timm.create_model(coatnet_model_name, pretrained=False, num_classes=0)

        # Get feature dimension from CoAtNet
        coanet_feature_dim = self.encoder.num_features
        print(f"CoAtNet ('{coatnet_model_name}') output feature dimension: {coanet_feature_dim}")

        # TODO: why no "self.latent_layer = LinearLayer(128,self.num_features, True, True)" ?
        # 2. Projection Head
        self.projector = ProjectionHead(
            in_features=coanet_feature_dim,
            hidden_features=projection_hidden_features, # e.g., coanet_feature_dim or a fixed value like 2048
            out_features=projection_out_features
        )

    def forward(self, x):
        h = self.encoder(x) # Features from CoAtNet: [batch_size, coanet_feature_dim]
        if h.ndim > 2: # Should already be [B, C] from timm with num_classes=0
            h = torch.squeeze(h)
        z = self.projector(h) # Projected features: [batch_size, projection_out_features]
        return z # For SimCLR pre-training, only z is needed for loss calculation

In [None]:
IMAGE_SIZE = 224
BATCH_SIZE = 32

train_dataset = UTKFaceSimCLRDataset(train_df, IMAGE_SIZE, s_jitter=0.5, is_train=True)
val_dataset = UTKFaceSimCLRDataset(val_df, IMAGE_SIZE, is_train=False)
test_dataset = UTKFaceSimCLRDataset(test_df, IMAGE_SIZE, is_train=False)

num_workers = 2
if platform.system() == "Windows":
    num_workers = 0

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=True if DEVICE.type == 'cuda' else False)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True if DEVICE.type == 'cuda' else False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True if DEVICE.type == 'cuda' else False)


In [None]:
# list available CoAtNet models
timm.list_models("CoAtNet*", pretrained=False)

In [None]:
COATNET_MODEL_NAME = "coatnet_0_224"
PROJECTION_DIM = 128

model = SimCLRCoAtNetModel(
    coatnet_model_name=COATNET_MODEL_NAME,
    projection_hidden_features=512,
    projection_out_features=PROJECTION_DIM
).to(DEVICE)

In [None]:
TEMPERATURE = 0.1
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4

EPOCHS = 50
WARMUP_EPOCHS = 5

criterion = SimCLR_Loss(batch_size=BATCH_SIZE, temperature=TEMPERATURE).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(
#     optimizer,
#     lambda epoch: (epoch + 1) / WARMUP_EPOCHS if WARMUP_EPOCHS > 0 else 1.0,
#     verbose=True) # Ensure it stays 1.0 after warmup

# TODO: check https://github.com/hamkerlab/DL_for_practitioners/blob/c80d72b77250a7dee47a9e79182af424faffedea/04_1_ViT/04_1_WarmUpScheduler.ipynb#L71 5. Recommended Warmup Settings for Different Scenarios
# Ramps LR with a small value
scheduler_warmup = torch.optim.lr_scheduler.LinearLR(
                            optimizer,
                            start_factor=1.0 / WARMUP_EPOCHS if WARMUP_EPOCHS > 0 else 1.0,
                            end_factor=1.0,
                            total_iters=WARMUP_EPOCHS)

# Cosine Decay (NO restarts for a single run)
# eta_min=0 -> minimum learning rate
scheduler_cosine_decay = torch.optim.lr_scheduler.CosineAnnealingLR(
                            optimizer,
                            T_max=EPOCHS - WARMUP_EPOCHS,
                            eta_min=0, # TODO: Why 0?
                            last_epoch=-1)

combined_scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer,
                                  schedulers=[scheduler_warmup, scheduler_cosine_decay],
                                  milestones=[WARMUP_EPOCHS])



In [None]:

# --- Training Loop ---
print(f"Starting SimCLR pre-training with {COATNET_MODEL_NAME} for {EPOCHS} epochs...")
train_losses, val_losses = [], []

best_val_loss = float('inf')
best_epoch = -1
PATH_TO_BEST_ENCODER_WEIGHTS = f'./{COATNET_MODEL_NAME}_simclr_encoder_best_val_loss.pth'

for epoch in range(EPOCHS):
    epoch_start_time = time.time()

    #  Training Phase
    model.train()
    total_train_loss = 0
    pbar_train = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]", leave=False)
    for x_i, x_j, _ in pbar_train:
        x_i, x_j = x_i.to(DEVICE), x_j.to(DEVICE)
        optimizer.zero_grad()
        z_i = model(x_i)
        z_j = model(x_j)
        loss = criterion(z_i, z_j)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
        pbar_train.set_postfix({"Loss": loss.item()})
    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    #  Validation Phase
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        pbar_val = tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]", leave=False)
        for x_i, x_j, _ in pbar_val:
            x_i, x_j = x_i.to(DEVICE), x_j.to(DEVICE)
            z_i = model(x_i)
            z_j = model(x_j)
            loss = criterion(z_i, z_j)
            total_val_loss += loss.item()
            pbar_val.set_postfix({"Loss": loss.item()})
    avg_val_loss = total_val_loss / len(val_loader) if len(val_loader) > 0 else 0
    val_losses.append(avg_val_loss)

    # Scheduler step
    combined_scheduler.step()

    current_lr = optimizer.param_groups[0]['lr']
    epoch_time_taken = (time.time() - epoch_start_time) / 60

    print(f"Epoch [{epoch+1}/{EPOCHS}] - Train Loss: {avg_train_loss:.4f} - Val Loss: {avg_val_loss:.4f} - LR: {current_lr:.6f} - Time: {epoch_time_taken:.2f} min")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_epoch = epoch + 1
        torch.save(model.encoder.state_dict(), PATH_TO_BEST_ENCODER_WEIGHTS)
        print(f"Epoch {epoch+1}: New best SSL validation loss: {avg_val_loss:.4f}. Saved encoder to {PATH_TO_BEST_ENCODER_WEIGHTS}")


    if (epoch + 1) % 10 == 0 or (epoch + 1) == EPOCHS:
        save_path = f'./{COATNET_MODEL_NAME}_simclr_epoch_{epoch+1}.pth'
        torch.save(model.encoder.state_dict(), save_path)
        print(f"Saved pre-trained encoder to {save_path}")

print("SimCLR Pre-training finished!")

# Optional: Plot losses
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('SimCLR Pre-training Loss')
plt.legend()
plt.grid(True)
plt.savefig(f'{COATNET_MODEL_NAME}_simclr_loss_plot.png')
plt.show()



In [None]:
# taken from https://github.com/hamkerlab/DL_for_practitioners/blob/c80d72b77250a7dee47a9e79182af424faffedea/Utils/plotting.py
def visualize_embeddings_tsne(embeddings: np.ndarray | torch.Tensor,
                              labels: np.ndarray | torch.Tensor,
                              output_dir: str | None, # type hint for output_dir
                              class_names: list[str], # type hint for class_names
                              n_samples: int = 2000,
                              num_components: int = 2,
                              title_suffix: str = "") -> None: # Added title_suffix
    try:
        from sklearn.manifold import TSNE # Ensure TSNE is imported here if not globally

        # check if data is numpy
        if torch.is_tensor(embeddings):
            embeddings = embeddings.cpu().numpy()
        if torch.is_tensor(labels):
            labels = labels.cpu().numpy()

        # Subsample if too many points
        if len(embeddings) > n_samples:
            print(f"Subsampling {n_samples} out of {len(embeddings)} for t-SNE.")
            indices = np.random.choice(len(embeddings), n_samples, replace=False)
            embeddings = embeddings[indices]
            labels = labels[indices]

        # Apply t-SNE
        print("Applying t-SNE... (this may take a while)")
        tsne = TSNE(n_components=num_components, random_state=SEED, perplexity=30, n_iter=1000, init='pca', learning_rate='auto') # Added init and lr
        embeddings_2d = tsne.fit_transform(embeddings)
        print("t-SNE done.")

        # Plot
        plt.figure(figsize=(12, 10))
        unique_labels = np.unique(labels)
        for class_label_val in unique_labels:
            if int(class_label_val) < len(class_names):
                 class_name_str = class_names[int(class_label_val)]
            else:
                 class_name_str = f"Class {int(class_label_val)}" # Fallback if class_names is too short

            indices = labels == class_label_val
            plt.scatter(embeddings_2d[indices, 0], embeddings_2d[indices, 1], label=class_name_str, alpha=0.7)

        plt.title(f't-SNE Visualization of Encoder Embeddings {title_suffix}')
        plt.xlabel("t-SNE Component 1")
        plt.ylabel("t-SNE Component 2")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        if output_dir is not None:
            os.makedirs(output_dir, exist_ok=True) # Ensure dir exists
            plt.savefig(os.path.join(output_dir, f'embeddings_tsne{title_suffix.replace(" ", "_")}.png'), dpi=300)
            print(f"t-SNE plot saved to {os.path.join(output_dir, f'embeddings_tsne{title_suffix.replace(" ", "_")}.png')}")
            plt.close() # Close plot if saving to file to prevent display issues in loops
        else:
            plt.show()

    except ImportError:
        print("scikit-learn not installed, skipping embedding visualization for t-SNE.")
    except Exception as e:
        print(f"An error occurred during t-SNE visualization: {e}")

In [None]:
print("\nProceeding to t-SNE visualization...")

# Instantiate the encoder architecture
feature_extractor_for_tsne = timm.create_model(COATNET_MODEL_NAME, pretrained=False, num_classes=0)

# Load the weights of the "best" encoder
if os.path.exists(PATH_TO_BEST_ENCODER_WEIGHTS):
    feature_extractor_for_tsne.load_state_dict(torch.load(PATH_TO_BEST_ENCODER_WEIGHTS, map_location=DEVICE))
    print(f"Loaded best encoder weights from {PATH_TO_BEST_ENCODER_WEIGHTS} for t-SNE.")
    feature_extractor_for_tsne.to(DEVICE)
    feature_extractor_for_tsne.eval()

    if val_df is not None and not val_df.empty:
        tsne_sample_size = len(val_df)  # Number of samples for t-SNE
        tsne_df = val_df
        print(f"Using {len(tsne_df)} samples from validation set for t-SNE.")

        # TODO: might need even a more minimal transform!!
        feature_dataset = UTKFaceSimCLRDataset(tsne_df, IMAGE_SIZE, is_train=False)
        feature_loader = DataLoader(feature_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers)

        all_features = []
        all_labels_for_tsne = []

        with torch.no_grad():
            for images, labels_batch in tqdm(feature_loader, desc="Extracting features for t-SNE"):
                images = images.to(DEVICE)
                 # Get 'h' from encoder?
                features = feature_extractor_for_tsne(images)
                all_features.append(features.cpu())
                all_labels_for_tsne.append(labels_batch.cpu())

        if all_features:
            all_features_tensor = torch.cat(all_features, dim=0)
            all_labels_tensor = torch.cat(all_labels_for_tsne, dim=0)

            visualize_embeddings_tsne(
                embeddings=all_features_tensor,
                labels=all_labels_tensor,
                output_dir="./tsne_plots",
                class_names=class_names_list,
                n_samples=tsne_sample_size,
                title_suffix=f" after {EPOCHS} Epochs SimCLR"
            )
        else:
            print("No features extracted for t-SNE.")
    else:
        print("Validation DataFrame (val_df) is empty or None!! ")