In [1]:
import torch
import torch.nn as nn
import os
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils, models
from torch.nn import functional as F
from torch.autograd import Variable
from tensorboardX import SummaryWriter
from tqdm import trange, tqdm
from PIL import Image
import torch.utils.model_zoo as model_zoo
import pandas as pd

In [2]:
root_dir = 'tiny-imagenet-200/train/'
writer = SummaryWriter()
batch_size = 512
np.random.seed(42)
omniglot_dir = 'omniglot/'
rtsd_dir = 'rtsd-r1/'

In [3]:
if False:
    classes = os.listdir(root_dir)
    test_classes = np.random.choice(classes, size = int(0.1 * len(classes)), replace = False)
    classes = [i for i in classes if i not in test_classes]
    def fin_dir(cl):
        return root_dir + cl + '/images/'
elif True:
    classes = []
    for i in os.listdir(omniglot_dir):
        for j in os.listdir(omniglot_dir + i):
            classes.append(i + '/' + j)
    test_classes = np.random.choice(classes, size = int(0.1 * len(classes)), replace = False)
    classes = [i for i in classes if i not in test_classes]
    def fin_dir(cl):
        return omniglot_dir + cl + '/'
else:
    classes = os.listdir(rtsd_dir)
    test_classes = np.random.choice(classes, size = int(0.1 * len(classes)), replace = False)
    classes = [i for i in classes if i not in test_classes]
    def fin_dir(cl):
        return rtsd_dir + cl + '/'

In [19]:
a = []
for i in set(train_labels.to_dict()['class_number'].values()):
    a.append([])
    for j in range(len(train_labels)):
        if list(train_labels.to_dict()['class_number'].items())[j][1] == i:
            a[-1].append([list(train_labels.to_dict()['class_number'].items())[j][0]])

NameError: name 'train_labels' is not defined

