In [1]:
import warnings
warnings.filterwarnings("ignore")
import torch
from torchvision import datasets, transforms
from torch.utils.data import Dataset
import random

to_pil_image = transforms.ToPILImage()

# Transform for CIFAR-10
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Load CIFAR-10 dataset
cifar10_train = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
cifar10_test = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

# Organize CIFAR-10 dataset by class
class_to_images_train = {i: [] for i in range(10)}
labels_train = []
for idx, (image, label) in enumerate(cifar10_train):
    if idx == 10000:
        break
    class_to_images_train[label].append((image, label))
    labels_train.append(label)
    
labels_train = torch.tensor(labels_train)

class_to_images_test = {i: [] for i in range(10)}
labels_test = []
for idx, (image, label) in enumerate(cifar10_test):
    if idx == 1000:
        break
    class_to_images_test[label].append((image, label))
    labels_test.append(label)
    
labels_test = torch.tensor(labels_test)


# Custom dataset class
class TripletCIFAR10Dataset(Dataset):
    def __init__(self, labels, class_to_images, num_samples):
        self.train_labels = labels
        self.class_to_images = class_to_images
        self.num_samples = num_samples
        self.classes = list(class_to_images.keys())
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        # Select a random class
        same_class = random.choice(self.classes)
        
        # Select two images from the same class
        img1, _ = random.choice(self.class_to_images[same_class])
        img2, _ = random.choice(self.class_to_images[same_class])
        
        # Select a different class
        diff_class = random.choice([c for c in self.classes if c != same_class])
        
        # Select one image from the different class
        img3, _ = random.choice(self.class_to_images[diff_class])
        
        # Return triplet (img1, img2, img3) and label (same_class)
        return (img1, img2, img3), same_class

# Create datasets
train_size, test_size = 5000, 1000
train_dataset = TripletCIFAR10Dataset(labels_train, class_to_images_train, train_size)
test_dataset = TripletCIFAR10Dataset(labels_test, class_to_images_test, test_size)

print("Datasets created successfully!")


Files already downloaded and verified
Files already downloaded and verified
Datasets created successfully!


In [2]:
from datasets_s import BalancedBatchSampler

# We'll create mini batches by sampling labels that will be present in the mini batch and number of examples from each class
train_batch_sampler = BalancedBatchSampler(torch.tensor(cifar10_train.targets), n_classes=10, n_samples=25)
test_batch_sampler = BalancedBatchSampler(torch.tensor(cifar10_test.targets), n_classes=10, n_samples=25)

online_train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=train_batch_sampler)
online_test_loader = torch.utils.data.DataLoader(test_dataset, batch_sampler=test_batch_sampler)

In [3]:
from torch.optim import lr_scheduler
from torch import optim
from networks import EmbeddingNet, TripletNet
from losses import OnlineTripletLoss
from utils import  SemihardNegativeTripletSelector
from metrics import AverageNonzeroTripletsMetric

margin = 1.
embedding_net = EmbeddingNet()
model = TripletNet(embedding_net)

loss_fn = OnlineTripletLoss(margin, SemihardNegativeTripletSelector(margin))
lr = 1e-3
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)
n_epochs = 50
log_interval = 50

In [4]:
from trainer import fit

fit(online_train_loader, online_test_loader, model, loss_fn, optimizer, scheduler, n_epochs, None, log_interval, metrics=[AverageNonzeroTripletsMetric()])


Epoch: 1/50. Train set: Average loss: 5083.0532	Average nonzero triplets: 366.91959798994975
Epoch: 1/50. Validation set: Average loss: 0.4852	Average nonzero triplets: 137.89743589743588
Epoch: 2/50. Train set: Average loss: 0.5005	Average nonzero triplets: 361.48743718592965
Epoch: 2/50. Validation set: Average loss: 0.4825	Average nonzero triplets: 276.43589743589746
Epoch: 3/50. Train set: Average loss: 0.5011	Average nonzero triplets: 217.98994974874373
Epoch: 3/50. Validation set: Average loss: 0.4828	Average nonzero triplets: 231.15384615384616
Epoch: 4/50. Train set: Average loss: 0.5017	Average nonzero triplets: 259.56281407035175
Epoch: 4/50. Validation set: Average loss: 0.4846	Average nonzero triplets: 240.6153846153846
Epoch: 5/50. Train set: Average loss: 0.4996	Average nonzero triplets: 359.19095477386935
Epoch: 5/50. Validation set: Average loss: 0.4876	Average nonzero triplets: 345.64102564102564
Epoch: 6/50. Train set: Average loss: 0.5024	Average nonzero triplets: 38

In [7]:
model.eval()

TripletNet(
  (embedding_net): EmbeddingNet(
    (convnet): Sequential(
      (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1))
      (1): PReLU(num_parameters=1)
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
      (4): PReLU(num_parameters=1)
      (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (fc): Sequential(
      (0): Linear(in_features=1600, out_features=256, bias=True)
      (1): PReLU(num_parameters=1)
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): PReLU(num_parameters=1)
      (4): Linear(in_features=256, out_features=128, bias=True)
    )
  )
)