In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision.datasets as dset
import torchvision.transforms as T

import numpy as np

In [3]:
class ChunkSampler(sampler.Sampler):
    """Samples elements sequentially from some offset. 
    Arguments:
        num_samples: # of desired datapoints
        start: offset where we should start selecting from
    """
    def __init__(self, num_samples, start = 0):
        self.num_samples = num_samples
        self.start = start

    def __iter__(self):
        return iter(range(self.start, self.start + self.num_samples))

    def __len__(self):
        return self.num_samples

NUM_TRAIN = 49000
NUM_VAL = 1000

cifar10_train = dset.CIFAR10('./cs231n/datasets', train=True, download=True,
                           transform=T.ToTensor())
loader_train = DataLoader(cifar10_train, batch_size=64, sampler=ChunkSampler(NUM_TRAIN, 0))

cifar10_val = dset.CIFAR10('./cs231n/datasets', train=True, download=True,
                           transform=T.ToTensor())
loader_val = DataLoader(cifar10_val, batch_size=64, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))

cifar10_test = dset.CIFAR10('./cs231n/datasets', train=False, download=True,
                          transform=T.ToTensor())
loader_test = DataLoader(cifar10_test, batch_size=64)


Files already downloaded and verified


Files already downloaded and verified


Files already downloaded and verified


In [142]:
dtype = torch.cuda.FloatTensor # the CPU datatype

# Constant to control how frequently we print train loss
print_every = 100

# This is a little utility that we'll use to reset the model
# if we want to re-initialize all our parameters
def reset(m):
    if hasattr(m, 'reset_parameters'):
        m.reset_parameters()

In [194]:
def train(model, loss_fn, optimizer, num_epochs = 1):
    best_val_acc = 0.
    
    for epoch in range(num_epochs):
        print('Starting epoch %d / %d' % (epoch + 1, num_epochs))
        model.train()
        
        for t, (x, y) in enumerate(loader_train):
            x_var = Variable(x.type(dtype))
            y_var = Variable(y.type(dtype).long())

            scores = model(x_var)
            
            loss = loss_fn(scores, y_var)
            if (t + 1) % print_every == 0:
                print('t = %d, loss = %.4f' % (t + 1, loss.data[0]))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        val_acc = check_accuracy(net, loader_val)
        
        if val_acc > best_val_acc:
            torch.save(net, 'net_%d.model' % (val_acc*100))

def check_accuracy(model, loader):
    if loader.dataset.train:
        print('Checking accuracy on validation set')
    else:
        print('Checking accuracy on test set')   
    num_correct = 0
    num_samples = 0
    model.eval() # Put the model in test mode (the opposite of model.train(), essentially)
    for x, y in loader:
        x_var = Variable(x.type(dtype), volatile=True)

        scores = model(x_var)
        _, preds = scores.data.cpu().max(1)
        num_correct += (preds == y).sum()
        num_samples += preds.size(0)
    acc = float(num_correct) / num_samples
    print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))
    
    return acc

In [213]:
class BaseNet(nn.Module):
    def __init__(self, use_dropout=False, p=0.9, input_shape=(3, 32, 32)):
        super(BaseNet, self).__init__()
        
        print(use_dropout, p, input_shape)
        assert len(input_shape) == 3, 'Input shape must be provided'
        
        self.input_channels = input_shape[0]
        self.output_classes = 10
        self.input_shape = input_shape
        num_layers = 10
        
        layers = []
        
        for i in range(num_layers):
            inp_channels = self.input_channels if i == 0 else 64
            
            layers += [
                nn.Conv2d(inp_channels, 64, 3, padding=1),
                nn.ReLU(inplace=True),
                nn.BatchNorm2d(64)
            ]
            
            if i == 2 or i == 5 or i == 8:
                layers.append(nn.MaxPool2d(2))
            
            if use_dropout:
                layers.append(nn.Dropout2d(p))
        
        self.feats_extractor = nn.Sequential(*layers)
        
        self.flat_feats = self.get_flat_feats(
            self.input_shape,
            self.feats_extractor
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(self.flat_feats, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, self.output_classes)
        )
    
    @staticmethod
    def get_flat_feats(input_shape, feats_model):
        dummy_input = Variable(torch.ones(1, *input_shape))
        feats = feats_model(dummy_input)
        
        return int(np.prod(feats.size()[1:]))

    def forward(self, x):
        x = self.feats_extractor(x)
        x = x.view(-1, self.flat_feats)
        
        return self.classifier(x)


