In [65]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

# from models import *
# from utils import progress_bar

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # self.fc1   = nn.Linear(16*5*5, 120)
        self.fc1   = nn.Linear(16* 21 * 21, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        # print(out.shape)
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        # print(out.shape)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out


cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 10)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

In [60]:
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch


In [61]:

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    # transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # transforms.Resize((32,32)),
    # transforms.Resize((96,96)),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Resize((32,32)),
    # transforms.Resize((96,96)),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.STL10(
    root='./data', split = 'train', download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size= 64 , shuffle=True)

testset = torchvision.datasets.STL10(
    root='./data', split = 'test', download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size= 256 , shuffle=False)

# classes = ('plane', 'car', 'bird', 'cat', 'deer',
#            'dog', 'frog', 'horse', 'ship', 'truck')

classes = ('airplane', 'bird', 'car', 'cat','deer', 
           'dog', 'horse', 'monkey', 'ship', 'truck')

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [66]:
# Model
print('==> Building model..')
device = 'cuda'
net = LeNet()
# net = VGG('VGG16')
net = net.to(device)
net = torch.nn.DataParallel(net)
# cudnn.benchmark = True
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr= 0.01,
                      momentum=0.9, weight_decay=5e-4)
print("check parameters: \n", net.parameters)


==> Building model..
check parameters: 
 <bound method Module.parameters of DataParallel(
  (module): LeNet(
    (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
    (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (fc1): Linear(in_features=7056, out_features=120, bias=True)
    (fc2): Linear(in_features=120, out_features=84, bias=True)
    (fc3): Linear(in_features=84, out_features=10, bias=True)
  )
)>


In [63]:

def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    


def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc

In [67]:
for epoch in range(100):
    train(epoch)
    test(epoch)

print("best acc: ", best_acc)


Epoch: 0
0 79 Loss: 2.323 | Acc: 7.812% (5/64)
1 79 Loss: 2.317 | Acc: 7.812% (10/128)
2 79 Loss: 2.313 | Acc: 7.812% (15/192)
3 79 Loss: 2.317 | Acc: 7.031% (18/256)
4 79 Loss: 2.315 | Acc: 8.438% (27/320)
5 79 Loss: 2.314 | Acc: 8.073% (31/384)
6 79 Loss: 2.316 | Acc: 7.812% (35/448)
7 79 Loss: 2.315 | Acc: 8.398% (43/512)
8 79 Loss: 2.312 | Acc: 9.028% (52/576)
9 79 Loss: 2.308 | Acc: 9.688% (62/640)
10 79 Loss: 2.307 | Acc: 9.233% (65/704)
11 79 Loss: 2.307 | Acc: 9.245% (71/768)
12 79 Loss: 2.307 | Acc: 9.255% (77/832)
13 79 Loss: 2.306 | Acc: 9.710% (87/896)
14 79 Loss: 2.306 | Acc: 9.271% (89/960)
15 79 Loss: 2.305 | Acc: 9.277% (95/1024)
16 79 Loss: 2.305 | Acc: 9.559% (104/1088)
17 79 Loss: 2.306 | Acc: 9.375% (108/1152)
18 79 Loss: 2.304 | Acc: 10.033% (122/1216)
19 79 Loss: 2.303 | Acc: 10.391% (133/1280)
20 79 Loss: 2.302 | Acc: 10.789% (145/1344)
21 79 Loss: 2.302 | Acc: 10.724% (151/1408)
22 79 Loss: 2.302 | Acc: 11.073% (163/1472)
23 79 Loss: 2.302 | Acc: 11.003% (169/1