In [1]:
from torchvision import transforms
import torchvision.transforms.functional as F
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torch

import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import accuracy_score
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances

import random

In [2]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
class TripletImageDataset(Dataset):
    def __init__(self, image_data, labels, transform=None, num_triplets=None):
        """
        Args:
            image_data (torch.Tensor): A tensor of shape (N, 1, 128, 128) containing the images.
            labels (torch.Tensor): A tensor of shape (N,) containing the labels.
            transform (callable, optional): Optional transform to be applied on a sample.
            num_triplets (int, optional): Number of triplets to generate. If None, generates all possible triplets.
        """
        self.image_data = torch.tensor(image_data, dtype=torch.float32).repeat(1, 3, 1, 1)
        self.labels = torch.tensor(labels, dtype=torch.long)
        self.transform = transform
        self.num_triplets = num_triplets if num_triplets else self.image_data.size(0)

        # Generate triplets
        self.triplets = self._generate_triplets(num_triplets)

    def _generate_triplets(self, num_triplets):
        triplets = []
        unique_labels = torch.unique(self.labels).tolist()
        
        for _ in range(num_triplets):
            label1, label2 = random.sample(unique_labels, 2)
            positive_indices = torch.where(self.labels == label1)[0].tolist()
            negative_indices = torch.where(self.labels == label2)[0].tolist()

            anchor, positive = random.sample(positive_indices, 2)
            negative = random.choice(negative_indices)
            triplets.append((anchor, positive, negative))

        return triplets

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        anchor_idx, positive_idx, negative_idx = self.triplets[idx]

        anchor = self.image_data[anchor_idx]
        positive = self.image_data[positive_idx]
        negative = self.image_data[negative_idx]

        if self.transform:
            anchor = self.transform(anchor)
            positive = self.transform(positive)
            negative = self.transform(negative)

        return anchor, positive, negative
    

