In [1]:
import torch
import torch.nn as nn
import os
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils, models
from torch.nn import functional as F
from torch.autograd import Variable
from tensorboardX import SummaryWriter
from tqdm import trange, tqdm
from PIL import Image
import torch.utils.model_zoo as model_zoo
import pandas as pd
import bcolz
import pickle

In [2]:
root_dir = 'tiny-imagenet-200/train/'
writer = SummaryWriter()
batch_size = 512
np.random.seed(42)
omniglot_dir = 'omniglot/'
rtsd_dir = 'rtsd-r1/'

In [3]:
if True:
    classes = os.listdir(root_dir)
    test_classes = np.random.choice(classes, size = int(0.1 * len(classes)), replace = False)
    classes = [i for i in classes if i not in test_classes]
    def fin_dir(cl):
        return root_dir + cl + '/images/'
elif False:
    classes = []
    for i in os.listdir(omniglot_dir):
        for j in os.listdir(omniglot_dir + i):
            classes.append(i + '/' + j)
    test_classes = np.random.choice(classes, size = int(0.1 * len(classes)), replace = False)
    classes = [i for i in classes if i not in test_classes]
    def fin_dir(cl):
        return omniglot_dir + cl + '/'
else:
    classes = os.listdir(rtsd_dir)
    test_classes = np.random.choice(classes, size = int(0.1 * len(classes)), replace = False)
    classes = [i for i in classes if i not in test_classes]
    def fin_dir(cl):
        return rtsd_dir + cl + '/'

In [4]:
with open('tiny-imagenet-200/words.txt') as file:
    a = file.readlines()
a = [i[:-1].split('\t') for i in a]
clss = [i[0] for i in a]
a = [i[1].replace(',', '').split() for i in a]

In [5]:
words = []
idx = 0
word2idx = {}
vectors = bcolz.carray(np.zeros(1), rootdir=f'glove.6B/6B.50.dat', mode='w')

with open(f'glove.6B/glove.6B.50d.txt', 'rb') as f:
    for l in f:
        line = l.decode().split()
        word = line[0]
        words.append(word)
        word2idx[word] = idx
        idx += 1
        vect = np.array(line[1:]).astype(np.float)
        vectors.append(vect)
    
vectors = bcolz.carray(vectors[1:].reshape((400001, 50)), rootdir=f'glove.6B/6B.50.dat', mode='w')
vectors.flush()
pickle.dump(words, open(f'glove.6B/6B.50_words.pkl', 'wb'))
pickle.dump(word2idx, open(f'glove.6B/6B.50_idx.pkl', 'wb'))

In [6]:
vectors = bcolz.open(f'glove.6B/6B.50.dat')[:]
words = pickle.load(open(f'glove.6B/6B.50_words.pkl', 'rb'))
word2idx = pickle.load(open(f'glove.6B/6B.50_idx.pkl', 'rb'))

glove = {w: vectors[word2idx[w]] for w in words}

In [7]:
a = [[glove[j] for j in i if j in glove] for i in a]
a = [np.mean(i, axis = 0) for i in a]
names = {}
for i in range(len(clss)):
    names[clss[i]] = a[i]

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


In [8]:
classes = os.listdir(root_dir)
classes = [i for i in classes if not type(names[i]) == np.float64]
test_classes = np.random.choice(classes, size = 20, replace = False)
classes = [i for i in classes if i not in test_classes]
att = 1
def fin_dir(cl):
    return root_dir + cl + '/images/'

In [9]:
class SiameseNetwork(nn.Module):
    def __init__(self, classes_n):
        super(SiameseNetwork, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2))

        self.fc = nn.Sequential(
            nn.Linear(257, 128),
            nn.Sigmoid(),
            nn.Linear(128, 1))
        
        self.dist = nn.CosineSimilarity()
        
    def forward_once(self, x):
        output = self.cnn(x)
        output = output.view(output.size()[0], -1)
        return output
    
    def forward(self, input1, input2):
        out1 = self.forward_once(input1)
        out2 = self.forward_once(input2)
        out = torch.cat([torch.abs(out1 - out2), (out1 - out2) ** 2, self.dist(out1, out2).view(-1, 1)], dim = 1)
        out = self.fc(out)
        return out
    
    def forward_one_shot(self, x, support_set):
        preds = []
        for i in range(5):
            preds.append(self.forward(x, support_set[:, i]).data.cpu().numpy()[0])
        res = np.argmin(np.array(preds))
        return res

