In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.tensorboard import SummaryWriter

def imshow(img):
    img = img / 2 + 0.5 
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [2]:
#dataset
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

trainset, valset = torch.utils.data.random_split(trainset, [45000, 5000])

In [3]:
#Network
class Net(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.n = n
        self.conv = nn.Sequential(nn.Conv2d(3, 16, (3,3), padding=1, bias=False),\
                    nn.BatchNorm2d(16))
    
        self.stage1 = nn.ModuleList()
        for _ in range(n):
                tmp = []
                tmp.append(nn.Conv2d(16, 16, (3,3), padding=1, bias=False))
                tmp.append(nn.BatchNorm2d(16))
                tmp.append(nn.ReLU())
                tmp.append(nn.Conv2d(16, 16, (3,3), padding=1, bias=False))
                tmp.append(nn.BatchNorm2d(16))
                self.stage1.append(nn.Sequential(*tmp))
        
        self.stage2 = nn.ModuleList()
        for _ in range(n):
                tmp = []
                if _ == 0:
                    tmp.append(nn.Conv2d(16, 32, (3,3), stride=2, padding=1, bias=False))
                else:
                    tmp.append(nn.Conv2d(32, 32, (3,3), padding=1, bias=False))
                tmp.append(nn.BatchNorm2d(32))
                tmp.append(nn.ReLU())
                tmp.append(nn.Conv2d(32, 32, (3,3), padding=1, bias=False))
                tmp.append(nn.BatchNorm2d(32))
                self.stage2.append(nn.Sequential(*tmp))

        self.stage3 = nn.ModuleList()
        for _ in range(n):
                tmp = []
                if _ == 0:
                    tmp.append(nn.Conv2d(32, 64, (3,3), stride=2, padding=1, bias=False))
                else:
                    tmp.append(nn.Conv2d(64, 64, (3,3), padding=1, bias=False))
                tmp.append(nn.BatchNorm2d(64))
                tmp.append(nn.ReLU())
                tmp.append(nn.Conv2d(64, 64, (3,3), padding=1, bias=False))
                tmp.append(nn.BatchNorm2d(64))
                self.stage3.append(nn.Sequential(*tmp))
                
        self.GAP = nn.AvgPool2d((8, 8))
        self.fc = nn.Linear(64, 10)
        
    def forward(self, x):
        x = F.relu(self.conv(x))
        # stage_1
        for i in range(self.n):
            x = F.relu(self.stage1[i](x) + x)
        # stage_2    
        for i in range(self.n):
            if i == 0:
                x_i = F.interpolate(x, size=x.shape[3]//2)  #size down
                x_i = torch.cat((x_i,torch.zeros(x_i.shape).to(x.device)),1) #double channel
                x = F.relu(self.stage2[i](x) + x_i)
            else:
                x = F.relu(self.stage2[i](x) + x)
        # stage_3
        for i in range(self.n):
            if i == 0:
                x_i = F.interpolate(x, size=x.shape[3]//2) #size down
                x_i = torch.cat((x_i, torch.zeros(x_i.shape).to(x.device)), 1) #double channel
                x = F.relu(self.stage3[i](x) + x_i)
            else:
                x = F.relu(self.stage3[i](x) + x)
        x = self.GAP(x) 
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        return x

In [4]:
#weight Initialization
def init_weights(m):
        if type(m) == nn.Linear:
            torch.nn.init.kaiming_normal_(m.weight)
            m.bias.data.fill_(0.01)
        if isinstance(m, nn.Conv2d):
            torch.nn.init.kaiming_normal_(m.weight.data)

In [5]:
#GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
#Make ResNet 
N = 5 #5 for ResNet-32, 3 for ResNet-20
net = Net(N)
net.apply(init_weights)
net.to(device);

In [7]:
#Criterion, Optimizer, Scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, weight_decay = 0.0001, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

In [8]:
#Batch size, Dataloader
batch_size = 128
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=4)

In [9]:
#Train
end_iter = 64000
it = 0

writer = SummaryWriter()
running_loss = 0.0
net.train();
while(it<=end_iter):
    for data in trainloader:
        it+=1
        inputs, labels = data[0].to(device), data[1].to(device)
        
        #random flip, constant padding, random crop
        if np.random.randint(0,2) == 0:
            inputs = torch.flip(inputs, [3])
        inputs = F.pad(inputs, (4, 4, 4, 4))
        W = inputs.size()[2]
        H = inputs.size()[3] 
        Ws = np.random.randint(0, W-32, 1)[0]
        Hs = np.random.randint(0, H-32, 1)[0]
        inputs = inputs[:,:,Ws:Ws+32,Hs:Hs+32]
        
        #training
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        #Optimizer Scheduling
        if it == 32000 or it == 48000:
            scheduler.step()
            PATH = './cifar_10_resnet'+str(N*6+2)+'_iter'+str(it)+'.pth'
            torch.save(net.state_dict(), PATH)
        
        #logging, each 100 iteration
        running_loss += loss.item()
        if it % 100 == 0:
            print("iter:",it,", loss/100it:",running_loss / 100)
            writer.add_scalar('training loss', running_loss / 100,it)
            running_loss = 0.0
            
        #accuracy test, each 1000 iteration
        if it % 1000 == 0:
            print("iter:",it,", test")
            net.eval();
            
            #for Training set
            correct = 0
            total = 0
            with torch.no_grad():
                for i,data in enumerate(trainloader):
                    if i % 100 == 0:
                        print(i,"/",len(trainloader))
                    inputs, labels = data[0].to(device), data[1].to(device)
                    
                    outputs = net(inputs)

                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            print('Accuracy on training images: %d %%' % (100 * correct / total))
            writer.add_scalar('Acc on training', (100 * correct / total),it)
            
            #for Test set
            correct = 0
            total = 0
            with torch.no_grad():
                for i,data in enumerate(testloader):
                    if i % 50 == 0:
                        print(i,"/",len(testloader))
                    inputs, labels = data[0].to(device), data[1].to(device)
                    
                    outputs = net(inputs)

                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            print('Accuracy on test images: %d %%' % (100 * correct / total))
            writer.add_scalar('Acc on testing', (100 * correct / total),it)
            net.train();
        
print('Finished Training')
PATH = './cifar_10_resnet'+str(N*6+2)+'_iter'+str(it)+'.pth'
torch.save(net.state_dict(), PATH)
writer.close()