In [1]:
import torch
import torch.nn as nn
import os
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils, models
from torch.nn import functional as F
from torch.autograd import Variable
from tensorboardX import SummaryWriter
from tqdm import trange, tqdm
from PIL import Image
import torch.utils.model_zoo as model_zoo

In [2]:
root_dir = 'tiny-imagenet-200/train/'
writer = SummaryWriter()
batch_size = 128
np.random.seed(42)
num_classes = 5
omniglot_dir = 'omniglot/'
rtsd_dir = 'rtsd-r1/'

In [3]:
if True:
    classes = os.listdir(root_dir)
    test_classes = np.random.choice(classes, size = int(0.1 * len(classes)), replace = False)
    classes = [i for i in classes if i not in test_classes]
    def fin_dir(cl):
        return root_dir + cl + '/images/'
elif False:
    classes = []
    for i in os.listdir(omniglot_dir):
        for j in os.listdir(omniglot_dir + i):
            classes.append(i + '/' + j)
    test_classes = np.random.choice(classes, size = int(0.1 * len(classes)), replace = False)
    classes = [i for i in classes if i not in test_classes]
    def fin_dir(cl):
        return omniglot_dir + cl + '/'
else:
    classes = os.listdir(rtsd_dir)
    test_classes = np.random.choice(classes, size = int(0.1 * len(classes)), replace = False)
    classes = [i for i in classes if i not in test_classes]
    def fin_dir(cl):
        return rtsd_dir + cl + '/'