In [218]:
net = BaseNet().type(dtype)
loss_fn = nn.CrossEntropyLoss().type(dtype)
optimizer = optim.Adam(net.parameters(), lr=1e-3,
                       weight_decay=0.0001)
torch.cuda.random.manual_seed(12345)
net.apply(reset)
train(net, loss_fn, optimizer, num_epochs=30)

False 0.9 (3, 32, 32)
Starting epoch 1 / 30


t = 100, loss = 1.5450


t = 200, loss = 1.4660


t = 300, loss = 1.2622


t = 400, loss = 1.0758


t = 500, loss = 0.9555


t = 600, loss = 0.9243


t = 700, loss = 0.9055


Checking accuracy on validation set


Got 700 / 1000 correct (70.00)
Starting epoch 2 / 30


  "type " + obj.__name__ + ". It won't be checked "


t = 100, loss = 0.7096


t = 200, loss = 0.7663


t = 300, loss = 0.9276


t = 400, loss = 0.5987


t = 500, loss = 0.6709


t = 600, loss = 0.6824


t = 700, loss = 0.6262


Checking accuracy on validation set


Got 755 / 1000 correct (75.50)
Starting epoch 3 / 30


t = 100, loss = 0.6147


t = 200, loss = 0.6962


t = 300, loss = 0.6591


t = 400, loss = 0.4807


t = 500, loss = 0.5525


t = 600, loss = 0.4688


t = 700, loss = 0.4778


Checking accuracy on validation set


Got 760 / 1000 correct (76.00)
Starting epoch 4 / 30


t = 100, loss = 0.4528


t = 200, loss = 0.5659


t = 300, loss = 0.5496


t = 400, loss = 0.3107


t = 500, loss = 0.4124


t = 600, loss = 0.4940


t = 700, loss = 0.4587


Checking accuracy on validation set


Got 780 / 1000 correct (78.00)
Starting epoch 5 / 30


t = 100, loss = 0.2662


t = 200, loss = 0.3729


t = 300, loss = 0.3886


t = 400, loss = 0.3926


t = 500, loss = 0.3666


t = 600, loss = 0.2735


t = 700, loss = 0.4188


Checking accuracy on validation set


Got 782 / 1000 correct (78.20)
Starting epoch 6 / 30


t = 100, loss = 0.1376


t = 200, loss = 0.3584


t = 300, loss = 0.3399


t = 400, loss = 0.2288


t = 500, loss = 0.3297


t = 600, loss = 0.1912


t = 700, loss = 0.3947


Checking accuracy on validation set


Got 772 / 1000 correct (77.20)
Starting epoch 7 / 30


t = 100, loss = 0.1320


t = 200, loss = 0.2650


t = 300, loss = 0.2997


t = 400, loss = 0.2215


t = 500, loss = 0.2131


t = 600, loss = 0.1780


t = 700, loss = 0.2099


Checking accuracy on validation set


Got 795 / 1000 correct (79.50)
Starting epoch 8 / 30


t = 100, loss = 0.1154


t = 200, loss = 0.3119


t = 300, loss = 0.2222


t = 400, loss = 0.1736


t = 500, loss = 0.1858


t = 600, loss = 0.1830


t = 700, loss = 0.3353


Checking accuracy on validation set


Got 794 / 1000 correct (79.40)
Starting epoch 9 / 30


t = 100, loss = 0.1445


t = 200, loss = 0.0741


t = 300, loss = 0.1901


t = 400, loss = 0.1597


t = 500, loss = 0.3229


t = 600, loss = 0.2991


t = 700, loss = 0.1218


Checking accuracy on validation set


Got 780 / 1000 correct (78.00)
Starting epoch 10 / 30