class ImageDataset(Dataset):
    def __init__(self, image_data, labels, transform=None):
        """
        Args:
            image_data (torch.Tensor): A tensor of shape (N, 1, 128, 128) containing the images.
            labels (torch.Tensor): A tensor of shape (N,) containing the labels.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.image_data = torch.tensor(image_data, dtype = torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
        self.transform = transform

    def __len__(self):
        return len(self.image_data)

    def __getitem__(self, idx):
        image = self.image_data[idx]  # Shape: (1, 128, 128)
        label = self.labels[idx]

        # Convert single-channel image to 3-channel by copying the channel
        image = image.repeat(3, 1, 1)  # Shape: (3, 128, 128)

        if self.transform:
            image = self.transform(image)

        return image, label

class RandomRotation:
    def __init__(self, degrees, p=0.5):
        self.degrees = degrees
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            angle = random.uniform(-self.degrees, self.degrees)
            img = F.rotate(img, angle)
        return img

class RandomGaussianBlur:
    def __init__(self, kernel_size, sigma=(0.1, 2.0), p=0.5):
        self.kernel_size = kernel_size
        self.sigma = sigma
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            img = F.gaussian_blur(img, self.kernel_size, self.sigma)
        return img

class RandomNoise:
    def __init__(self, mean=0, std=0.1, p=0.5):
        self.mean = mean
        self.std = std
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            noise = torch.randn_like(img) * self.std + self.mean
            img = img + noise
            img = torch.clamp(img, 0, 1)  # Clamp values to [0, 1] range
        return img

# Define the transformations
transform = transforms.Compose([
    RandomRotation(degrees=30, p=0.5),
])

def triplet_loss(anchor, positive, negative, margin: float = None): 
    """
    Computes Triplet Loss with a dynamic margin if not provided.
    """
    # Compute distances
    distance_positive = torch.norm(anchor - positive, p=2, dim=1)
    distance_negative = torch.norm(anchor - negative, p=2, dim=1)
    
    # Default margin selection based on statistics
    mean_pos, std_pos = distance_positive.mean(), distance_positive.std()
    mean_neg, std_neg = distance_negative.mean(), distance_negative.std()
    
    margin = (mean_neg - mean_pos) - (std_pos + std_neg)
    
    # Compute loss
    loss = torch.clamp(distance_positive - distance_negative + margin, min=0.0)
    return torch.mean(loss)


In [None]:
model = models.resnet50(weights = False)

resnet50 = torch.nn.Sequential(*list(model.children())[:-1], nn.Flatten(), nn.Linear(2048, 100))
resnet50.load_state_dict(torch.load("models/resnet50_basemodel.pth"))

resnet50 = resnet50[:-1]

# Load data
train_data = np.load("data/train.npz")
test_data = np.load("data/test.npz")
unseen_test_data = np.load("data/test_unseen.npz")

# Create datasets and dataloaders
train_dataset = TripletImageDataset(train_data["data"], train_data["labels"], transform=transform, num_triplets = 1000)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = TripletImageDataset(test_data["data"], test_data["labels"], transform=transform, num_triplets=1000)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

unseen_test_dataset = ImageDataset(unseen_test_data["data"], unseen_test_data["labels"], transform=transform)
unseen_test_dataloader = DataLoader(unseen_test_dataset, batch_size=32, shuffle=False)

# Define loss function and optimizer
criterion = triplet_loss  # For classification tasks
optimizer = optim.Adam(resnet50.parameters(), lr=1e-7)

# Training loop
num_epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet50.to(device)  # Move model to GPU if available

train_losses = list()
test_losses = list()
test_accs = list()
unseen_test_accs = list()

patience = 5
best_loss = float('inf')
patience_counter = 0

for epoch in range(num_epochs):
    # Training phase
    resnet50.train()  # Set model to training mode
    running_loss = 0.0
    batch = 0
    for anchor, positive, negative in train_dataloader:
        anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        anchor_embedding, pos_embedding, neg_embedding = resnet50(anchor), resnet50(positive), resnet50(negative)
        loss = criterion(anchor_embedding, pos_embedding, neg_embedding, batch)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Accumulate loss
        running_loss += loss.item() * anchor.size(0)
        batch += 1


    # Calculate training loss for the epoch
    epoch_loss = running_loss / len(train_dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}")

    # Evaluation phase
    resnet50.eval()  # Set model to evaluation mode
    test_loss = 0.0

    with torch.no_grad():  # Disable gradient computation
        for anchor, positive, negative in test_dataloader:
            anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)

            # Forward pass
            anchor_embedding, pos_embedding, neg_embedding = resnet50(anchor), resnet50(positive), resnet50(negative)
            loss = criterion(anchor_embedding, pos_embedding, neg_embedding, batch)

            # Accumulate test loss
            test_loss += loss.item() * anchor.size(0)
            
    
    all_embeds = []
    unseen_labels = []

    with torch.no_grad():  # Disable gradient computation
        for inputs, labels in unseen_test_dataloader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            embeds = resnet50(inputs)

            all_embeds.append(embeds)
            unseen_labels.append(labels)

    all_embeds = torch.cat(all_embeds)
    unseen_labels = torch.cat(unseen_labels)

    distances = torch.cdist(all_embeds, all_embeds)

    correct_pred = 0
    for i in range(all_embeds.size(0)):
        distances[i,i] = float("inf")
        closest_inx = distances[i].argmin()
        if unseen_labels[i] == unseen_labels[closest_inx]:
            correct_pred += 1

    unseen_test_accuracy = correct_pred / all_embeds.size(0)


    # Calculate test loss and accuracy
    test_loss = test_loss / len(test_dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Test Loss: {test_loss:.4f}, Unseen Data Accuracy: {unseen_test_accuracy:.4f}")
    print()

    train_losses.append(epoch_loss)
    test_losses.append(test_loss)
    unseen_test_accs.append(unseen_test_accuracy)

    # Check for early stopping
    if test_loss < best_loss:
        best_loss = test_loss
        patience_counter = 0
        torch.save(resnet50.state_dict(), "models/resnet50_triplet.pth")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered")
            break


print("Training complete!")

  resnet50.load_state_dict(torch.load("models/resnet50_basemodel.pth"))


Epoch [1/100], Train Loss: 0.1628
Epoch [1/100], Test Loss: 0.1251, Unseen Data Accuracy: 0.5300

Epoch [2/100], Train Loss: 0.1667
Epoch [2/100], Test Loss: 0.1235, Unseen Data Accuracy: 0.5550

Epoch [3/100], Train Loss: 0.1558
Epoch [3/100], Test Loss: 0.1452, Unseen Data Accuracy: 0.5150

Epoch [4/100], Train Loss: 0.1406
Epoch [4/100], Test Loss: 0.1285, Unseen Data Accuracy: 0.5100

Epoch [5/100], Train Loss: 0.1264
Epoch [5/100], Test Loss: 0.1221, Unseen Data Accuracy: 0.4850

Epoch [6/100], Train Loss: 0.1397
Epoch [6/100], Test Loss: 0.1400, Unseen Data Accuracy: 0.5500

Epoch [7/100], Train Loss: 0.1469
Epoch [7/100], Test Loss: 0.1339, Unseen Data Accuracy: 0.5500

Epoch [8/100], Train Loss: 0.1592
Epoch [8/100], Test Loss: 0.1261, Unseen Data Accuracy: 0.5400

Epoch [9/100], Train Loss: 0.1732
Epoch [9/100], Test Loss: 0.1236, Unseen Data Accuracy: 0.5450

Epoch [10/100], Train Loss: 0.1547
Epoch [10/100], Test Loss: 0.1426, Unseen Data Accuracy: 0.5150

Early stopping tri

In [5]:

torch.save(resnet50.state_dict(), "models/resnet50_triplet.pth")

In [None]:
plt.plot(train_losses, label="Train loss")
plt.plot(test_losses, label="Test loss")
plt.legend()
plt.title("Train and Test loss during training.")
plt.ylabel("Cross Entropy Loss")
plt.xlabel("Epoch")
plt.show()

In [None]:
plt.plot(test_accs, label="Train loss")
plt.legend()
plt.title("Test accuracy during training.")
plt.ylabel("Accuracy")
plt.xlabel("Epoch")
plt.show()