In [1]:
from torchvision import transforms
import torchvision.transforms.functional as F
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torch

import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import accuracy_score
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances

import random

In [2]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
class ImageDataset(Dataset):
    def __init__(self, image_data, labels, transform=None, num_triplets=None):
        """
        Args:
            image_data (torch.Tensor): A tensor of shape (N, 1, 128, 128) containing the images.
            labels (torch.Tensor): A tensor of shape (N,) containing the labels.
            transform (callable, optional): Optional transform to be applied on a sample.
            num_triplets (int, optional): Number of triplets to generate. If None, generates all possible triplets.
        """
        self.image_data = torch.tensor(image_data, dtype=torch.float32).repeat(1, 3, 1, 1)
        self.labels = torch.tensor(labels, dtype=torch.long)
        self.transform = transform
        self.num_triplets = num_triplets if num_triplets else self.image_data.size(0)

        # Generate triplets
        self.triplets = self._generate_triplets(num_triplets)

    def _generate_triplets(self, num_triplets):
        triplets = []
        unique_labels = torch.unique(self.labels).tolist()
        
        for _ in range(num_triplets):
            label1, label2 = random.sample(unique_labels, 2)
            positive_indices = torch.where(self.labels == label1)[0].tolist()
            negative_indices = torch.where(self.labels == label2)[0].tolist()

            anchor, positive = random.sample(positive_indices, 2)
            negative = random.choice(negative_indices)
            triplets.append((anchor, positive, negative))

        return triplets

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        anchor_idx, positive_idx, negative_idx = self.triplets[idx]

        anchor = self.image_data[anchor_idx]
        positive = self.image_data[positive_idx]
        negative = self.image_data[negative_idx]

        if self.transform:
            anchor = self.transform(anchor)
            positive = self.transform(positive)
            negative = self.transform(negative)

        return anchor, positive, negative

class RandomRotation:
    def __init__(self, degrees, p=0.5):
        self.degrees = degrees
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            angle = random.uniform(-self.degrees, self.degrees)
            img = F.rotate(img, angle)
        return img

class RandomGaussianBlur:
    def __init__(self, kernel_size, sigma=(0.1, 2.0), p=0.5):
        self.kernel_size = kernel_size
        self.sigma = sigma
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            img = F.gaussian_blur(img, self.kernel_size, self.sigma)
        return img

class RandomNoise:
    def __init__(self, mean=0, std=0.1, p=0.5):
        self.mean = mean
        self.std = std
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            noise = torch.randn_like(img) * self.std + self.mean
            img = img + noise
            img = torch.clamp(img, 0, 1)  # Clamp values to [0, 1] range
        return img

# Define the transformations
transform = transforms.Compose([
    RandomRotation(degrees=30, p=0.5),
    RandomGaussianBlur(kernel_size=3, sigma=(0.1, 2.0), p=1),
    RandomNoise(mean=0, std=0.1, p=1),
])


def triplet_loss(anchor, positive, negative, margin: float = 1.0):
    """
    Computes Triplet Loss
    """
    distance_positive = nn.functional.pairwise_distance(anchor, positive)
    distance_negative = nn.functional.pairwise_distance(anchor, negative)
    loss = torch.clamp(distance_positive - distance_negative + margin, min=0.0)
    return torch.mean(loss)

In [4]:
model =  torch.load("models/resnet50_basemodel", weights_only=False)
resnet50 = torch.nn.Sequential(*list(model.children())[:-1], nn.Linear(2048, 256))

# Load data
train_data = np.load("data/train.npz")
test_data = np.load("data/test.npz")
register_data = np.load("data/register.npz")

# Create datasets and dataloaders
train_dataset = ImageDataset(train_data["data"], train_data["labels"], transform=transform, num_triplets = 1000)
test_dataset = ImageDataset(test_data["data"], test_data["labels"], transform=transform, num_triplets=1000)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

# Define loss function and optimizer
criterion = triplet_loss  # For classification tasks
optimizer = optim.Adam(resnet50.parameters(), lr=1e-6)

# Training loop
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet50.to(device)  # Move model to GPU if available