In [10]:
class ImageDataset(Dataset):
    def __init__(self, images):
        self.images = images
        self.n = len(images)
        self.cur_class = 0
        self.cur_img = 0
        self.size = 0
        for i in images:
            self.size += len(i)
        
    def __len__(self):
        return self.size
        
    def __getitem__(self, idx):
        class1 = self.cur_class
        if np.random.uniform() < 0.5:
            class2 = np.random.randint(0, self.n)
            while class1 == class2:
                class2 = np.random.randint(0, self.n)
            res = 1
        else:
            class2 = class1
            res = 0
        
        img1 = self.images[class1][self.cur_img]
        img2 = self.images[class2][np.random.randint(0, high = len(self.images[class2]))]
        self.cur_img += 1
        if self.cur_img >= len(self.images[self.cur_class]):
            self.cur_img = 0
            self.cur_class = (self.cur_class + 1) % self.n
        
        sample = {'img1' : torch.from_numpy(img1), 'img2' : torch.from_numpy(img2), 'res' : res}
        return sample

In [11]:
class TestDataset(Dataset):
    def __init__(self, images):
        self.images = images
        self.n = len(images)
        self.cur_class = 0
        self.cur_img = 0
        self.size = 0
        for i in images:
            self.size += len(i)
        
    def __len__(self):
        return self.size
        
    def __getitem__(self, idx):
        img = self.images[self.cur_class][self.cur_img]
        ind = np.random.randint(0, self.n, size = 5)
        if not self.cur_class in ind:
            res = np.random.randint(0, 5)
            ind[res] = self.cur_class
        else:
            res = np.where(ind == self.cur_class)[0][0]
        sup = np.stack([self.images[a][np.random.randint(0, len(self.images[a]))] for a in ind])
        self.cur_img += 1
        if self.cur_img >= len(self.images[self.cur_class]):
            self.cur_img = 0
            self.cur_class = (self.cur_class + 1) % self.n
        sample = {'img' : torch.from_numpy(img), 'support_set' : torch.from_numpy(sup), 'res' : int(res)}
        return sample

In [12]:
images = []
for cl in tqdm(classes):
    images.append([])
    for img_name in os.listdir(fin_dir(cl)):
        img = np.array(Image.open(fin_dir(cl) + img_name).resize((28, 28)), dtype = 'float32') / 255
        if not len(img.shape) == 3:
            img = np.stack([img, img, img])
        else:
            img = img.transpose((2, 0, 1))
        images[-1].append(img)

100%|██████████| 177/177 [00:18<00:00,  9.33it/s]


In [13]:
one_shot_images = []
for cl in tqdm(test_classes):
    one_shot_images.append([])
    for img_name in os.listdir(fin_dir(cl)):
        img = np.array(Image.open(fin_dir(cl) + img_name).resize((28, 28)), dtype = 'float32') / 255
        if not len(img.shape) == 3:
            img = np.stack([img, img, img])
        else:
            img = img.transpose((2, 0, 1))
        one_shot_images[-1].append(img)

100%|██████████| 20/20 [00:02<00:00,  9.26it/s]


In [14]:
model = SiameseNetwork(len(classes))
model.cuda()
coef = 0.0

In [15]:
dataset = ImageDataset(images)
dataloader = DataLoader(dataset, batch_size = batch_size)
one_shot_dataset = TestDataset(one_shot_images)
one_shot_dataloader = DataLoader(one_shot_dataset, batch_size = batch_size)

In [18]:
sgd = torch.optim.Adam(model.parameters(), weight_decay = coef)
dist = nn.CosineSimilarity()
criterion = nn.BCEWithLogitsLoss()
sigm = nn.Sigmoid()
for i in trange(800):
    model.train()
    sum_loss = 0
    acc = 0
    for j, batch in enumerate(dataloader):
        target = Variable(batch['res']).float().cuda()
        y_pred = model.forward(Variable(batch['img1']).cuda(), Variable(batch['img2']).cuda())
        y_pred = y_pred.view(-1)
        x = torch.round(sigm(y_pred))
        acc += torch.mean((x == target).float())
        loss = criterion(y_pred, target)
        sum_loss += loss
        sgd.zero_grad()
        loss.backward()
        sgd.step()
    writer.add_scalar('cross entropy', sum_loss / len(dataloader), i)
    writer.add_scalar('train accuracy', acc / len(dataloader), i)
    model.eval()
    correct = 0
    for j in range(len(one_shot_dataset)):
        sample = one_shot_dataset.__getitem__(0)
        target = sample['res']
        y_pred = model.forward(Variable(sample['img']).unsqueeze(0).cuda(), Variable(sample['support_set']).cuda())
        pred = np.argmin(y_pred.data.cpu().numpy())
        if target == pred:
            correct += 1
    writer.add_scalar('one shot accuracy', correct / len(one_shot_dataset), i)

 36%|███▋      | 291/800 [2:05:55<3:40:15, 25.96s/it]

KeyboardInterrupt: 

In [None]:
torch.save(model, 'SiameseNet')