In [1]:
import torch
import torch.nn as nn
import os
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils
from skimage import io, transform
from torch.nn import functional as F
from torch.autograd import Variable
from tensorboardX import SummaryWriter
from tqdm import trange
from PIL import Image

In [2]:
root_dir = 'tiny-imagenet-200/train/'
writer = SummaryWriter()
batch_size = 128
np.random.seed(42)

In [3]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [4]:
class SiameseNetwork(nn.Module):
    def __init__(self, classes_n):
        super(SiameseNetwork, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=7, padding=3),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 256, kernel_size=7, padding=3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(256, 256, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 128, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU())

        self.fc1 = nn.Sequential(
            nn.Linear(256 * 8 * 8, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 512))
        
        self.fc2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1))
        
        self.fc3 = nn.Sequential(
            nn.Linear(512, classes_n))
        
    def forward_once(self, x):
        output = self.cnn(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output
    
    def forward(self, x):
        out = self.forward_once(x)
        out = self.fc3(out)
        return out
    
    def forward_siamese(self, input1, input2):
        out1 = self.forward_once(input1)
        out2 = self.forward_once(input2)
        out = torch.abs(out1 - out2)
        out = self.fc2(out)
        return out
    
    def forward_one_shot(self, x, support_set):
        preds = []
        for img in support_set:
            preds.append(self.forward_siamese(x, img).data.cpu().numpy()[0])
        res = np.argmin(np.array(preds))
        return res

In [5]:
class ImageDataset(Dataset):
    def __init__(self, images, n_classes):
        self.images = images
        self.n = n_classes
        self.cur_class = 0
        self.cur_img = 0
        self.size = 0
        for i in images:
            self.size += len(i)
        
    def __len__(self):
        return self.size
        
    def __getitem__(self, idx):
        class1 = self.cur_class
        if np.random.uniform() < 0.5:
            class2 = np.random.randint(0, self.n)
            res = 1
        else:
            class2 = class1
            res = 0
        
        img1 = self.images[class1][self.cur_img]
        img2 = self.images[class2][np.random.randint(0, high = len(self.images[class2]))]
        self.cur_img += 1
        if self.cur_img >= len(self.images[self.cur_class]):
            self.cur_img = 0
            self.cur_class = (self.cur_class + 1) % self.n
        
        sample = {'img1' : torch.from_numpy(img1), 'img2' : torch.from_numpy(img2), 'res' : res}
        return sample

In [6]:
class CNNDataset(Dataset):
    def __init__(self, images,  n_classes):
        self.images = images
        self.n = n_classes
        self.size = 0
        for i in images:
            self.size += len(i)
        
    def __len__(self):
        return self.size
        
    def __getitem__(self, idx):
        class1 = np.random.randint(0, self.n)
        img1 = self.images[class1][np.random.randint(0, len(self.images[class1]))]
        sample = {'img1' : torch.from_numpy(img1), 'res' : class1}
        return sample

In [7]:
classes = os.listdir(root_dir)
test_classes = np.random.choice(classes, size = int(0.1 * len(classes)), replace = False)
classes = [i for i in classes if i not in test_classes]

In [8]:
n = len(classes)
test_n = len(test_classes)

In [9]:
images = []
test = []

In [10]:
for cl in classes:
    images.append([])
    test.append([])
    for img_name in os.listdir(root_dir + cl + '/images/'):
        img = np.array(Image.open(root_dir + cl + '/images/' + img_name), dtype = 'float32') / 255
        if not len(img.shape) == 3:
            img = np.stack([img, img, img])
        else:
            img = img.transpose((2, 0, 1))
        if np.random.uniform() < 0.1:
            test[-1].append(img)
        else:
            images[-1].append(img)

In [11]:
dataset = CNNDataset(images, n)
dataloader = DataLoader(dataset, batch_size = batch_size)
test_dataset = CNNDataset(test, n)
test_dataloader = DataLoader(test_dataset, batch_size = batch_size)

In [11]:
model = SiameseNetwork(n)
model.apply(weights_init)
model.cuda()
coef = 0.02

In [13]:
sgd = torch.optim.SGD(model.parameters(), lr = 0.02, momentum = 0.9, weight_decay = coef)
criterion = nn.CrossEntropyLoss()
for i in range(20):
    model.train()
    for j, batch in enumerate(dataloader):
        target = Variable(batch['res']).cuda()
        y_pred = model.forward(Variable(batch['img1']).cuda())       
        loss = criterion(y_pred, target)
        mav = torch.mean(torch.abs(y_pred))
        writer.add_scalar('cross entropy', loss, i * len(dataloader) + j)
        writer.add_scalar('absolute value', mav, i * len(dataloader) + j)
        sgd.zero_grad()
        loss.backward()
        sgd.step()
    model.eval()
    for j, batch in enumerate(test_dataloader):
        target = Variable(batch['res']).cuda()
        y_pred = model.forward(Variable(batch['img1']).cuda())       
        loss = criterion(y_pred, target)
        writer.add_scalar('test cross entropy', loss, i * len(test_dataloader) + j)

In [12]:
model = torch.load('model2')

In [13]:
dataset = ImageDataset(images, n)
dataloader = DataLoader(dataset, batch_size = batch_size)
test_dataset = ImageDataset(test, n)
test_dataloader = DataLoader(test_dataset, batch_size = batch_size)

In [None]:
for param in model.cnn.parameters():
    param.requires_grad = False

In [None]:
sgd = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr = 0.02, momentum = 0.9, weight_decay = coef)
criterion = nn.BCEWithLogitsLoss()
for i in range(200):
    model.train()
    for j, batch in enumerate(dataloader):
        target = Variable(batch['res']).float().cuda()
        y_pred = model.forward_siamese(Variable(batch['img1']).cuda(), Variable(batch['img2']).cuda())
        y_pred = y_pred.view(-1)
        loss = criterion(y_pred, target)
        mav = torch.mean(torch.abs(y_pred))
        writer.add_scalar('siamese cross entropy', loss, i * len(dataloader) + j)
        writer.add_scalar('siamese absolute value', mav, i * len(dataloader) + j)
        sgd.zero_grad()
        loss.backward()
        sgd.step()
    model.eval()
    for j, batch in enumerate(test_dataloader):
        target = Variable(batch['res']).float().cuda()
        y_pred = model.forward_siamese(Variable(batch['img1']).cuda(), Variable(batch['img2']).cuda())
        y_pred = y_pred.view(-1)
        loss = criterion(y_pred, target)
        writer.add_scalar('test siamese cross entropy', loss, i * len(test_dataloader) + j)

