In [None]:
! pip install pytorch-metric-learning
! pip install MulticoreTSNE

In [None]:
import timm
import torch
from torch import optim
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
from cycler import cycler
from torch import optim
import os

import matplotlib.pyplot as plt
#from sklearn.manifold import TSNE
from MulticoreTSNE import MulticoreTSNE as TSNE
from pytorch_metric_learning.miners import TripletMarginMiner
from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.losses import TripletMarginLoss
from pytorch_metric_learning.reducers import ThresholdReducer

import glob
import cv2
from PIL import Image

In [None]:
class CFG:
    data_path = "./dataset/"
    exp = "ex1"
    model_name = ["vgg16","efficientnet_b0","resnet34d"]
    epoch = 20
    pretrained = True
    inp_channels = 1
    batch_size = 64
    lr = 1e-5
    out_features = 128
    seed = 42
    device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_everything(seed=CFG.seed)


OUTPUT_DIR = './'+CFG.exp+'/'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

In [None]:
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

train_dataset = datasets.FashionMNIST(CFG.data_path, train=True, download=False, transform=transform)
test_dataset = datasets.FashionMNIST(CFG.data_path, train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=CFG.batch_size, shuffle=False)

In [None]:
class CustomModel(nn.Module):
    def __init__(
        self, model_name=CFG.model_name, n_class=CFG.out_features, pretrained=CFG.pretrained, in_chans=CFG.inp_channels):
        super().__init__()
        self.backbone = timm.create_model(
            model_name, pretrained=pretrained, num_classes = n_class, in_chans=in_chans)
      
    def forward(self, x):
        x = self.backbone(x)
        output = x
        return output

In [None]:
def train(model, loss_func, mining_func, device, dataloader, optimizer, epoch):
    model.train() 
    for idx, (inputs, labels) in enumerate(dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        embeddings = model(inputs)
        indices_tuple = mining_func(embeddings, labels)
        loss = loss_func(embeddings, labels, indices_tuple)
        loss.backward()
        optimizer.step()
        if idx % 100 == 0:
            print('Epoch {} Iteration {}: Loss = {}'.format(epoch, idx, loss))

In [None]:
def test(model, dataloader, device, epoch):
    _predicted_metrics = []
    _true_labels = []
    with torch.no_grad():    
        for i, (inputs,  labels) in enumerate(dataloader):
            inputs, labels = inputs.to(device), labels.to(device)
            metric = model(inputs).detach().cpu().numpy()
            metric = metric.reshape(metric.shape[0], metric.shape[1])
            _predicted_metrics.append(metric)
            _true_labels.append(labels.detach().cpu().numpy())
    return np.concatenate(_predicted_metrics), np.concatenate(_true_labels)

In [None]:
def tsne(epoch):
    inf_model = CustomModel(model_name=CFG.model_name[2]).to(CFG.device)
    inf_model.load_state_dict(torch.load(OUTPUT_DIR + "/" + f'epoch{epoch}_model.pth'))
    test_predicted_metrics, test_true_labels = test(inf_model, test_loader, CFG.device, epoch)
    tSNE_metrics = TSNE(n_components=2, random_state=0,n_jobs=4).fit_transform(test_predicted_metrics)

    plt.scatter(tSNE_metrics[:, 0], tSNE_metrics[:, 1], c=test_true_labels)
    plt.colorbar()
    plt.savefig(OUTPUT_DIR + "/" + f"output_epoch{epoch}.jpg")
    plt.xlim([-40,40])
    plt.ylim([-40,40])
    plt.show()

In [None]:
def create_gif():
    pictures=[]
    for i in range(1,CFG.epoch+1):
        pic_name=OUTPUT_DIR + f"output_epoch{i}.jpg"
        img = Image.open(pic_name)
        pictures.append(img)
    
    pictures[0].save(OUTPUT_DIR + "/" + 'anime.gif',save_all=True, append_images=pictures[1:],
    optimize=False, duration=500, loop=0)

In [None]:
if __name__ == '__main__':
    
    model = CustomModel(model_name=CFG.model_name[2]).to(CFG.device)
    optimizer = optim.Adam(model.parameters(), lr=CFG.lr)
    
    test_predicted_metrics = []
    test_true_labels = []
    model_loss = 0
    best_model_loss = 100


    distance = CosineSimilarity()
    reducer = ThresholdReducer(low = 0)
    loss_func = TripletMarginLoss(margin=0.2, distance=distance, reducer=reducer)
    mining_func = TripletMarginMiner(margin=0.2, distance=distance)
    
    for epoch in range(1, CFG.epoch + 1):
        print('Epoch {}/{}'.format(epoch, CFG.epoch))
        print('-' * 10)
        train(model, loss_func, mining_func, CFG.device, train_loader, optimizer, epoch)
#         if epoch % 5 == 0 or epoch ==1:
        torch.save(model.state_dict(),OUTPUT_DIR + "/" + f'epoch{epoch}_model.pth')
        tsne(epoch)
    
    #create mp4 
    create_gif()

In [None]:
# tSNE_metrics = TSNE(n_components=2, random_state=0).fit_transform(test_predicted_metrics)

# plt.scatter(tSNE_metrics[:, 0], tSNE_metrics[:, 1], c=test_true_labels)
# plt.colorbar()
# plt.savefig("output.png")
# plt.show()