In [1]:
import torch
import torch.nn as nn
import os
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils
from skimage import io, transform
from torch.nn import functional as F
from torch.autograd import Variable
from tensorboardX import SummaryWriter

In [2]:
root_dir = 'tiny-imagenet-200/train/'
writer = SummaryWriter()
batch_size = 32

In [3]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=10, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 128, kernel_size=7, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(256, 256, kernel_size=3, padding=2),
            nn.ReLU())

        self.fc1 = nn.Sequential(
            nn.Linear(256 * 8 * 8, 4096),
            nn.Sigmoid())
        
        self.fc2 = nn.Sequential(
            nn.Linear(4096, 1),
            nn.Sigmoid())
        
    def forward_once(self, x):
        output = self.cnn(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        out1 = self.forward_once(input1)
        out2 = self.forward_once(input2)
        out = torch.abs(out1 - out2)
        out = self.fc2(out)
        return out

In [4]:
class ImageDataset(Dataset):
    def __init__(self, classes, transform=None):
        self.transform = transform
        self.images = []
        self.n = len(classes)
        
        for cl in classes:
            self.images.append([])
            for img_name in os.listdir(root_dir + cl + '/images/'):
                img = io.imread(root_dir + cl + '/images/' + img_name) / 255
                if not len(img.shape) == 3:
                    img = np.stack([img, img, img])
                else:
                    img = img.transpose((2, 0, 1))
                self.images[-1].append(img)
    
    def __len__(self):
        return 4096
        
    def __getitem__(self, idx):
        if np.random.uniform() < 0.5:
            class1, class2 = np.random.choice(self.n, size = 2, replace = False)
            res = 1
        else:
            class1 = np.random.choice(self.n, size = 1)[0]
            class2 = class1
            res = 0
        
        img1 = self.images[class1][np.random.randint(0, high = len(self.images[class1]))]
        img2 = self.images[class2][np.random.randint(0, high = len(self.images[class2]))]
    
        sample = {'img1' : torch.from_numpy(img1.astype('float')), 'img2' : torch.from_numpy(img2.astype('float')), 'res' : res}
        if self.transform:
            sample = self.transform(sample)

        return sample

In [5]:
classes = os.listdir(root_dir)
test_classes = np.random.choice(classes, size = int(0.4 * len(classes)),replace = False)
classes = [i for i in classes if i not in test_classes]

In [6]:
dataset = ImageDataset(classes)
dataloader = DataLoader(dataset, batch_size = batch_size)

In [7]:
model = SiameseNetwork()
model.double()
model.cuda()
coef = 0

In [9]:
sgd = torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.5, weight_decay = coef)
for i in range(100):
    for j, batch in enumerate(dataloader):
        target = Variable(batch['res']).cuda()
        y_pred = model.forward(Variable(batch['img1']).cuda(), Variable(batch['img2']).cuda())
        y_pred = torch.cat([1 - y_pred, y_pred], dim = 1).cuda()
        loss = F.cross_entropy(y_pred, target)
        print(y_pred, target, loss)
        writer.add_scalar('train_loss', loss, i * 4096 / batch_size + j)
        sgd.zero_grad()
        loss.backward()
        sgd.step()

Variable containing:
 0.5038  0.4962
 0.5080  0.4920
 0.5038  0.4962
 0.5086  0.4914
 0.5069  0.4931
 0.5088  0.4912
 0.5093  0.4907
 0.5067  0.4933
 0.5086  0.4914
 0.5066  0.4934
 0.5073  0.4927
 0.5083  0.4917
 0.5073  0.4927
 0.5085  0.4915
 0.5073  0.4927
 0.5067  0.4933
 0.5087  0.4913
 0.5085  0.4915
 0.5081  0.4919
 0.5091  0.4909
 0.5049  0.4951
 0.5095  0.4905
 0.5055  0.4945
 0.5025  0.4975
 0.5074  0.4926
 0.5084  0.4916
 0.5088  0.4912
 0.5084  0.4916
 0.5089  0.4911
 0.5069  0.4931
 0.5088  0.4912
 0.5086  0.4914
[torch.cuda.DoubleTensor of size 32x2 (GPU 0)]
 Variable containing:
 1
 1
 0
 1
 1
 1
 0
 0
 0
 0
 0
 0
 0
 1
 0
 1
 1
 1
 1
 0
 1
 0
 1
 0
 1
 0
 1
 0
 0
 0
 0
 0
[torch.cuda.LongTensor of size 32 (GPU 0)]
 Variable containing:
 0.6921
[torch.cuda.DoubleTensor of size 1 (GPU 0)]

Variable containing:
 0.5083  0.4917
 0.5023  0.4977
 0.5080  0.4920
 0.5080  0.4920
 0.5080  0.4920
 0.5093  0.4907
 0.5083  0.4917
 0.5081  0.4919
 0.5064  0.4936
 0.5090  0.4910
 0.

KeyboardInterrupt: 