In [2]:
from time import sleep
import pickle
import os
from threading import Thread
import queue
from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.metrics.pairwise import cosine_similarity

from src.utils import load_model, get_dataloaders, load_images_in_folder, show_images, modify_keys, save_results
from src.train import train_model

In [3]:
lr = 0.0001
batch_size = 32
num_epochs = 30

TRAIN_PATH = '/mnt/hdd/1/imageData/train/russianDataCleanAdded'
TEST_PATH = '/mnt/hdd/1/imageData/index/russianDataCleanAdded'
num_classes = len(os.listdir(TRAIN_PATH))
        
state_path = 'state_resnet.pkl'
model_name = 'landmark'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [4]:
train_q = queue.Queue(10) #queue for 10 elements
test_q = queue.Queue(10) 

datasets, dataloaders = get_dataloaders(TRAIN_PATH, TEST_PATH, batch_size)

In [5]:
class ResnetClassifier(nn.Module):
    def __init__(self):
        super(ResnetClassifier, self).__init__()
        self.model = torchvision.models.resnet50(pretrained=True)
        in_features = self.model.fc.in_features
        self.model.fc = nn.Sequential(nn.Linear(in_features, 512), 
                                      nn.ReLU(), 
                                      nn.Dropout(0.4),
                                      )
        self.last = nn.Linear(512, num_classes)
        
        
    def __call__(self, x):
        y = self.model(x)
        return y, self.last(y)
    
    def check_predictions(self, dataloader):
        ys = []
        pred = []
        with torch.no_grad():
            for x, y in tqdm(dataloader):
                _, output = self(x.to(device))
                pred.append(torch.argmax(output, dim=1))
                ys.extend(y)
        correct = {}
        pred = torch.cat(pred).cpu()

        for y, p in zip(ys, pred.cpu()):
            correct[y.item()] = correct.get(y.item(), np.array([0, 0])) + np.array([y == p, 1])
        return accuracy_score(ys, pred), correct
        
    def confusion_matrix(self, dataloader):
        ys = []
        pred = []
        with torch.no_grad():
            for x, y in dataloader:
                output = self(x.to(device))
                pred.append(torch.argmax(output, dim=1))
                ys.extend(y)
        return confusion_matrix(ys, torch.cat(pred).cpu())
    
    def predictions_for_class(self, x):
        with torch.no_grad():
            output = self(x.to(device))
            return torch.sort(torch.softmax(output.cpu(), dim=1), dim=1)
    

In [6]:
model = ResnetClassifier().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)
criterion = nn.CrossEntropyLoss()
train_loss, val_loss = train_model(dataloaders, device, model, criterion, optimizer, state_path, model_name, num_epochs=num_epochs,
                                  continue_train=False, scheduler=scheduler)

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

after 50 loss is 4.393304347991943
after 100 loss is 4.480538845062256
after 150 loss is 3.8888838291168213
after 200 loss is 4.0484619140625


KeyboardInterrupt: 

In [None]:
with open(state_path, 'rb') as f:
    state = pickle.load(f)
train_loss = state['loss']
val_loss = state['val_losses']

plt.plot(np.arange(len(train_loss)), train_loss)
plt.plot(np.arange(len(val_loss)), val_loss)
plt.legend(('train', 'validation'));

In [None]:
state['accuracy']

In [None]:
model2 = ResnetClassifier().to(device)
model2 = load_model(model2, model_name, 25)
model2.eval();
# acc, correct = model2.check_predictions(dataloaders['test'])
# correct = modify_keys(correct, datasets['train'])
# print(acc)

In [None]:
# clazz = 126432
# prefix = '/mnt/hdd/1/imageData/index/russianData/' + str(clazz)
# images_np, images_torch = load_images_in_folder(prefix)

# ys = model2.predictions_for_class(images_torch)

# class_ind = datasets['train'].classes.index(str(clazz))
# show_images(images_np, titles=ys[0][:, -3:].numpy().tolist(), correctness=ys[1][:, -1] == class_ind, cols=5)

In [None]:
def calc_centroids(model, loader, path='models/centers_classic.pkl'):
    if os.path.exists(path):
        with open(path, 'rb') as f:
            return pickle.load(f)
    centers = np.zeros((num_classes, 512))
    cnt = np.zeros(num_classes)
    for x, y in tqdm(loader):
        feat, _ = model(x.to(device))
        centers[y] += feat.detach().cpu().numpy()
        cnt[y] += 1
        
    for i in range(num_classes):
        centers[i] /= (cnt[i] + 0.0000001)
    centers = centers[:-1]
    with open(path, 'wb') as f:
        pickle.dump(centers, f)
        
    return centers


centers = calc_centroids(model2, dataloaders['train'])

In [None]:
def check_classes(a, b):
    datasets['train'].classes

def centroid_test(loader):
    correct = 0
    for x, y in tqdm(loader):
        features, _ = model2(x.to(device))
        for xx, yy in zip(features, y):
            d = cosine_similarity(centers, 
                                  xx.detach().cpu().reshape(1, -1)).reshape(-1)
            if d.argmax() == yy:
                correct += 1
            
    return correct

res = centroid_test(dataloaders['test'])

In [None]:
res / len(datasets['test'])