In [36]:
train_labels.get_value(1, 'filename')

  """Entry point for launching an IPython kernel.


'000001.png'

In [41]:
for i in range(len(a)):
    a[i] = [train_labels.get_value(a[i][j][0], 'filename') for j in range(len(a[i]))]

  


In [46]:
for i in range(len(a)):
    for j in a[i]:
        os.rename(rtsd_dir + 'train/' + j, rtsd_dir + str(i) + '/' + j)

In [4]:
class SiameseNetwork(nn.Module):
    def __init__(self, classes_n):
        super(SiameseNetwork, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2))

        self.fc = nn.Sequential(
            nn.Linear(257, 128),
            nn.Sigmoid(),
            nn.Linear(128, 1))
        
        self.dist = nn.CosineSimilarity()
        
    def forward_once(self, x):
        output = self.cnn(x)
        output = output.view(output.size()[0], -1)
        return output
    
    def forward(self, input1, input2):
        out1 = self.forward_once(input1)
        out2 = self.forward_once(input2)
        out = torch.cat([torch.abs(out1 - out2), (out1 - out2) ** 2, self.dist(out1, out2).view(-1, 1)], dim = 1)
        out = self.fc(out)
        return out
    
    def forward_one_shot(self, x, support_set):
        preds = []
        for img in support_set:
            preds.append(self.forward(x, img).data.cpu().numpy()[0])
        res = np.argmin(np.array(preds))
        return res

In [5]:
class ImageDataset(Dataset):
    def __init__(self, images):
        self.images = images
        self.n = len(images)
        self.cur_class = 0
        self.cur_img = 0
        self.size = 0
        for i in images:
            self.size += len(i)
        
    def __len__(self):
        return self.size
        
    def __getitem__(self, idx):
        class1 = self.cur_class
        if np.random.uniform() < 0.5:
            class2 = np.random.randint(0, self.n)
            res = 1
        else:
            class2 = class1
            res = 0
        
        img1 = self.images[class1][self.cur_img]
        img2 = self.images[class2][np.random.randint(0, high = len(self.images[class2]))]
        self.cur_img += 1
        if self.cur_img >= len(self.images[self.cur_class]):
            self.cur_img = 0
            self.cur_class = (self.cur_class + 1) % self.n
        
        sample = {'img1' : torch.from_numpy(img1), 'img2' : torch.from_numpy(img2), 'res' : res}
        return sample

In [6]:
class TestDataset(Dataset):
    def __init__(self, classes):
        self.images = []
        self.support_set = []
        self.n = len(classes)
        self.cur_class = 0
        self.cur_img = 0
        self.size = 0
        for cl in classes:
            self.images.append([])
            sup = np.random.choice(os.listdir(fin_dir(cl)))
            img = Image.open(fin_dir(cl) + sup)
            img = img.resize((28, 28))
            img = np.array(img, dtype = 'float32') / 255
            if not len(img.shape) == 3:
                img = np.stack([img, img, img])
            else:
                img = img.transpose((2, 0, 1))
            self.support_set.append(torch.Tensor(img))
            for img_name in os.listdir(fin_dir(cl)):
                if not sup == img_name:
                    img = Image.open(fin_dir(cl) + img_name)
                    img = img.resize((28, 28))
                    img = np.array(img, dtype = 'float32') / 255
                    if not len(img.shape) == 3:
                        img = np.stack([img, img, img])
                    else:
                        img = img.transpose((2, 0, 1))
                    self.images[-1].append(img)
                    self.size += 1
        self.support_set = torch.stack(self.support_set)
        
    def __len__(self):
        return self.size
        
    def __getitem__(self):
        img1 = self.images[self.cur_class][self.cur_img]
        ind = np.random.randint(0, self.n, size = 5)
        if not self.cur_class in ind:
            res = np.random.randint(0, 5)
            ind[res] = self.cur_class
        else:
            res = np.where(ind == self.cur_class)[0][0]
        sup = torch.stack([self.support_set[a] for a in ind])
        self.cur_img += 1
        if self.cur_img >= len(self.images[self.cur_class]):
            self.cur_img = 0
            self.cur_class = (self.cur_class + 1) % self.n
        sample = {'img1' : torch.from_numpy(img1.astype('float32')), 'res' : res, 'support_set': sup}
        return sample

In [7]:
images = []
for cl in tqdm(classes):
    images.append([])
    for img_name in os.listdir(fin_dir(cl)):
        img = np.array(Image.open(fin_dir(cl) + img_name).resize((28, 28)), dtype = 'float32') / 255
        if not len(img.shape) == 3:
            img = np.stack([img, img, img])
        else:
            img = img.transpose((2, 0, 1))
        images[-1].append(img)

100%|██████████| 868/868 [00:02<00:00, 322.53it/s]


In [8]:
model = SiameseNetwork(len(classes))
model.cuda()
coef = 0.0

In [9]:
dataset = ImageDataset(images)
dataloader = DataLoader(dataset, batch_size = batch_size)
one_shot_dataset = TestDataset(test_classes)

In [10]:
sgd = torch.optim.Adam(model.parameters(), weight_decay = coef)
dist = nn.CosineSimilarity()
criterion = nn.BCEWithLogitsLoss()
sigm = nn.Sigmoid()
for i in trange(800):
    model.train()
    sum_loss = 0
    acc = 0
    for j, batch in enumerate(dataloader):
        target = Variable(batch['res']).float().cuda()
        y_pred = model.forward(Variable(batch['img1']).cuda(), Variable(batch['img2']).cuda())
        y_pred = y_pred.view(-1)
        x = torch.round(sigm(y_pred))
        acc += torch.mean((x == target).float())
        loss = criterion(y_pred, target)
        sum_loss += loss
        sgd.zero_grad()
        loss.backward()
        sgd.step()
    writer.add_scalar('cross entropy', sum_loss / len(dataloader), i)
    writer.add_scalar('train accuracy', acc / len(dataloader), i)
    model.eval()
    correct = 0
    for j in range(len(one_shot_dataset)):
        sample = one_shot_dataset.__getitem__()
        target = sample['res']
        y_pred = model.forward(Variable(sample['img1']).unsqueeze(0).cuda(), Variable(sample['support_set']).cuda())
        pred = np.argmin(y_pred.data.cpu().numpy())
        if target == pred:
            correct += 1
    writer.add_scalar('one shot accuracy', correct / len(one_shot_dataset), i)

100%|██████████| 800/800 [1:04:06<00:00,  4.81s/it]


In [11]:
torch.save(model, 'SiameseNet')

  "type " + obj.__name__ + ". It won't be checked "
