In [1]:
import torch
import torch.nn as nn
import os
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils
from skimage import io, transform
from torch.nn import functional as F
from torch.autograd import Variable
from tensorboardX import SummaryWriter

In [2]:
root_dir = 'tiny-imagenet-200/train/'
writer = SummaryWriter()
batch_size = 32

In [3]:
class CNN(nn.Module):
    def __init__(self, classes_n):
        super(CNN, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU())

        self.fc1 = nn.Sequential(
            nn.Linear(128 * 8 * 8, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, classes_n))
    
    def forward(self, x):
        output = self.cnn(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

In [4]:
classes = os.listdir(root_dir)
test_classes = np.random.choice(classes, size = int(0.99 * len(classes)), replace = False)
classes = [i for i in classes if i not in test_classes]

In [5]:
len(classes)

2

In [6]:
class CNNDataset(Dataset):
    def __init__(self, classes):
        self.transform = transform
        self.images = []
        self.n = len(classes)
        
        for cl in classes:
            self.images.append([])
            for img_name in os.listdir(root_dir + cl + '/images/'):
                img = io.imread(root_dir + cl + '/images/' + img_name) / 255
                if not len(img.shape) == 3:
                    img = np.stack([img, img, img])
                else:
                    img = img.transpose((2, 0, 1))
                self.images[-1].append(img)
                
    def __len__(self):
        return 120 * 500
        
    def __getitem__(self, idx):
        class1 = np.random.randint(0, self.n)
        img1 = self.images[class1][np.random.randint(0, 120)]
        sample = {'img1' : torch.from_numpy(img1.astype('float')), 'res' : class1}
        return sample

In [7]:
dataset = CNNDataset(classes)
dataloader = DataLoader(dataset, batch_size = batch_size)

In [8]:
model = CNN(len(classes))
model.double()
model.cuda()
coef = 0

In [None]:
sgd = torch.optim.SGD(model.parameters(), lr = 0.1, momentum = 0.6, weight_decay = coef)
criterion = nn.CrossEntropyLoss()
model.train()
for i in range(10):
    for j, batch in enumerate(dataloader):
        target = Variable(batch['res']).cuda()
        y_pred = model.forward(Variable(batch['img1']).cuda())
        sgd.zero_grad()
        loss = criterion(y_pred, target)
        loss.backward()
        sgd.step()
        writer.add_scalar('cross entropy', loss, i * len(dataloader) + j)