In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torchvision
from torchvision import transforms
from torchvision.transforms import ToTensor, ToPILImage
from torchvision import datasets
from torchvision import models
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import defaultdict
import random
from copy import deepcopy

In [3]:
class cnn(nn.Module):
    def __init__(self) -> None:
        super(cnn, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 8, 3, padding=1),
            nn.MaxPool2d(2),
            nn.ReLU())
        self.conv2 = nn.Sequential(
            nn.Conv2d(8, 8, 3, padding=1),
            nn.MaxPool2d(2),
            nn.ReLU())
        self.flatten = nn.Flatten()
        self.fc1 = nn.Sequential(
            nn.Linear(7*7*8, 28),
            nn.ReLU())
        self.fc2 = nn.Linear(28, 10)

    def forward(self, inp):
        x = self.conv1(inp)
        x = self.conv2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        return self.fc2(x)
        


train_dataset = datasets.MNIST(root='../data/', transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='../data/', transform=transforms.ToTensor(), train=False)

train = DataLoader(train_dataset, batch_size=64, shuffle=True)
test = DataLoader(test_dataset, batch_size=64, shuffle=True)

model = cnn()
optimizer = optim.Adam(params=model.parameters())

for i in tqdm(range(10)):
    running_loss = 0
    for x, y in train:
        pred = model.forward(x)
        loss = F.cross_entropy(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss
    running_loss /= len(train)
    print(f'Loss: {running_loss}')
    correct = 0
    for x, y in test:
        pred = model.forward(x)
        correct += (pred.argmax(1) == y).sum().item()
    print(f'Testing accuracy: {correct/test_dataset.__len__()}')

  0%|          | 0/10 [00:00<?, ?it/s]

Loss: 0.4235815405845642


 10%|█         | 1/10 [00:10<01:31, 10.18s/it]

Testing accuracy: 0.9567
Loss: 0.12916457653045654


 20%|██        | 2/10 [00:18<01:12,  9.10s/it]

Testing accuracy: 0.9657
Loss: 0.09494631737470627


 30%|███       | 3/10 [00:27<01:01,  8.83s/it]

Testing accuracy: 0.9743
Loss: 0.07881827652454376


 40%|████      | 4/10 [00:35<00:52,  8.68s/it]

Testing accuracy: 0.9758
Loss: 0.06688041985034943


 50%|█████     | 5/10 [00:43<00:42,  8.58s/it]

Testing accuracy: 0.9775
Loss: 0.05979124829173088


 60%|██████    | 6/10 [00:51<00:33,  8.35s/it]

Testing accuracy: 0.9818
Loss: 0.05225615203380585


 70%|███████   | 7/10 [00:59<00:24,  8.16s/it]

Testing accuracy: 0.9807
Loss: 0.04774440824985504


 80%|████████  | 8/10 [01:07<00:16,  8.01s/it]

Testing accuracy: 0.9774
Loss: 0.0433792769908905


 90%|█████████ | 9/10 [01:14<00:07,  7.92s/it]

Testing accuracy: 0.9835
Loss: 0.03995371609926224


100%|██████████| 10/10 [01:23<00:00,  8.35s/it]

Testing accuracy: 0.9846





In [4]:
dataset = datasets.FashionMNIST(root='../data/', transform=transforms.ToTensor())

In [5]:
images = defaultdict(list)

counts = [4]*10
for i in dataset:
    if counts[i[1]]:
        images[i[1]].append(i[0])
        counts[i[1]] -= 1
    if sum(counts)==0:
        break

In [6]:

class siamese(nn.Module):
    def __init__(self, embedding) -> None:
        super(siamese, self).__init__()
        self.embedding = embedding
        self.fc = nn.Linear(10, 5)
    
    def forward(self, inp):
        x = self.embedding.forward(inp)
        return self.fc(x)

model.requires_grad_(False)
sim_model = siamese(model)
optimizer = optim.Adam(params=sim_model.parameters())
EPOCHS = 10000


for i in tqdm(range(1, EPOCHS+1)):
    true_class = random.randint(0, 9)
    true_img1_idx = random.randint(0, 1)
    true_img1 = images[true_class][true_img1_idx]
    true_img2 = images[true_class][1-true_img1_idx]
    false_img_idx = random.randint(0, 9)
    while false_img_idx == true_img1_idx:
        false_img_idx = random.randint(0, 9)
    false_img = images[false_img_idx][random.randint(0, 1)]

    out1 = sim_model(true_img1.unsqueeze(0))
    out2 = sim_model(true_img2.unsqueeze(0))
    out3 = sim_model(false_img.unsqueeze(0))

    d_plus = F.mse_loss(out1, out2)**2
    d_minus = F.mse_loss(out1, out3)**2
    d = d_plus + 10 - d_minus
    loss = F.relu(d)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


100%|██████████| 10000/10000 [00:15<00:00, 644.20it/s]


In [7]:
img1 = images[2][2].unsqueeze(0)
img2 = images[2][3].unsqueeze(0)
img3 = images[5][2].unsqueeze(0)

with torch.no_grad():
    out1 = sim_model.forward(img1)
    out2 = sim_model.forward(img2)
    out3 = sim_model.forward(img3)
    print(f'distance between True: {F.mse_loss(out1, out2)}')
    print(f'distance between False: {F.mse_loss(out1, out3)}')

distance between True: 2.872373580932617
distance between False: 41.127967834472656


# Fine-Tuning

In [8]:
model_copy = deepcopy(model)
model_copy.fc2.requires_grad_(True)

optimizer = optim.Adam(params=model_copy.parameters())

for i in tqdm(range(1, EPOCHS+1)):
    true_class = random.randint(0, 9)
    true_img1_idx = random.randint(0, 1)
    true_img1 = images[true_class][true_img1_idx]
    true_img2 = images[true_class][1-true_img1_idx]
    false_img_idx = random.randint(0, 9)
    while false_img_idx == true_img1_idx:
        false_img_idx = random.randint(0, 9)
    false_img = images[false_img_idx][random.randint(0, 1)]

    out1 = model_copy(true_img1.unsqueeze(0))
    out2 = model_copy(true_img2.unsqueeze(0))
    out3 = model_copy(false_img.unsqueeze(0))

    d_plus = F.mse_loss(out1, out2)**2
    d_minus = F.mse_loss(out1, out3)**2
    d = d_plus + 10 - d_minus
    loss = F.relu(d)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


100%|██████████| 10000/10000 [00:10<00:00, 974.60it/s]


In [9]:
img1 = images[2][2].unsqueeze(0)
img2 = images[2][3].unsqueeze(0)
img3 = images[5][2].unsqueeze(0)

with torch.no_grad():
    out1 = model_copy.forward(img1)
    out2 = model_copy.forward(img2)
    out3 = model_copy.forward(img3)
    print(f'distance between True: {F.mse_loss(out1, out2)}')
    print(f'distance between False: {F.mse_loss(out1, out3)}')

distance between True: 6.500158786773682
distance between False: 38.35658645629883