t = 100, loss = 0.1179


t = 200, loss = 0.1394


t = 300, loss = 0.0886


t = 400, loss = 0.2738


t = 500, loss = 0.0780


t = 600, loss = 0.0814


t = 700, loss = 0.1931


Checking accuracy on validation set


Got 804 / 1000 correct (80.40)
Starting epoch 11 / 30


t = 100, loss = 0.0707


t = 200, loss = 0.0953


t = 300, loss = 0.1286


t = 400, loss = 0.0650


t = 500, loss = 0.1492


t = 600, loss = 0.0523


t = 700, loss = 0.1430


Checking accuracy on validation set


Got 814 / 1000 correct (81.40)
Starting epoch 12 / 30


t = 100, loss = 0.0540


t = 200, loss = 0.1845


t = 300, loss = 0.1578


t = 400, loss = 0.0657


t = 500, loss = 0.2467


t = 600, loss = 0.0592


t = 700, loss = 0.0807


Checking accuracy on validation set


Got 810 / 1000 correct (81.00)
Starting epoch 13 / 30


t = 100, loss = 0.0624


t = 200, loss = 0.1648


t = 300, loss = 0.0432


t = 400, loss = 0.1736


t = 500, loss = 0.1391


t = 600, loss = 0.1352


t = 700, loss = 0.2243


Checking accuracy on validation set


Got 813 / 1000 correct (81.30)
Starting epoch 14 / 30


t = 100, loss = 0.1608


t = 200, loss = 0.0245


t = 300, loss = 0.1084


t = 400, loss = 0.0853


t = 500, loss = 0.1518


t = 600, loss = 0.0672


t = 700, loss = 0.1287


Checking accuracy on validation set


Got 801 / 1000 correct (80.10)
Starting epoch 15 / 30


t = 100, loss = 0.0976


t = 200, loss = 0.0819


t = 300, loss = 0.1025


t = 400, loss = 0.0567


t = 500, loss = 0.1009


t = 600, loss = 0.0622


t = 700, loss = 0.0710


Checking accuracy on validation set


Got 819 / 1000 correct (81.90)
Starting epoch 16 / 30


t = 100, loss = 0.0771


t = 200, loss = 0.1116


t = 300, loss = 0.0910


t = 400, loss = 0.0334


t = 500, loss = 0.0320


t = 600, loss = 0.0755


t = 700, loss = 0.1626


Checking accuracy on validation set


Got 818 / 1000 correct (81.80)
Starting epoch 17 / 30


t = 100, loss = 0.0378


t = 200, loss = 0.2067


t = 300, loss = 0.0916


t = 400, loss = 0.0670


t = 500, loss = 0.0496


t = 600, loss = 0.1105


t = 700, loss = 0.0557


Checking accuracy on validation set


Got 827 / 1000 correct (82.70)
Starting epoch 18 / 30


t = 100, loss = 0.1699


t = 200, loss = 0.0766


t = 300, loss = 0.0637


t = 400, loss = 0.0759


t = 500, loss = 0.3143


t = 600, loss = 0.1197


t = 700, loss = 0.0764


Checking accuracy on validation set


Got 794 / 1000 correct (79.40)
Starting epoch 19 / 30


t = 100, loss = 0.0294


t = 200, loss = 0.1019


t = 300, loss = 0.0514


t = 400, loss = 0.0953


t = 500, loss = 0.0486


t = 600, loss = 0.0325


t = 700, loss = 0.0873


Checking accuracy on validation set


Got 821 / 1000 correct (82.10)
Starting epoch 20 / 30


t = 100, loss = 0.0680


t = 200, loss = 0.1328


t = 300, loss = 0.0644


t = 400, loss = 0.0347


t = 500, loss = 0.0682


t = 600, loss = 0.0892


t = 700, loss = 0.0964


Checking accuracy on validation set


Got 825 / 1000 correct (82.50)
Starting epoch 21 / 30


t = 100, loss = 0.0585


t = 200, loss = 0.0573


t = 300, loss = 0.0548


t = 400, loss = 0.0478


t = 500, loss = 0.0969


