# ¿How well does a metric learning approach do?

In [2]:
import torch
from functools import partial
import torch.nn as nn
from torchvision import transforms
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from tqdm import tqdm
import matplotlib.pyplot as plt


# reproducibility
seed = 1993
torch.manual_seed(1)
torch.cuda.manual_seed(1)
torch.cuda.manual_seed_all(1)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = 'cuda' if torch.cuda.is_available() else 'cpu'
data_dir = "/home/studio-lab-user/CIL_Survey/data"

## Load and preprocess data

In [3]:
train_dataset_gpu = {}
eval_dataset_gpu = {}

# dataset
train = torchvision.datasets.CIFAR100(root=data_dir, download=True, transform=transforms.ToTensor())
eval = torchvision.datasets.CIFAR100(root=data_dir, train=False, transform=transforms.ToTensor())

# move dataset to gpu
train_dataset_gpu_loader = torch.utils.data.DataLoader(train, batch_size=len(train), drop_last=True,
                                            shuffle=True, num_workers=2, persistent_workers=False)
eval_dataset_gpu_loader = torch.utils.data.DataLoader(eval, batch_size=len(eval), drop_last=True,
                                            shuffle=False, num_workers=1, persistent_workers=False)
train_dataset_gpu['images'], train_dataset_gpu['targets'] = [item.to(device="cuda", non_blocking=True) for item in next(iter(train_dataset_gpu_loader))]
eval_dataset_gpu['images'],  eval_dataset_gpu['targets']  = [item.to(device="cuda", non_blocking=True) for item in next(iter(eval_dataset_gpu_loader)) ]

# # normalize images
# train_cifar_std, train_cifar_mean = torch.std_mean(train_dataset_gpu['images'], dim=(0, 2, 3)) 
# print(f"Mean: {[f'{x:.4f}' for x in train_cifar_mean.tolist()]}")
# print(f"Std: {[f'{x:.4f}' for x in train_cifar_std.tolist()]}")
# def batch_normalize_images(input_images, mean, std):
#     return (input_images - mean.view(1, -1, 1, 1)) / std.view(1, -1, 1, 1)
# batch_normalize_images = partial(batch_normalize_images, mean=train_cifar_mean, std=train_cifar_std)
# train_dataset_gpu['images'] = batch_normalize_images(train_dataset_gpu['images'])
# eval_dataset_gpu['images']  = batch_normalize_images(eval_dataset_gpu['images'])

data = {
        'train': train_dataset_gpu,
        'eval': eval_dataset_gpu
    }

# pad images for later random cropping
pad_amount = 4
data['train']['images'] = F.pad(data['train']['images'], (pad_amount,)*4, 'reflect')

Files already downloaded and verified


### Metric based data loading