train_losses = list()
test_losses = list()
test_accs = list()

for epoch in range(num_epochs):
    # Training phase
    resnet50.train()  # Set model to training mode
    running_loss = 0.0

    for anchor, positive, negative in train_dataloader:
        anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        anchor_embedding, pos_embedding, neg_embedding = resnet50(anchor), resnet50(positive), resnet50(negative)
        loss = criterion(anchor_embedding, pos_embedding, neg_embedding)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Accumulate loss
        running_loss += loss.item() * anchor.size(0)

    # Calculate training loss for the epoch
    epoch_loss = running_loss / len(train_dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_loss:.4f}")

    # Evaluation phase
    resnet50.eval()  # Set model to evaluation mode
    test_loss = 0.0


    with torch.no_grad():  # Disable gradient computation
        for anchor, positive, negative in test_dataloader:
            anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)

            # Forward pass
            anchor_embedding, pos_embedding, neg_embedding = resnet50(anchor), resnet50(positive), resnet50(negative)
            loss = criterion(anchor_embedding, pos_embedding, neg_embedding)

            # Accumulate test loss
            test_loss += loss.item() * anchor.size(0)
            
        X_register = torch.tensor(register_data["data"],dtype=torch.float32).to(device)
        X_register = X_register.repeat(1,3,1,1)
        X_register = transform(X_register)
        class_embeddings = resnet50(X_register)

        X_test = torch.tensor(test_data["data"],dtype=torch.float32).to(device)
        X_test = X_test.repeat(1,3,1,1)
        X_test = transform(X_test)
        test_embeddings = resnet50(X_test)
        
    # Get predictions
    similarities = euclidean_distances(test_embeddings.cpu().numpy().squeeze(), class_embeddings.cpu().numpy().squeeze())
    preds = np.argmin(similarities, axis=1)

    # Calculate test loss and accuracy
    test_loss = test_loss / len(test_dataset)
    test_accuracy = accuracy_score(test_data["labels"], preds)
    print(f"Epoch [{epoch+1}/{num_epochs}], Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

    train_losses.append(epoch_loss)
    test_losses.append(test_loss)
    test_accs.append(test_accuracy)

print("Training complete!")

Epoch [1/10], Train Loss: 0.0142
Epoch [1/10], Test Loss: 0.0255, Test Accuracy: 0.9200
Epoch [2/10], Train Loss: 0.0140
Epoch [2/10], Test Loss: 0.0240, Test Accuracy: 0.8840
Epoch [3/10], Train Loss: 0.0157
Epoch [3/10], Test Loss: 0.0205, Test Accuracy: 0.9180
Epoch [4/10], Train Loss: 0.0125
Epoch [4/10], Test Loss: 0.0219, Test Accuracy: 0.9200
Epoch [5/10], Train Loss: 0.0173
Epoch [5/10], Test Loss: 0.0197, Test Accuracy: 0.9270
Epoch [6/10], Train Loss: 0.0117
Epoch [6/10], Test Loss: 0.0200, Test Accuracy: 0.8960
Epoch [7/10], Train Loss: 0.0151
Epoch [7/10], Test Loss: 0.0274, Test Accuracy: 0.9190
Epoch [8/10], Train Loss: 0.0158
Epoch [8/10], Test Loss: 0.0272, Test Accuracy: 0.7500
Epoch [9/10], Train Loss: 0.0120
Epoch [9/10], Test Loss: 0.0232, Test Accuracy: 0.9220
Epoch [10/10], Train Loss: 0.0150
Epoch [10/10], Test Loss: 0.0231, Test Accuracy: 0.9170
Training complete!


In [5]:
torch.save(resnet50, "models/resnet50_triplet")

In [None]:
plt.plot(train_losses, label="Train loss")
plt.plot(test_losses, label="Test loss")
plt.legend()
plt.title("Train and Test loss during training.")
plt.ylabel("Cross Entropy Loss")
plt.xlabel("Epoch")
plt.show()

In [None]:
plt.plot(test_accs, label="Train loss")
plt.legend()
plt.title("Test accuracy during training.")
plt.ylabel("Accuracy")
plt.xlabel("Epoch")
plt.show()