t = 600, loss = 0.0223


t = 700, loss = 0.0138


Checking accuracy on validation set


Got 819 / 1000 correct (81.90)
Starting epoch 22 / 30


t = 100, loss = 0.0645


t = 200, loss = 0.0388


t = 300, loss = 0.1369


t = 400, loss = 0.0295


t = 500, loss = 0.0421


t = 600, loss = 0.1773


t = 700, loss = 0.0516


Checking accuracy on validation set


Got 819 / 1000 correct (81.90)
Starting epoch 23 / 30


t = 100, loss = 0.0321


t = 200, loss = 0.1027


t = 300, loss = 0.0567


t = 400, loss = 0.0597


t = 500, loss = 0.0598


t = 600, loss = 0.0736


t = 700, loss = 0.0534


Checking accuracy on validation set


Got 821 / 1000 correct (82.10)
Starting epoch 24 / 30


t = 100, loss = 0.1436


t = 200, loss = 0.0222


t = 300, loss = 0.1308


t = 400, loss = 0.0703


t = 500, loss = 0.0668


t = 600, loss = 0.1383


t = 700, loss = 0.1502


Checking accuracy on validation set


Got 818 / 1000 correct (81.80)
Starting epoch 25 / 30


t = 100, loss = 0.0590


t = 200, loss = 0.0646


t = 300, loss = 0.0740


t = 400, loss = 0.0835


t = 500, loss = 0.0538


t = 600, loss = 0.0069


t = 700, loss = 0.0336


Checking accuracy on validation set


Got 824 / 1000 correct (82.40)
Starting epoch 26 / 30


t = 100, loss = 0.0590


t = 200, loss = 0.0365


t = 300, loss = 0.0190


t = 400, loss = 0.1570


t = 500, loss = 0.1375


t = 600, loss = 0.0248


t = 700, loss = 0.0177


Checking accuracy on validation set


Got 802 / 1000 correct (80.20)
Starting epoch 27 / 30


t = 100, loss = 0.0365


t = 200, loss = 0.0839


t = 300, loss = 0.0419


t = 400, loss = 0.0924


t = 500, loss = 0.1395


t = 600, loss = 0.0519


t = 700, loss = 0.0720


Checking accuracy on validation set


Got 828 / 1000 correct (82.80)
Starting epoch 28 / 30


t = 100, loss = 0.0203


t = 200, loss = 0.0570


t = 300, loss = 0.0320


t = 400, loss = 0.0374


t = 500, loss = 0.2242


t = 600, loss = 0.1417


t = 700, loss = 0.1096


Checking accuracy on validation set


Got 810 / 1000 correct (81.00)
Starting epoch 29 / 30


t = 100, loss = 0.1468


t = 200, loss = 0.0429


t = 300, loss = 0.1508


t = 400, loss = 0.0728


t = 500, loss = 0.0155


t = 600, loss = 0.0796


t = 700, loss = 0.0338


Checking accuracy on validation set


Got 824 / 1000 correct (82.40)
Starting epoch 30 / 30


t = 100, loss = 0.0286


t = 200, loss = 0.0645


t = 300, loss = 0.0282


t = 400, loss = 0.1708


t = 500, loss = 0.1016


t = 600, loss = 0.0124


t = 700, loss = 0.1038


Checking accuracy on validation set


Got 820 / 1000 correct (82.00)


In [167]:
torch.save(net.state_dict(), 'best.pth.tar')

In [149]:
new_model = BaseNet().type(dtype)
x = torch.load('best.pth.tar')

In [151]:
new_model.load_state_dict(x)

In [152]:
check_accuracy(new_model, loader_val)

Checking accuracy on validation set
Got 720 / 1000 correct (72.00)


In [181]:
torch.save(net, 'net')

  "type " + obj.__name__ + ". It won't be checked "


In [182]:
m = torch.load(('net'))

In [183]:
check_accuracy(m, loader_val)

Checking accuracy on validation set
Got 743 / 1000 correct (74.30)


In [197]:
print('ajinkya.%f' % 0.0)

ajinkya.0.000000
