<a href="https://colab.research.google.com/github/Amaljayaranga/TensorBoard/blob/master/TensorBoard_Projector.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from tensorboardX import SummaryWriter
import torch
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from argparse import ArgumentParser
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
from SimaseDataset import SiameseMNIST


parser = ArgumentParser(description='Simase Network')
parser.add_argument('--learning_batch_size', type=int, default=16)
parser.add_argument('--fc_in_features', type=int, default=512)
parser.add_argument('--fc_out_features', type=int, default=64)
parser.add_argument('--constractive_loss_margin', type=float, default=1.0)
parser.add_argument('--learning_rate', type=float, default=1e-3)
parser.add_argument('--num_epochs', type=int, default=1)
parser.add_argument('--weight_decay', type=float, default=1e-5)
parser.add_argument('--validation_split', type=float, default=0.9)
parser.add_argument('--mode', type=str, default='train')
parser.add_argument('--device', type=str, default='cuda')


args, unknown = parser.parse_known_args()

DEVICE = args.device
if not torch.cuda.is_available():
    DEVICE = 'cpu'

train_dataset = MNIST('../data/MNIST', train=True, download=True,
                             transform = transforms.Compose([
                                 transforms.ToTensor()
                             ]))

test_dataset = MNIST('../data/MNIST', train=False, download=True,
                            transform = transforms.Compose([
                                transforms.ToTensor()
                            ]))


siamese_train_dataset = SiameseMNIST(train_dataset)
split = int(np.floor(args.validation_split * len(siamese_train_dataset)))
indices = list(range(len(siamese_train_dataset)))
train_indices, val_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

siamese_train_loader = torch.utils.data.DataLoader(siamese_train_dataset, batch_size = args.learning_batch_size,
                                                   sampler=train_sampler)
siamese_val_loader = torch.utils.data.DataLoader(siamese_train_dataset, batch_size = args.learning_batch_size,
                                                  sampler=valid_sampler)

simase_test_dataset = SiameseMNIST(test_dataset)
siamese_test_loader = torch.utils.data.DataLoader(simase_test_dataset, batch_size = 1, shuffle=True)

dataloaders = []
dataloaders.append(siamese_train_loader)
dataloaders.append(siamese_val_loader)

writer_train = SummaryWriter('runs/train_0')
writer_test = SummaryWriter('runs/test_0')

class EmbeddingNet(nn.Module):
    def __init__(self):
        super(EmbeddingNet, self).__init__()
        self.convnet = nn.Sequential(nn.Conv2d(1, 32, 5), nn.PReLU(),
                                     nn.MaxPool2d(2, stride=2),
                                     nn.Conv2d(32, 64, 5), nn.PReLU(),
                                     nn.MaxPool2d(2, stride=2))

        self.fc = nn.Sequential(nn.Linear(64 * 4 * 4, 256),
                                nn.PReLU(),
                                nn.Linear(256, 256),
                                nn.PReLU(),
                                nn.Linear(256, 64)
                                )

    def forward(self, x):
        output = self.convnet(x)
        output = output.view(output.size()[0], -1)
        embedding = self.fc(output)
        return embedding

    def get_embedding(self, x):
        return self.forward(x)

class SimaseNet(nn.Module):

    def __init__(self, embeddingNet):
        super(SimaseNet,self).__init__()
        self.embeddingNet = embeddingNet

        self.fc = nn. Sequential(nn.Linear(64,32),
                                 nn.ReLU(),
                                 nn.Linear(32,16)
                                 )

    def forward_once(self, x):
        x = x.view(x.size()[0], -1)
        x = self.fc(x)
        return x

    def forward(self, in1, in2):
        emd_in1 = self.embeddingNet(in1)
        emd_in2 =  self.embeddingNet(in2)
        out1 = self.forward_once(emd_in1)
        out2 = self.forward_once(emd_in2)
        return out1, out2

    def get_embedding(self, x):
        return self.embeddingNet(x)

class ContrastiveLoss(nn.Module):

    def __init__(self, margin):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, target):
        eq_distance = F.pairwise_distance(output1, output2)
        loss = 0.5 * (1 - target.float()) * torch.pow(eq_distance, 2) + \
               0.5 * target.float() * torch.pow(torch.clamp(self.margin - eq_distance, min=0.00), 2)
        return loss.mean()

embeddingNet = EmbeddingNet()
model = SimaseNet(embeddingNet)

def extract_embeddings(dataloader, model,epoch):
    with torch.no_grad():
        model.eval()
        for idx, batch in enumerate(dataloader):
            image1, image2, target = batch
            embeddings_1 = model.get_embedding(image1)
            writer_train.add_embedding(
                mat=embeddings_1, metadata=target, label_img=image1, global_step=idx*epoch)

extract_embeddings(siamese_train_loader,model,1)
'''
if args.mode == 'train':

    criterion = ContrastiveLoss(margin=args.constractive_loss_margin)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
    print('training started')
    minimum_loss = float('Inf')

    training_losses = []
    validation_losses = []
    epochs = []

    for epoch in range(1, args.num_epochs + 1):
        epochs.append(epoch)

        for dataloader_idx, dataloader in enumerate(dataloaders):

            stage = ''

            if dataloader == siamese_val_loader:
                model = model.eval()
            else:
                model = model.train()

            epoch_loss = []

            for  batch_idx, batch in enumerate(dataloader):

                img1, img2, target = batch
                out1, out2 = model(img1, img2)
                loss = criterion(out1, out2, target)

                if dataloader == siamese_train_loader:
                    loss.backward()
                    optimizer.step()
                    optimizer.zero_grad()

                epoch_loss.append(loss.item())

            if dataloader == siamese_train_loader:
                training_losses.append(np.mean(epoch_loss))
                stage = 'train'
            else:
                validation_losses.append(np.mean(epoch_loss))
                stage = 'eval'

            if dataloader == siamese_train_loader:
                print(
                    f'epoch: {epoch} stage: {stage} loss: {np.mean(epoch_loss)}')
                writer_train.add_scalar('Loss', np.mean(epoch_loss), epoch)
                extract_embeddings(dataloader, model,epoch)

            else:
                print(f'epoch: {epoch} stage: {stage} loss: {np.mean(epoch_loss)}')
                writer_test.add_scalar('Loss', np.mean(epoch_loss), epoch)


            if dataloader == siamese_val_loader:
                if minimum_loss > np.mean(validation_losses):
                    minimum_loss = np.mean(validation_losses)
                    torch.save(model, 'simase_best.pt')
                    print('Model is saving')


    plt.plot(epochs, training_losses, label="train")
    plt.plot(epochs, validation_losses, label="eval")
    plt.xlabel('epochs')
    plt.ylabel('loss')
    plt.title('training vs validation loss')
    plt.legend()
    plt.show()
'''