In [None]:
from torchvision import transforms
import torchvision.transforms.functional as F
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torch

import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import accuracy_score
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances

import random

In [None]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
class ImageDataset(Dataset):
    def __init__(self, image_data, labels, transform=None):
        """
        Args:
            image_data (torch.Tensor): A tensor of shape (N, 1, 128, 128) containing the images.
            labels (torch.Tensor): A tensor of shape (N,) containing the labels.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.image_data = torch.tensor(image_data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
        self.transform = transform

    def __len__(self):
        return len(self.image_data)

    def __getitem__(self, idx):
        image = self.image_data[idx]
        label = self.labels[idx]

        image = image.repeat(3, 1, 1)

        if self.transform:
            augmented_image = self.transform(image)
            return image, augmented_image, label
        
        return image, label

class RandomRotation:
    def __init__(self, degrees, p=0.5):
        self.degrees = degrees
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            angle = random.uniform(-self.degrees, self.degrees)
            img = F.rotate(img, angle)
        return img


# Define the transformations
transform = transforms.Compose([
    RandomRotation(degrees=30, p=1.0),
])



class SupConLoss(nn.Module):
    def __init__(self, temperature=0.07, base_temperature=0.07):
        """
        Implementation of Supervised Contrastive Learning loss
        
        Args:
            temperature: Scaling parameter for cosine similarity
            base_temperature: Baseline temperature parameter
        """
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.base_temperature = base_temperature

    def forward(self, features, labels):
        """
        Args:
            features: Hidden vector of shape [batch_size, n_views, ...].
            labels: Ground truth of shape [batch_size].
        Returns:
            A loss scalar.
        """
        device = features.device
        batch_size = features.shape[0]
        
        # Reshape features to [batch_size * n_views, ...]
        if len(features.shape) < 3:
            features = features.unsqueeze(1)
        features = features.view(features.shape[0], features.shape[1], -1)
        n_views = features.shape[1]
        features = torch.cat(torch.unbind(features, dim=1), dim=0)
        
        # Expand labels to match features
        labels = labels.contiguous().view(-1, 1)
        labels = labels.repeat(n_views, 1)
        
        # Compute similarity matrix
        features = nn.functional.normalize(features, dim=1)
        similarity_matrix = torch.matmul(features, features.T)
        
        # Get mask for positive pairs
        mask = torch.eq(labels, labels.T).float().to(device)
        
        # For numerical stability
        logits_max, _ = torch.max(similarity_matrix, dim=1, keepdim=True)
        logits = similarity_matrix - logits_max.detach()
        
        # Compute log_prob
        exp_logits = torch.exp(logits / self.temperature)
        log_prob = logits / self.temperature - torch.log(exp_logits.sum(1, keepdim=True))
        
        # Compute mean of log-likelihood over positive pairs
        mask_pos = mask.clone()
        mask_pos[torch.eye(mask_pos.shape[0], dtype=torch.bool).to(device)] = 0
        mean_log_prob_pos = (mask_pos * log_prob).sum(1) / mask_pos.sum(1)
        
        # Loss
        loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.mean()
        
        return loss



In [None]:
model = models.resnet50(weights = False)

resnet50 = torch.nn.Sequential(*list(model.children())[:-1], nn.Flatten(), nn.Linear(2048, 100))
resnet50.load_state_dict(torch.load("models/resnet50_basemodel.pth"))

resnet50 = resnet50[:-1]

# Load data
train_data = np.load("data/train.npz")
test_data = np.load("data/test.npz")
unseen_test_data = np.load("data/test_unseen.npz")

# Create datasets and dataloaders
train_dataset = ImageDataset(train_data["data"], train_data["labels"], transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = ImageDataset(test_data["data"], test_data["labels"], transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

unseen_test_dataset = ImageDataset(unseen_test_data["data"], unseen_test_data["labels"])
unseen_test_dataloader = DataLoader(unseen_test_dataset, batch_size=32, shuffle=False)

# Define loss function and optimizer
criterion = SupConLoss()  # For classification tasks
optimizer = optim.Adam(resnet50.parameters(), lr=1e-4)

# Training loop
num_epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet50.to(device)  # Move model to GPU if available

train_losses = list()
test_losses = list()
test_accs = list()
unseen_test_accs = list()

patience = 5
best_loss = float('inf')
patience_counter = 0

for epoch in range(num_epochs):
    # Training phase
    resnet50.train()  # Set model to training mode
    running_loss = 0.0
    batch = 0
    for anchors, augmented, labels in train_dataloader:
        inputs = torch.cat([anchors, augmented]).to(device)
        labels = torch.cat([labels, labels]).to(device)
        

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        embeddings = resnet50(inputs)
        loss = criterion(embeddings, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Accumulate loss
        running_loss += loss.item() * inputs.size(0)
        batch += 1


    # Calculate training loss for the epoch
    epoch_loss = running_loss / len(train_dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}")

    # Evaluation phase
    resnet50.eval()  # Set model to evaluation mode
    test_loss = 0.0

    with torch.no_grad():  # Disable gradient computation
        for anchors, augmented, labels in train_dataloader:
            inputs = torch.cat([anchors, augmented]).to(device)
            labels = torch.cat([labels, labels]).to(device)

            # Forward pass
            embeddings = resnet50(inputs)
            loss = criterion(embeddings, labels)

            # Accumulate test loss
            test_loss += loss.item() * inputs.size(0)
            
    
    all_embeds = []
    unseen_labels = []

    with torch.no_grad():  # Disable gradient computation
        for inputs, labels in unseen_test_dataloader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            embeds = resnet50(inputs)

            all_embeds.append(embeds)
            unseen_labels.append(labels)

    all_embeds = torch.cat(all_embeds)
    unseen_labels = torch.cat(unseen_labels)

    distances = torch.cdist(all_embeds, all_embeds)

    correct_pred = 0
    for i in range(all_embeds.size(0)):
        distances[i,i] = float("inf")
        closest_inx = distances[i].argmin()
        if unseen_labels[i] == unseen_labels[closest_inx]:
            correct_pred += 1

    unseen_test_accuracy = correct_pred / all_embeds.size(0)


    # Calculate test loss and accuracy
    test_loss = test_loss / len(test_dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Test Loss: {test_loss:.4f}, Unseen Data Accuracy: {unseen_test_accuracy:.4f}")
    print()

    train_losses.append(epoch_loss)
    test_losses.append(test_loss)
    unseen_test_accs.append(unseen_test_accuracy)

    # Check for early stopping
    if test_loss < best_loss:
        best_loss = test_loss
        patience_counter = 0
        torch.save(resnet50.state_dict(), "models/resnet50_fintune_100_256.pth")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered")
            break


print("Training complete!")

In [None]:
torch.save(resnet50, "models/resnet50_triplet")

In [None]:
plt.plot(train_losses, label="Train loss")
plt.plot(test_losses, label="Test loss")
plt.legend()
plt.title("Train and Test loss during training.")
plt.ylabel("Cross Entropy Loss")
plt.xlabel("Epoch")
plt.show()

In [None]:
plt.plot(test_accs, label="Train loss")
plt.legend()
plt.title("Test accuracy during training.")
plt.ylabel("Accuracy")
plt.xlabel("Epoch")
plt.show()