In [4]:
class MatchingNetwork(nn.Module):
    def __init__(self, fce = True):
        super(MatchingNetwork, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.fce = fce
        self.size = 128
        self.lstm = nn.LSTMCell(input_size = self.size, hidden_size = 2 * self.size)
        self.sup_lstm = nn.LSTMCell(input_size = self.size, hidden_size = self.size)
        self.reversed_sup_lstm = nn.LSTMCell(input_size = self.size, hidden_size = self.size)
        self.dist = nn.CosineSimilarity(dim = 2)
        self.softmax = nn.Softmax(dim = 1)
        self.fc = nn.Sequential(
            nn.Linear(2 * self.size + 1, self.size),
            nn.Sigmoid(),
            nn.Linear(self.size, 1))

    def forward(self, support_set, inp):
        support_set = support_set.view(-1, inp.shape[1], inp.shape[2], inp.shape[3])
        sup = self.cnn(support_set)
        sup = sup.view(-1, num_classes, self.size)
        new_sup = Variable(torch.Tensor(torch.zeros(sup.shape))).cuda()
        inp = self.cnn(inp).view(-1, self.size)
        h_0 = inp
        if self.fce:
            for i in range(num_classes):
                sup_h_0 = Variable(torch.zeros(sup.shape[0], self.size), requires_grad = False).cuda()
                sup_c_0 = Variable(torch.zeros(sup.shape[0], self.size), requires_grad = False).cuda()
                sup_rev_h_0 = Variable(torch.zeros(sup.shape[0], self.size), requires_grad = False).cuda()
                sup_rev_c_0 = Variable(torch.zeros(sup.shape[0], self.size), requires_grad = False).cuda()
                for j in range(i):
                    sup_h_0, sup_c_0 = self.sup_lstm(sup[:, j], (sup_h_0, sup_c_0))
                for j in range(num_classes - i - 1):
                    sup_rev_h_0, sup_rev_c_0 = self.sup_lstm(sup[:, num_classes - j - 1], (sup_rev_h_0, sup_rev_c_0))
                new_sup[:, i] = sup[:, i] + sup_h_0 + sup_rev_h_0
            sup = new_sup
            h_0 = Variable(torch.zeros(sup.shape[0], self.size), requires_grad = False).cuda()
            c_0 = Variable(torch.zeros(sup.shape[0], 2 * self.size), requires_grad = False).cuda()
            for i in range(20):
                r_0 = self.softmax(torch.sum(sup * h_0.unsqueeze(1), dim = 2))
                r_0 = torch.sum(r_0.unsqueeze(2) * sup, dim = 1)
                h_0, c_0 = self.lstm(inp, (torch.cat([h_0, r_0], dim = 1), c_0))
                h_0, _ = torch.split(h_0, self.size, dim = 1)
                h_0 = h_0 + inp
        diff = (sup - h_0.unsqueeze(1)).view(-1, self.size)
        out = self.fc(torch.cat([torch.abs(diff), (diff) ** 2, self.dist(sup, h_0.unsqueeze(1)).view(-1, 1)], dim = 1))
        out = out.view(-1, num_classes)
        return out

In [5]:
images = []
for cl in tqdm(classes):
    images.append([])
    for img_name in os.listdir(fin_dir(cl)):
        img = np.array(Image.open(fin_dir(cl) + img_name).resize((28, 28)), dtype = 'float32') / 255
        if not len(img.shape) == 3:
            img = np.stack([img, img, img])
        else:
            img = img.transpose((2, 0, 1))
        images[-1].append(img)

100%|██████████| 180/180 [00:34<00:00,  5.16it/s]


In [6]:
one_shot_images = []
for cl in tqdm(test_classes):
    one_shot_images.append([])
    for img_name in os.listdir(fin_dir(cl)):
        img = np.array(Image.open(fin_dir(cl) + img_name).resize((28, 28)), dtype = 'float32') / 255
        if not len(img.shape) == 3:
            img = np.stack([img, img, img])
        else:
            img = img.transpose((2, 0, 1))
        one_shot_images[-1].append(img)

100%|██████████| 20/20 [00:03<00:00,  5.20it/s]


In [7]:
class ImageDataset(Dataset):
    def __init__(self, images):
        self.images = images
        self.n = len(images)
        self.cur_class = 0
        self.cur_img = 0
        self.size = 0
        for i in images:
            self.size += len(i)
        
    def __len__(self):
        return self.size
        
    def __getitem__(self, idx):
        img = self.images[self.cur_class][self.cur_img]
        ind = np.random.randint(0, self.n, size = num_classes)
        if not self.cur_class in ind:
            res = np.random.randint(0, num_classes)
            ind[res] = self.cur_class
        else:
            res = np.where(ind == self.cur_class)[0][0]
        sup = np.stack([self.images[a][np.random.randint(0, len(self.images[a]))] for a in ind])
        self.cur_img += 1
        if self.cur_img >= len(self.images[self.cur_class]):
            self.cur_img = 0
            self.cur_class = (self.cur_class + 1) % self.n
        sample = {'img' : torch.from_numpy(img), 'support_set' : torch.from_numpy(sup), 'res' : int(res)}
        return sample

In [8]:
dataset = ImageDataset(images)
dataloader = DataLoader(dataset, batch_size = batch_size)
one_shot_dataset = ImageDataset(one_shot_images)
one_shot_dataloader = DataLoader(one_shot_dataset, batch_size = batch_size)

In [9]:
model = MatchingNetwork(False)
model.cuda()
coef = 0.0

In [None]:
sgd = torch.optim.Adam(model.parameters(), weight_decay = coef)
criterion = nn.CrossEntropyLoss()
t = trange(800)
for i in t:
    model.train()
    sum_loss = 0
    acc = 0
    for j, batch in enumerate(dataloader):
        target = Variable(batch['res']).cuda()
        img = batch['img']
        support_set = batch['support_set']
        y_pred = model.forward(Variable(support_set).cuda(), Variable(img).cuda())
        loss = criterion(y_pred, target)
        _, x = torch.max(y_pred, dim = 1)
        sum_loss += loss
        acc += torch.sum(x == target).float() / target.shape[0]
        t.set_description('accuracy = %g, loss = %g' % (acc / (j + 1), loss))
        sgd.zero_grad()
        loss.backward()
        sgd.step()
    writer.add_scalar('cross entropy', sum_loss / (j + 1), i)
    writer.add_scalar('train accuracy', acc / (j + 1), i)
    model.eval()
    one_shot_acc = 0
    for j, batch in enumerate(one_shot_dataloader):
        target = Variable(batch['res']).cuda()
        img = batch['img']
        support_set = batch['support_set']
        y_pred = model.forward(Variable(support_set).cuda(), Variable(img).cuda())
        _, x = torch.max(y_pred, dim = 1)
        one_shot_acc += torch.sum(x == target).float() / target.shape[0]
    writer.add_scalar('one shot accuracy', one_shot_acc / (j + 1), i)

accuracy = 0.844424, loss = 0.441546:  28%|██▊       | 227/800 [2:01:29<5:06:41, 32.11s/it]

In [15]:
torch.save(model, 'MatchingNetwork2')

  "type " + obj.__name__ + ". It won't be checked "
