In [1]:
import torch
import torch.nn as nn
import os
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils, models
from torch.nn import functional as F
from torch.autograd import Variable
from tensorboardX import SummaryWriter
from tqdm import trange, tqdm
from PIL import Image
import torch.utils.model_zoo as model_zoo

In [2]:
root_dir = 'tiny-imagenet-200/train/'
writer = SummaryWriter()
batch_size = 128
np.random.seed(42)

In [3]:
class SiameseNetwork(nn.Module):
    def __init__(self, classes_n):
        super(SiameseNetwork, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2))

        self.fc1 = nn.Sequential(
            nn.Linear(64, 16),
            nn.Sigmoid())
        
        self.fc2 = nn.Sequential(
            nn.Linear(257, 128),
            nn.Sigmoid(),
            nn.Linear(128, 1))
        
        self.fc3 = nn.Sequential(
            nn.Linear(1024, classes_n))
        self.dist = nn.CosineSimilarity()
        
    def forward_once(self, x):
        output = self.cnn(x)
        output = output.view(output.size()[0], -1)
        return output
    
    def forward(self, x):
        out = self.forward_once(x)
        out = self.fc2(out)
        return out
    
    def forward_siamese(self, input1, input2):
        out1 = self.forward_once(input1)
        out2 = self.forward_once(input2)
        out = torch.cat([torch.abs(out1 - out2), (out1 - out2) ** 2, self.dist(out1, out2).view(-1, 1)], dim = 1)
        out = self.fc2(out)
        return out
    
    def forward_one_shot(self, x, support_set):
        preds = []
        for img in support_set:
            preds.append(self.forward_siamese(x, img).data.cpu().numpy()[0])
        res = np.argmin(np.array(preds))
        return res

In [4]:
class ImageDataset(Dataset):
    def __init__(self, images, n_classes):
        self.images = images
        self.n = n_classes
        self.cur_class = 0
        self.cur_img = 0
        self.size = 0
        for i in images:
            self.size += len(i)
        
    def __len__(self):
        return self.size
        
    def __getitem__(self, idx):
        class1 = self.cur_class
        if np.random.uniform() < 0.5:
            class2 = np.random.randint(0, self.n)
            res = 1
        else:
            class2 = class1
            res = 0
        
        img1 = self.images[class1][self.cur_img]
        img2 = self.images[class2][np.random.randint(0, high = len(self.images[class2]))]
        self.cur_img += 1
        if self.cur_img >= len(self.images[self.cur_class]):
            self.cur_img = 0
            self.cur_class = (self.cur_class + 1) % self.n
        
        sample = {'img1' : torch.from_numpy(img1), 'img2' : torch.from_numpy(img2), 'res' : res}
        return sample

In [13]:
class TestDataset(Dataset):
    def __init__(self, classes):
        self.images = []
        self.support_set = []
        self.n = len(classes)
        self.cur_class = 0
        self.cur_img = 0
        self.size = 0
        for cl in classes:
            self.images.append([])
            sup = np.random.choice(os.listdir(root_dir + cl + '/images/'))
            img = Image.open(root_dir + cl + '/images/' + sup)
            img = img.resize((28, 28))
            img = np.array(img, dtype = 'float32') / 255
            if not len(img.shape) == 3:
                img = np.stack([img, img, img])
            else:
                img = img.transpose((2, 0, 1))
            self.support_set.append(torch.Tensor(img))
            for img_name in os.listdir(root_dir + cl + '/images/'):
                if not sup == img_name:
                    img = Image.open(root_dir + cl + '/images/' + img_name)
                    img = img.resize((28, 28))
                    img = np.array(img, dtype = 'float32') / 255
                    if not len(img.shape) == 3:
                        img = np.stack([img, img, img])
                    else:
                        img = img.transpose((2, 0, 1))
                    self.images[-1].append(img)
                    self.size += 1
        self.support_set = torch.stack(self.support_set)
        
    def __len__(self):
        return self.size
        
    def __getitem__(self):
        img1 = self.images[self.cur_class][self.cur_img]
        ind = np.random.randint(0, self.n, size = 5)
        if not self.cur_class in ind:
            res = np.random.randint(0, 5)
            ind[res] = self.cur_class
        else:
            res = np.where(ind == self.cur_class)[0][0]
        sup = torch.stack([self.support_set[a] for a in ind])
        self.cur_img += 1
        if self.cur_img >= len(self.images[self.cur_class]):
            self.cur_img = 0
            self.cur_class = (self.cur_class + 1) % self.n
        sample = {'img1' : torch.from_numpy(img1.astype('float32')), 'res' : res, 'support_set': sup}
        return sample

In [6]:
class CNNDataset(Dataset):
    def __init__(self, images,  n_classes):
        self.images = images
        self.n = n_classes
        self.size = 0
        for i in images:
            self.size += len(i)
        
    def __len__(self):
        return self.size
        
    def __getitem__(self, idx):
        class1 = np.random.randint(0, self.n)
        img1 = self.images[class1][np.random.randint(0, len(self.images[class1]))]
        sample = {'img1' : torch.from_numpy(img1), 'res' : class1}
        return sample

