In [1]:
import torch
import torch.nn as nn
import os
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils
from skimage import io, transform
from torch.nn import functional as F
from torch.autograd import Variable
from tensorboardX import SummaryWriter

In [2]:
root_dir = 'tiny-imagenet-200/train_test/'
writer = SummaryWriter()

In [3]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 128, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(256, 256, kernel_size=5, padding=2),
            nn.ReLU())

        self.fc1 = nn.Sequential(
            nn.Linear(256 * 8 * 8, 4096),
            nn.ReLU())
        
        self.fc2 = nn.Sequential(
            nn.Linear(4096, 1),
            nn.Sigmoid())
        
    def forward_once(self, x):
        output = self.cnn(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        out1 = self.forward_once(input1)
        out2 = self.forward_once(input2)
        out = torch.abs(out1 - out2)
        out = self.fc2(out)
        return out

In [4]:
class ImageDataset(Dataset):
    def __init__(self, classes, transform=None):
        self.root_dir = classes
        self.transform = transform
        self.iter = 32
        self.class1 = ''
        self.class2 = ''
        self.res = 0

    def __len__(self):
        return 4096
        
    def __getitem__(self, idx):
        if self.iter >= 4:
            if np.random.uniform() < 0.5:
                self.class1, self.class2 = np.random.choice(self.root_dir, size = 2, replace = False)
                self.res = 1
            else:
                self.class1 = np.random.choice(self.root_dir, size = 1)[0]
                self.class2 = self.class1
                self.res = 0
            self.iter = 0
        img_name1 = np.random.choice(os.listdir(root_dir + self.class1 + '/images/'), size = 1)[0]
        img_name2 = np.random.choice(os.listdir(root_dir + self.class2 + '/images/'), size = 1)[0]
        img1 = io.imread(root_dir + self.class1 + '/images/' + img_name1)
        img2 = io.imread(root_dir + self.class2 + '/images/' + img_name2)
        if not len(img1.shape) == 3:
            img1 = np.stack([img1, img1, img1])
        else:
            img1 = img1.transpose((2, 0, 1))
        if not len(img2.shape) == 3:
            img2 = np.stack([img2, img2, img2])
        else:
            img2 = img2.transpose((2, 0, 1))

        sample = {'img1' : torch.from_numpy(img1.astype('float')), 'img2' : torch.from_numpy(img2.astype('float')), 'res' : self.res}
        self.iter += 1
        if self.transform:
            sample = self.transform(sample)

        return sample

In [5]:
classes = os.listdir(root_dir)
test_classes = np.random.choice(classes, size = int(0.2 * len(classes)),replace = False)
classes = [i for i in classes if i not in test_classes]

In [6]:
dataset = ImageDataset(classes)
dataloader = DataLoader(dataset, batch_size = 32)

In [7]:
model = SiameseNetwork()
model.double()
model.cuda()
coef = 0

In [8]:
sgd = torch.optim.Adam(model.parameters(), weight_decay = coef)
for i in range(1):
    for j, batch in enumerate(dataloader):
        target = Variable(batch['res']).cuda()
        y_pred = model.forward(Variable(batch['img1']).cuda(), Variable(batch['img2']).cuda())
        y_pred = torch.cat([y_pred, 1 - y_pred], dim = 1).cuda()
        loss = F.cross_entropy(y_pred, target)
        writer.add_scalar('train_loss', loss, i * 32 + j)
        sgd.zero_grad()
        loss.backward()
        sgd.step()

KeyboardInterrupt: 