In [13]:
import torch
import torchvision
import torchvision.transforms as transforms

import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np

import torch.nn as nn
import torch.nn.functional as F

from torch.utils.tensorboard import SummaryWriter

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [14]:
# transform = transforms.Compose(
#     [transforms.ToTensor()])
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [15]:
trainset, valset = torch.utils.data.random_split(trainset, [45000, 5000])

In [16]:
class Net(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.n = n
        
        self.conv = nn.Sequential(nn.Conv2d(3, 16, (3,3), padding=1),\
                    nn.BatchNorm2d(16))
    
        self.stage1 = nn.ModuleList()
        for _ in range(n):
                tmp = []
                tmp.append(nn.Conv2d(16, 16, (3,3), padding=1))
                tmp.append(nn.BatchNorm2d(16))
                tmp.append(nn.ReLU())
                tmp.append(nn.Conv2d(16, 16, (3,3), padding=1))
                tmp.append(nn.BatchNorm2d(16))
                self.stage1.append(nn.Sequential(*tmp))
        
        self.stage2 = nn.ModuleList()
        for _ in range(n):
                tmp = []
                if _ == 0:
                    tmp.append(nn.Conv2d(16, 32, (3,3), stride=2, padding=1))
                else:
                    tmp.append(nn.Conv2d(32, 32, (3,3), padding=1))
                tmp.append(nn.BatchNorm2d(32))
                tmp.append(nn.ReLU())
                tmp.append(nn.Conv2d(32, 32, (3,3), padding=1))
                tmp.append(nn.BatchNorm2d(32))
                self.stage2.append(nn.Sequential(*tmp))
        
        self.stage3 = nn.ModuleList()
        for _ in range(n):
                tmp = []
                if _ == 0:
                    tmp.append(nn.Conv2d(32, 64, (3,3), stride=2, padding=1))
                else:
                    tmp.append(nn.Conv2d(64, 64, (3,3), padding=1))
                tmp.append(nn.BatchNorm2d(64))
                tmp.append(nn.ReLU())
                tmp.append(nn.Conv2d(64, 64, (3,3), padding=1))
                tmp.append(nn.BatchNorm2d(64))
                self.stage3.append(nn.Sequential(*tmp))
        
        self.GAP = nn.AvgPool2d((8, 8))
        self.fc = nn.Linear(64, 10)
        
    def forward(self, x):
        x = F.relu(self.conv(x))
        ###stage_1
        for i in range(self.n):
            x = F.relu(self.stage1[i](x) + x)
        ###stage_2    
        for i in range(self.n):
            if i == 0:
                x_i = F.interpolate(x, size=x.shape[3]//2)  #size down
                x_i = torch.cat((x_i,torch.zeros(x_i.shape).to(x.device)),1) #double channel
                x = F.relu(self.stage2[i](x) + x_i)
            else:
                x = F.relu(self.stage2[i](x) + x)
        ###stage_3
        for i in range(self.n):
            if i == 0:
                x_i = F.interpolate(x, size=x.shape[3]//2) #size down
                x_i = torch.cat((x_i, torch.zeros(x_i.shape).to(x.device)), 1) #double channel
                x = F.relu(self.stage3[i](x) + x_i)
            else:
                x = F.relu(self.stage3[i](x) + x)
        x = self.GAP(x) 
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        return x

#weight Initialization
def init_weights(m):
        if type(m) == nn.Linear:
            torch.nn.init.kaiming_normal_(m.weight)
            m.bias.data.fill_(0.01)
        
        if isinstance(m, nn.Conv2d):
            torch.nn.init.kaiming_normal_(m.weight.data)
            torch.nn.init.zeros_(m.bias.data)

In [17]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [6]:
#################################

In [18]:
# ResNet 32 
N = 5

net = Net(N);
net.apply(init_weights);
net.to(device);

In [19]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, weight_decay = 0.0001, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

In [21]:
batch_size = 128
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=4)
print(len(trainloader))

start_iter = 0
end_iter = 64000
it = 0

352


In [22]:
writer = SummaryWriter()
running_loss = 0.0
net.train();
while(it<=end_iter):
    for data in trainloader:
        it+=1
        inputs, labels = data[0].to(device), data[1].to(device)
        
        #random flip, constant padding, random crop
        if np.random.randint(0,2) == 0:
            inputs = torch.flip(inputs, [3])
        inputs = F.pad(inputs, (4, 4, 4, 4))
        W = inputs.size()[2]
        H = inputs.size()[3] 
        Ws = np.random.randint(0, W-32, 1)[0]
        Hs = np.random.randint(0, H-32, 1)[0]
        inputs = inputs[:,:,Ws:Ws+32,Hs:Hs+32]
        
        #training
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        if it == 32000 or it == 48000:
            scheduler.step()
            PATH = './cifar_10_resnet'+str(N*6+2)+'_iter'+str(it)+'.pth'
            torch.save(net.state_dict(), PATH)
            
        running_loss += loss.item()
        if it % 100 == 0:
            print("iter:",it,", loss/100it:",running_loss / 100)
            writer.add_scalar('training loss', running_loss / 100,it)
            running_loss = 0.0
            
        if it % 1000 == 0:
            print("iter:",it,", test")
            net.eval();
            #for Training set
            correct = 0
            total = 0
            with torch.no_grad():
                for i,data in enumerate(trainloader):
                    if i % 100 == 0:
                        print(i,"/",len(trainloader))
                    inputs, labels = data[0].to(device), data[1].to(device)
                    
                    outputs = net(inputs)

                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            print('Accuracy on training images: %d %%' % (100 * correct / total))
            writer.add_scalar('Acc on training', (100 * correct / total),it)
            
            #for Test set
            correct = 0
            total = 0
            with torch.no_grad():
                for i,data in enumerate(testloader):
                    if i % 50 == 0:
                        print(i,"/",len(testloader))
                    inputs, labels = data[0].to(device), data[1].to(device)
                    
                    outputs = net(inputs)

                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            print('Accuracy on test images: %d %%' % (100 * correct / total))
            writer.add_scalar('Acc on testing', (100 * correct / total),it)
            net.train();
        
print('Finished Training')

PATH = './cifar_10_resnet'+str(N*6+2)+'_iter'+str(it)+'.pth'
torch.save(net.state_dict(), PATH)
writer.close()

iter: 100 , loss/100it: 2.776061017513275
iter: 200 , loss/100it: 2.2238227486610413
iter: 300 , loss/100it: 2.0731608951091767
iter: 400 , loss/100it: 1.8699504792690278
iter: 500 , loss/100it: 1.7405461978912353
iter: 600 , loss/100it: 1.625420045852661
iter: 700 , loss/100it: 1.5293361568450927
iter: 800 , loss/100it: 1.4328518640995025
iter: 900 , loss/100it: 1.3345908391475678
iter: 1000 , loss/100it: 1.242227029800415
iter: 1000 , test
0 / 352
100 / 352
200 / 352
300 / 352
Accuracy on training images: 49 %
0 / 79
50 / 79
Accuracy on test images: 49 %
iter: 1100 , loss/100it: 1.1905099976062774
iter: 1200 , loss/100it: 1.1350100088119506
iter: 1300 , loss/100it: 1.0865227764844894
iter: 1400 , loss/100it: 1.040309603214264
iter: 1500 , loss/100it: 1.006677429676056
iter: 1600 , loss/100it: 0.947714912891388
iter: 1700 , loss/100it: 0.9185647308826447
iter: 1800 , loss/100it: 0.8659544199705124
iter: 1900 , loss/100it: 0.8461891549825669
iter: 2000 , loss/100it: 0.8155267602205276


iter: 14400 , loss/100it: 0.28717013269662856
iter: 14500 , loss/100it: 0.25787654653191566
iter: 14600 , loss/100it: 0.2702786347270012
iter: 14700 , loss/100it: 0.2966419172286987
iter: 14800 , loss/100it: 0.28369120985269547
iter: 14900 , loss/100it: 0.2612015686929226
iter: 15000 , loss/100it: 0.25448479637503624
iter: 15000 , test
0 / 352
100 / 352
200 / 352
300 / 352
Accuracy on training images: 88 %
0 / 79
50 / 79
Accuracy on test images: 83 %
iter: 15100 , loss/100it: 0.2920688197016716
iter: 15200 , loss/100it: 0.2611998589336872
iter: 15300 , loss/100it: 0.27193727403879164
iter: 15400 , loss/100it: 0.265555290132761
iter: 15500 , loss/100it: 0.2771045273542404
iter: 15600 , loss/100it: 0.2647376573085785
iter: 15700 , loss/100it: 0.26174304991960523
iter: 15800 , loss/100it: 0.28459170684218404
iter: 15900 , loss/100it: 0.2563937938213348
iter: 16000 , loss/100it: 0.262112153545022
iter: 16000 , test
0 / 352
100 / 352
200 / 352
300 / 352
Accuracy on training images: 91 %
0 /

iter: 28200 , loss/100it: 0.20655210837721824
iter: 28300 , loss/100it: 0.22127913914620875
iter: 28400 , loss/100it: 0.1959347252547741
iter: 28500 , loss/100it: 0.22875115998089313
iter: 28600 , loss/100it: 0.2066959834843874
iter: 28700 , loss/100it: 0.19852043345570564
iter: 28800 , loss/100it: 0.2137391234189272
iter: 28900 , loss/100it: 0.20489991001784802
iter: 29000 , loss/100it: 0.19660663701593875
iter: 29000 , test
0 / 352
100 / 352
200 / 352
300 / 352
Accuracy on training images: 93 %
0 / 79
50 / 79
Accuracy on test images: 87 %
iter: 29100 , loss/100it: 0.19943902157247068
iter: 29200 , loss/100it: 0.22981174290180206
iter: 29300 , loss/100it: 0.20408954314887523
iter: 29400 , loss/100it: 0.2128651310503483
iter: 29500 , loss/100it: 0.22354323744773866
iter: 29600 , loss/100it: 0.20942187681794167
iter: 29700 , loss/100it: 0.20648629210889338
iter: 29800 , loss/100it: 0.22273162223398685
iter: 29900 , loss/100it: 0.21552391290664674
iter: 30000 , loss/100it: 0.197256840914

0 / 352
100 / 352
200 / 352
300 / 352
Accuracy on training images: 99 %
0 / 79
50 / 79
Accuracy on test images: 91 %
iter: 42100 , loss/100it: 0.023613581624813376
iter: 42200 , loss/100it: 0.02630859647411853
iter: 42300 , loss/100it: 0.02180626029614359
iter: 42400 , loss/100it: 0.02403525535017252
iter: 42500 , loss/100it: 0.0246002602763474
iter: 42600 , loss/100it: 0.025439336160197855
iter: 42700 , loss/100it: 0.022640308463014663
iter: 42800 , loss/100it: 0.0263079794915393
iter: 42900 , loss/100it: 0.02456608631182462
iter: 43000 , loss/100it: 0.022244249135255814
iter: 43000 , test
0 / 352
100 / 352
200 / 352
300 / 352
Accuracy on training images: 99 %
0 / 79
50 / 79
Accuracy on test images: 91 %
iter: 43100 , loss/100it: 0.02260684992186725
iter: 43200 , loss/100it: 0.02408226082334295
iter: 43300 , loss/100it: 0.023454634512308985
iter: 43400 , loss/100it: 0.021855805637314915
iter: 43500 , loss/100it: 0.021119136072229594
iter: 43600 , loss/100it: 0.022863558735698463
iter:

iter: 55700 , loss/100it: 0.011572132465662435
iter: 55800 , loss/100it: 0.01112701523466967
iter: 55900 , loss/100it: 0.01162298531504348
iter: 56000 , loss/100it: 0.01177012242260389
iter: 56000 , test
0 / 352
100 / 352
200 / 352
300 / 352
Accuracy on training images: 99 %
0 / 79
50 / 79
Accuracy on test images: 91 %
iter: 56100 , loss/100it: 0.01040609851363115
iter: 56200 , loss/100it: 0.01190515401540324
iter: 56300 , loss/100it: 0.01171842470066622
iter: 56400 , loss/100it: 0.011743682480882853
iter: 56500 , loss/100it: 0.010592646312434227
iter: 56600 , loss/100it: 0.010156195313902572
iter: 56700 , loss/100it: 0.012468336666934192
iter: 56800 , loss/100it: 0.010860942915314808
iter: 56900 , loss/100it: 0.012471669958904386
iter: 57000 , loss/100it: 0.012355541879078374
iter: 57000 , test
0 / 352
100 / 352
200 / 352
300 / 352
Accuracy on training images: 99 %
0 / 79
50 / 79
Accuracy on test images: 91 %
iter: 57100 , loss/100it: 0.010857278830371797
iter: 57200 , loss/100it: 0.0

In [23]:
# ResNet 20
N = 3

net = Net(N);
net.apply(init_weights);
net.to(device);

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, weight_decay = 0.0001, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)


start_iter = 0
end_iter = 64000
it = 0

In [24]:
writer = SummaryWriter()
running_loss = 0.0
net.train();
while(it<=end_iter):
    for data in trainloader:
        it+=1
        inputs, labels = data[0].to(device), data[1].to(device)
        
        #random flip, constant padding, random crop
        if np.random.randint(0,2) == 0:
            inputs = torch.flip(inputs, [3])
        inputs = F.pad(inputs, (4, 4, 4, 4))
        W = inputs.size()[2]
        H = inputs.size()[3] 
        Ws = np.random.randint(0, W-32, 1)[0]
        Hs = np.random.randint(0, H-32, 1)[0]
        inputs = inputs[:,:,Ws:Ws+32,Hs:Hs+32]
        
        #training
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        if it == 32000 or it == 48000:
            scheduler.step()
            PATH = './cifar_10_resnet'+str(N*6+2)+'_iter'+str(it)+'.pth'
            torch.save(net.state_dict(), PATH)
            
        running_loss += loss.item()
        if it % 100 == 0:
            print("iter:",it,", loss/100it:",running_loss / 100)
            writer.add_scalar('training loss', running_loss / 100,it)
            running_loss = 0.0
            
        if it % 1000 == 0:
            print("iter:",it,", test")
            net.eval();
            #for Training set
            correct = 0
            total = 0
            with torch.no_grad():
                for i,data in enumerate(trainloader):
                    if i % 100 == 0:
                        print(i,"/",len(trainloader))
                    inputs, labels = data[0].to(device), data[1].to(device)
                    
                    outputs = net(inputs)

                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            print('Accuracy on training images: %d %%' % (100 * correct / total))
            writer.add_scalar('Acc on training', (100 * correct / total),it)
            
            #for Test set
            correct = 0
            total = 0
            with torch.no_grad():
                for i,data in enumerate(testloader):
                    if i % 50 == 0:
                        print(i,"/",len(testloader))
                    inputs, labels = data[0].to(device), data[1].to(device)
                    
                    outputs = net(inputs)

                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            print('Accuracy on test images: %d %%' % (100 * correct / total))
            writer.add_scalar('Acc on testing', (100 * correct / total),it)
            net.train();
        
print('Finished Training')

PATH = './cifar_10_resnet'+str(N*6+2)+'_iter'+str(it)+'.pth'
torch.save(net.state_dict(), PATH)
writer.close()

iter: 100 , loss/100it: 2.229663013219833
iter: 200 , loss/100it: 1.7424214959144593
iter: 300 , loss/100it: 1.611091344356537
iter: 400 , loss/100it: 1.4741445863246918
iter: 500 , loss/100it: 1.373933925628662
iter: 600 , loss/100it: 1.2640364730358125
iter: 700 , loss/100it: 1.198943825364113
iter: 800 , loss/100it: 1.123560608625412
iter: 900 , loss/100it: 1.0392798268795014
iter: 1000 , loss/100it: 1.0076363807916642
iter: 1000 , test
0 / 352
100 / 352
200 / 352
300 / 352
Accuracy on training images: 59 %
0 / 79
50 / 79
Accuracy on test images: 59 %
iter: 1100 , loss/100it: 0.9286403733491898
iter: 1200 , loss/100it: 0.9004082947969436
iter: 1300 , loss/100it: 0.8412130254507065
iter: 1400 , loss/100it: 0.813558389544487
iter: 1500 , loss/100it: 0.7527267789840698
iter: 1600 , loss/100it: 0.7688713455200196
iter: 1700 , loss/100it: 0.7281544363498688
iter: 1800 , loss/100it: 0.7010033693909645
iter: 1900 , loss/100it: 0.6805355855822564
iter: 2000 , loss/100it: 0.6827689230442047


iter: 14300 , loss/100it: 0.27625075027346613
iter: 14400 , loss/100it: 0.283586850464344
iter: 14500 , loss/100it: 0.26861775785684583
iter: 14600 , loss/100it: 0.2724536829441786
iter: 14700 , loss/100it: 0.2738576266169548
iter: 14800 , loss/100it: 0.2753096230328083
iter: 14900 , loss/100it: 0.2548686572164297
iter: 15000 , loss/100it: 0.2657020992040634
iter: 15000 , test
0 / 352
100 / 352
200 / 352
300 / 352
Accuracy on training images: 91 %
0 / 79
50 / 79
Accuracy on test images: 86 %
iter: 15100 , loss/100it: 0.27770285099744796
iter: 15200 , loss/100it: 0.2497419771552086
iter: 15300 , loss/100it: 0.2548259747028351
iter: 15400 , loss/100it: 0.26604945585131645
iter: 15500 , loss/100it: 0.2762607631087303
iter: 15600 , loss/100it: 0.25293921642005446
iter: 15700 , loss/100it: 0.2580187866091728
iter: 15800 , loss/100it: 0.2536784516274929
iter: 15900 , loss/100it: 0.2689136382192373
iter: 16000 , loss/100it: 0.2578543266654015
iter: 16000 , test
0 / 352
100 / 352
200 / 352
300

iter: 28100 , loss/100it: 0.23251489624381066
iter: 28200 , loss/100it: 0.21598307326436042
iter: 28300 , loss/100it: 0.21590652890503406
iter: 28400 , loss/100it: 0.22729399777948855
iter: 28500 , loss/100it: 0.24062612377107143
iter: 28600 , loss/100it: 0.20835210859775544
iter: 28700 , loss/100it: 0.22458786115050317
iter: 28800 , loss/100it: 0.22689737893640996
iter: 28900 , loss/100it: 0.22786436170339586
iter: 29000 , loss/100it: 0.20888620771467686
iter: 29000 , test
0 / 352
100 / 352
200 / 352
300 / 352
Accuracy on training images: 90 %
0 / 79
50 / 79
Accuracy on test images: 85 %
iter: 29100 , loss/100it: 0.2297537276148796
iter: 29200 , loss/100it: 0.23325007066130637
iter: 29300 , loss/100it: 0.2140235575288534
iter: 29400 , loss/100it: 0.21089519873261453
iter: 29500 , loss/100it: 0.22236458368599415
iter: 29600 , loss/100it: 0.21225971706211566
iter: 29700 , loss/100it: 0.21956342183053493
iter: 29800 , loss/100it: 0.2217031627893448
iter: 29900 , loss/100it: 0.22979206591

0 / 352
100 / 352
200 / 352
300 / 352
Accuracy on training images: 99 %
0 / 79
50 / 79
Accuracy on test images: 91 %
iter: 42100 , loss/100it: 0.0399103246582672
iter: 42200 , loss/100it: 0.03846651874948293
iter: 42300 , loss/100it: 0.0424110515974462
iter: 42400 , loss/100it: 0.04098502713721246
iter: 42500 , loss/100it: 0.03750155781395734
iter: 42600 , loss/100it: 0.039600168988108635
iter: 42700 , loss/100it: 0.03578881136141718
iter: 42800 , loss/100it: 0.04131331570446491
iter: 42900 , loss/100it: 0.03558808712288737
iter: 43000 , loss/100it: 0.03976568982936442
iter: 43000 , test
0 / 352
100 / 352
200 / 352
300 / 352
Accuracy on training images: 99 %
0 / 79
50 / 79
Accuracy on test images: 91 %
iter: 43100 , loss/100it: 0.03506295224651694
iter: 43200 , loss/100it: 0.032540793465450406
iter: 43300 , loss/100it: 0.03529857345391065
iter: 43400 , loss/100it: 0.03633990404196084
iter: 43500 , loss/100it: 0.032164681483991445
iter: 43600 , loss/100it: 0.03817823780700565
iter: 4370

iter: 55700 , loss/100it: 0.020612470677588136
iter: 55800 , loss/100it: 0.021666760239750147
iter: 55900 , loss/100it: 0.01862028471659869
iter: 56000 , loss/100it: 0.020962173724547027
iter: 56000 , test
0 / 352
100 / 352
200 / 352
300 / 352
Accuracy on training images: 99 %
0 / 79
50 / 79
Accuracy on test images: 91 %
iter: 56100 , loss/100it: 0.020784774208441378
iter: 56200 , loss/100it: 0.01843850968987681
iter: 56300 , loss/100it: 0.021061465851962568
iter: 56400 , loss/100it: 0.022172010163776578
iter: 56500 , loss/100it: 0.02111601174576208
iter: 56600 , loss/100it: 0.018930547547060996
iter: 56700 , loss/100it: 0.01884787459857762
iter: 56800 , loss/100it: 0.01955086597474292
iter: 56900 , loss/100it: 0.01926923110615462
iter: 57000 , loss/100it: 0.02012533846544102
iter: 57000 , test
0 / 352
100 / 352
200 / 352
300 / 352
Accuracy on training images: 99 %
0 / 79
50 / 79
Accuracy on test images: 92 %
iter: 57100 , loss/100it: 0.021056120458524674
iter: 57200 , loss/100it: 0.02