## Training on MNIST Dataset with Contrastive Pairs Loss
-------------------------

In [1]:
import os
#import torchsummary
import matplotlib.pyplot as plt
import torch
import numpy as np
import cv2
from torch import nn
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms, models

os.chdir('..'); os.chdir('..')
print(os.getcwd()) # Should be .\ContrastiveLearning
from Code.trainers import Trainer
#from Code.models import SiameseNet
from Code.losses import form_triplets, ContrastiveLoss
from Code.dataloaders import LabeledContrastiveDataset
from Code.utils import extract_embeddings


# Hyperparameters
N = 3000
EMB_SIZE = 2
DEVICE = 'cuda'
LR = 0.0005
EPOCHS = 10
MARGIN = 1.0
NAME = 'MNIST_PAIR_LOSS_' + '_'.join([str(N), str(EMB_SIZE), str(LR), str(EPOCHS), str(MARGIN)])

# Reproduciblity
SEED = 911
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

D:\Research\ContrastiveRepresentationLearning


## Create Dataloader and Inspect Data
---------------------

In [2]:
root = r'D:\Data\Imagery\MNIST\MNIST'
mean, std = 0.1307, 0.3081

tfms  =    transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((mean,), (std,))
                             ])


lcd = LabeledContrastiveDataset(root, transforms=tfms)


In [3]:
datadict = lcd.__getitem__(4)
print(datadict["x1"].shape); print(datadict["x2"].shape)

torch.Size([10, 1, 28, 28])
torch.Size([10, 1, 28, 28])


  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()


In [4]:
from torchvision import transforms


train_sampler = SubsetRandomSampler(range(int(N*0.9)))
test_sampler = SubsetRandomSampler(range(int(N*0.9), N))

siamese_train_loader = torch.utils.data.DataLoader(lcd, batch_size=None, sampler=train_sampler)
siamese_test_loader = torch.utils.data.DataLoader(lcd, batch_size=None, shuffle=test_sampler)


## Model
------------

In [5]:
embedding_net = models.resnet18()
embedding_net.conv1 = nn.Conv2d(1, 64, (7,7), (2,2), (3,3))
embedding_net.fc = nn.Linear(512, EMB_SIZE)
model = embedding_net
model.train(); print() ; #torchsummary.summary(model, input_size = [(1,28,28),(1, 28, 28)], device=DEVICE)




## Training
-------------------------

In [6]:

TL = ContrastiveLoss(margin=1.0, mode='pairs')

t = Trainer(model = model, 
            dataloader = siamese_train_loader,
            lr=LR,
            loss_function= TL)

In [None]:
losses = t.train(EPOCHS, print_every=1)#, writer = writer)

  0%|                                                                                         | 0/2700 [00:00<?, ?it/s]

----- Epoch: 0 -----


100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [02:53<00:00, 15.55it/s]
  0%|                                                                                 | 2/2700 [00:00<02:39, 16.94it/s]

Avg train loss: 0.03285384323926539
----- Epoch: 1 -----


100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [02:43<00:00, 16.46it/s]
  0%|                                                                                 | 2/2700 [00:00<02:35, 17.39it/s]

Avg train loss: 0.02067065692502621
----- Epoch: 2 -----


100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [02:42<00:00, 16.59it/s]
  0%|                                                                                 | 2/2700 [00:00<02:33, 17.54it/s]

Avg train loss: 0.015718653033154794
----- Epoch: 3 -----


100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [02:38<00:00, 17.03it/s]
  0%|                                                                                 | 2/2700 [00:00<02:40, 16.81it/s]

Avg train loss: 0.012854573776516678
----- Epoch: 4 -----


100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [02:42<00:00, 16.57it/s]
  0%|                                                                                 | 2/2700 [00:00<02:36, 17.24it/s]

Avg train loss: 0.011053307272608436
----- Epoch: 5 -----


 38%|█████████████████████████████▌                                                | 1022/2700 [01:02<01:44, 16.11it/s]

In [None]:
plt.plot(losses)
plt.title("Training Loss")
plt.ylabel("Train loss"); plt.xlabel("Epochs")


In [None]:
losses

## Inspecting Embeddings
-------------------

In [None]:
EMBS_TO_VISUALIZE = N - int(N*0.9)
mnist_classes = [0,1,2,3,4,5,6,7,8,9]
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
              '#9467bd', '#8c564b', '#e377c2', '#7f7f7f',
              '#bcbd22', '#17becf']

coloring = {mnist_classes[i]: colors[i] for i in range(len(mnist_classes))}

In [None]:
test_embs = extract_embeddings(siamese_test_loader, model, EMBS_TO_VISUALIZE, 'cpu')

In [None]:
test_embs.head()

In [None]:
import seaborn as sns
sns.set_style("darkgrid")
sns.relplot(x="X", y="Y", data=test_embs, hue="Label", palette="deep", alpha=0.7, s=75)
plt.title("Test Embeddings")

## Saving Model
-------------------------

In [None]:
outpath = os.getcwd() + r'\Outputs\Weights\' + NAME 
torch.save(model.state_dict(), outpath); print(outpath)