In [4]:
@torch.no_grad()
def old_get_batches(data_dict, key, batchsize, indices=range(100), cutmix=False, cutmix_size=None):
    # select subset of class indices 
    if indices is not None:
        indices = torch.tensor(indices, device=device)
        images, targets = data_dict[key]["images"], data_dict[key]["targets"] 
        samples = torch.isin(targets, indices)
        images, targets = images[samples], targets[samples]
        assert len(images) == len(targets)

    num_epoch_examples = len(images)
    shuffled = torch.randperm(num_epoch_examples, device=device)
    crop_size = 32

    ## Here, we prep the dataset by applying all data augmentations in batches ahead of time before each epoch, then we return an iterator below
    ## that iterates in chunks over with a random derangement (i.e. shuffled indices) of the individual examples. So we get perfectly-shuffled
    ## batches (which skip the last batch if it's not a full batch), but everything seems to be (and hopefully is! :D) properly shuffled. :)
    if key == 'train':
        images = batch_crop(images, crop_size) # TODO: hardcoded image size for now?
        images = batch_flip_lr(images)
        if cutmix:
            images, targets = batch_cutmix(images, targets, patch_size=cutmix_size)

    # # Send the images to an (in beta) channels_last to help improve tensor core occupancy (and reduce NCHW <-> NHWC thrash) during training
    # images = images.to(memory_format=torch.channels_last)
    for idx in range(num_epoch_examples // batchsize):
        if not (idx+1)*batchsize > num_epoch_examples: ## Use the shuffled randperm to assemble individual items into a minibatch
            yield images.index_select(0, shuffled[idx*batchsize:(idx+1)*batchsize]), \
                  targets.index_select(0, shuffled[idx*batchsize:(idx+1)*batchsize]) ## Each item is only used/accessed by the network once per epoch. :D


In [30]:
from batch_transforms import batch_crop, batch_flip_lr

@torch.no_grad()
def get_batches(data_dict, key, batchsize, indices=range(100)):
    # select subset of class indices 
    indices = torch.tensor(indices, device=device)
    images, targets = data_dict[key]["images"], data_dict[key]["targets"] 
    samples = torch.isin(targets, indices)
    images, targets = images[samples], targets[samples]
    
    assert len(images) == len(targets)

    # as we are going to pair up the images, we need the size of the dataset to be even
    if len(images) % 2 != 0:
        images = images[:-1]
        targets = targets[:-1]

    num_epoch_examples = len(images)
    shuffled = torch.randperm(num_epoch_examples, device=device)
    crop_size = 32

    images = images[shuffled]
    targets = targets[shuffled]

    if key == 'train':
        images = batch_crop(images, crop_size)
        images = batch_flip_lr(images)

    # pair up the dataset
    targets = targets.reshape(num_epoch_examples // 2, 2)
    pairs = torch.eq(targets[:,0], targets[:,1])
    # TODO: this takes much longer than before 
    # without 50% -> 30 ms
    # with 50% -> 1 seg
    # we need that roughly 50% of the pairs are positive and negative
    if pairs.float().mean() < 0.5:
        # get negative pairs
        targets = targets.reshape(num_epoch_examples)
        pairs = torch.stack([pairs, pairs], 1).reshape(num_epoch_examples)
        neg = pairs == False
        # permute them 
        perm = torch.randperm(len(pairs[neg]))
        
        images[neg] = images[neg][perm]
        targets[neg] = targets[neg][perm]
        targets = targets.reshape(num_epoch_examples // 2, 2)
        pairs = torch.eq(targets[:,0], targets[:,1])
    
    images = images.reshape(num_epoch_examples // 2, 2, 3, images.shape[-1], images.shape[-2])
    num_epoch_examples = len(images)

    for idx in range(num_epoch_examples // batchsize):
        yield images[idx*batchsize: (idx+1)*batchsize], targets[idx*batchsize: (idx+1)*batchsize]

#### ¿How do I make sure that the implementation is correct?

In [39]:
%%time 

for x, y in get_batches(data, "train", 128):
    y
    
torch.eq(y[:,0], y[:,1]).float().mean()

CPU times: user 93.1 ms, sys: 32.4 ms, total: 125 ms
Wall time: 118 ms


In [26]:
%%time 

for x, y in old_get_batches(data, "train", 128):
    y

CPU times: user 20.7 ms, sys: 16.3 ms, total: 37 ms
Wall time: 35.5 ms


In [40]:
# from torchvision.transforms.functional import to_pil_image


# images_pairs = data["train"]["images"].reshape(25000, 2, 3, 40, 40)
# labels_pairs  = data["train"]["targets"].reshape(25000, 2)

# # data["train"]["images"][1].shape
# to_pil_image(data["train"]["images"][0]).show()
# print(data["train"]["targets"][0])
# to_pil_image(data["train"]["images"][1]).show()
# print(data["train"]["targets"][1])

# to_pil_image(images_pairs[0][0]).show()
# to_pil_image(images_pairs[0][1]).show()
# print(labels_pairs[0][0])
# print(labels_pairs[0][1])

## Implement siamese network

### Model definition

### Classification Mechanism

## Incremental trainining

### Loss function