In [None]:
torch.save(model, 'model3')

In [None]:
class TestDataset(Dataset):
    def __init__(self, classes):
        self.transform = transform
        self.images = []
        self.support_set = []
        self.n = len(classes)
        self.cur_class = 0
        self.cur_img = 0
        self.size = 0
        
        for cl in classes:
            self.images.append([])
            sup = np.random.choice(os.listdir(root_dir + cl + '/images/'))
            for img_name in os.listdir(root_dir + cl + '/images/'):
                img = io.imread(root_dir + cl + '/images/' + img_name) / 255
                if not len(img.shape) == 3:
                    img = np.stack([img, img, img])
                else:
                    img = img.transpose((2, 0, 1))
                if sup == img_name:
                    self.support_set.append(Variable(torch.Tensor(img)).unsqueeze(0).cuda())
                else:
                    self.images[-1].append(img)
                    self.size += 1

    def __len__(self):
        return self.size
        
    def __getitem__(self):
        img1 = self.images[self.cur_class][self.cur_img]
        self.cur_img += 1
        if self.cur_img >= len(self.images[self.cur_class]):
            self.cur_img = 0
            self.cur_class = (self.cur_class + 1) % self.n
        
        sample = {'img1' : torch.from_numpy(img1.astype('float32')), 'res' : self.cur_class}
        return sample

In [None]:
del dataset
del dataloader

In [None]:
dataset = TestDataset(test_classes)

In [None]:
model.eval()
correct = 0
for i in trange(len(dataset)):
    sample = dataset.__getitem__()
    target = sample['res']
    y_pred = model.forward_one_shot(Variable(sample['img1']).unsqueeze(0).cuda(), dataset.support_set)
    if target == y_pred:
        correct += 1
print(correct / len(dataset))