In [7]:
classes = os.listdir(root_dir)
test_classes = np.random.choice(classes, size = int(0.1 * len(classes)), replace = False)
classes = [i for i in classes if i not in test_classes]

In [8]:
n = len(classes)
test_n = len(test_classes)

In [9]:
images = []
test = []

In [10]:
for cl in tqdm(classes):
    images.append([])
    test.append([])
    for img_name in os.listdir(root_dir + cl + '/images/'):
        img = np.array(Image.open(root_dir + cl + '/images/' + img_name).resize((28, 28)), dtype = 'float32') / 255
        if not len(img.shape) == 3:
            img = np.stack([img, img, img])
        else:
            img = img.transpose((2, 0, 1))
        if np.random.uniform() < 0.1:
            test[-1].append(img)
        else:
            images[-1].append(img)

100%|██████████| 180/180 [00:19<00:00,  9.30it/s]


In [14]:
dataset = CNNDataset(images, n)
dataloader = DataLoader(dataset, batch_size = batch_size)
test_dataset = CNNDataset(test, n)
test_dataloader = DataLoader(test_dataset, batch_size = batch_size)

In [15]:
model = SiameseNetwork(n)
model.cuda()
coef = 0.0

In [13]:
sgd = torch.optim.Adam(model.parameters(), weight_decay = coef)
criterion = nn.CrossEntropyLoss()
for i in range(30):
    model.train()
    for j, batch in enumerate(dataloader):
        target = Variable(batch['res']).cuda()
        y_pred = model.forward(Variable(batch['img1']).cuda())       
        loss = criterion(y_pred, target)
        mav = torch.mean(torch.abs(y_pred))
        writer.add_scalar('cross entropy', loss, i * len(dataloader) + j)
        writer.add_scalar('absolute value', mav, i * len(dataloader) + j)
        sgd.zero_grad()
        loss.backward()
        sgd.step()
    model.eval()
    for j, batch in enumerate(test_dataloader):
        target = Variable(batch['res']).cuda()
        y_pred = model.forward(Variable(batch['img1']).cuda())       
        loss = criterion(y_pred, target)
        writer.add_scalar('test cross entropy', loss, i * len(test_dataloader) + j)

In [16]:
dataset = ImageDataset(images, n)
dataloader = DataLoader(dataset, batch_size = batch_size)
test_dataset = ImageDataset(test, n)
test_dataloader = DataLoader(test_dataset, batch_size = batch_size)
one_shot_dataset = TestDataset(test_classes)

In [None]:
for param in model.cnn.parameters():
    param.requires_grad = True

In [17]:
sgd = torch.optim.Adam(model.parameters(), weight_decay = coef)
dist = nn.CosineSimilarity()
criterion = nn.BCEWithLogitsLoss()
for i in trange(200):
    model.train()
    for j, batch in enumerate(dataloader):
        target = Variable(batch['res']).float().cuda()
        y_pred = model.forward_siamese(Variable(batch['img1']).cuda(), Variable(batch['img2']).cuda())
        y_pred = y_pred.view(-1)
        loss = criterion(y_pred, target)
        mav = torch.mean(torch.abs(y_pred))
        writer.add_scalar('siamese cross entropy', loss, i * len(dataloader) + j)
        writer.add_scalar('siamese absolute value', mav, i * len(dataloader) + j)
        sgd.zero_grad()
        loss.backward()
        sgd.step()
    model.eval()
    for j, batch in enumerate(test_dataloader):
        target = Variable(batch['res']).float().cuda()
        y_pred = model.forward_siamese(Variable(batch['img1']).cuda(), Variable(batch['img2']).cuda())
        y_pred = y_pred.view(-1)
        loss = criterion(y_pred, target)
        writer.add_scalar('test siamese cross entropy', loss, i * len(test_dataloader) + j)
    correct = 0
    for j in range(len(one_shot_dataset)):
        sample = one_shot_dataset.__getitem__()
        target = sample['res']
        y_pred = model.forward_siamese(Variable(sample['img1']).unsqueeze(0).cuda(), Variable(sample['support_set']).cuda())
        pred = np.argmin(y_pred.data.cpu().numpy())
        if target == pred:
            correct += 1
    writer.add_scalar('test accuracy', correct / len(one_shot_dataset), i)

100%|██████████| 200/200 [1:33:50<00:00, 28.15s/it]


In [18]:
torch.save(model, 'SiameseNet')

  "type " + obj.__name__ + ". It won't be checked "


In [19]:
one_shot_dataset = TestDataset(test_classes)

In [16]:
model = models.Inception3().cuda()
model.load_state_dict(torch.load("/home/david/.torch/models/inception_v3_google-1a9a5a14.pth"))

In [22]:
model.eval()
correct = 0
dist = nn.CosineSimilarity()
for i in trange(len(one_shot_dataset)):
    sample = one_shot_dataset.__getitem__()
    target = sample['res']
    y_pred = model.forward_siamese(Variable(sample['img1']).unsqueeze(0).cuda(), Variable(sample['support_set']).cuda())
    pred = np.argmin(y_pred.data.cpu().numpy())
    if target == pred:
        correct += 1
print(correct / len(one_shot_dataset))

100%|██████████| 9980/9980 [00:17<00:00, 567.30it/s]

0.3629258517034068



