In [1]:
import torch
import torchvision
import torch.nn as nn

In [2]:
batch_size = 128

In [5]:
trainset = torchvision.datasets.CIFAR10('cifar',
            train=True, download=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.Resize((32, 32)),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ]))

testset = torchvision.datasets.CIFAR10('cifar',
            train=False, download=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.Resize((32, 32)),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ]))

trainset = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
testset = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)

train_dataset = iter(trainset)
test_dataset = iter(testset)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
def ResidualBlock(channel):
    return nn.Sequential(
        nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=3, padding=1),
        nn.BatchNorm2d(channel),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=3, padding=1),
        nn.BatchNorm2d(channel),
        nn.ReLU(inplace=True)
    )

class ResNet(nn.Module):
  def __init__(self):
    super(ResNet, self).__init__()
    self.layer0 = nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3),
        nn.BatchNorm2d(64),
        nn.ReLU()
    )
    self.layerfc = nn.Sequential(
        nn.AdaptiveAvgPool2d((1, 1)),
        nn.Flatten(),
        nn.Linear(in_features=512, out_features=10)
    )
    self.layer1 = ResidualBlock(64)
    self.layer2 = ResidualBlock(128)
    self.layer3 = ResidualBlock(256)
    self.layer4 = ResidualBlock(512)
    self.activation = nn.ReLU()
    self.convert2 = nn.Sequential(
        nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True))
    self.convert3 = nn.Sequential(
        nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True))
    self.convert4 = nn.Sequential(
        nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
        nn.BatchNorm2d(512),
        nn.ReLU(inplace=True))

  def forward(self, x):
    x = self.layer0(x)

    x = self.layer1(x) + x

    tmp1 = self.convert2(x)
    tmp2 = self.layer2(tmp1)
    x = tmp1 + tmp2
    
    tmp1 = self.convert3(x)
    tmp2 = self.layer3(tmp1)
    x = tmp1 + tmp2

    tmp1 = self.convert4(x)
    tmp2 = self.layer4(tmp1)
    x = tmp1 + tmp2

    x = self.layerfc(x)
    return x

In [10]:
model = ResNet()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

criterion = nn.CrossEntropyLoss().to(device)

def optim_scheduler(epoch):
  if epoch < 10:
    lr = 1e-3
  elif epoch < 20:
    lr = 1e-4
  else:
    lr = 1e-5
  return torch.optim.Adam(model.parameters(), lr=lr)

loss_list = []
acc_list = []

for epoch in range(50):
  optimizer = optim_scheduler(epoch)
  for idx, (x, label) in enumerate(trainset):
    x, label = x.to(device), label.to(device)
    model.train()
    logits = model(x)
    loss = criterion(logits, label)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
        
  model.eval()
  with torch.no_grad():
      tot_corr = 0
      tot_num = 0
      
      for x, label in testset:
          x, label = x.to(device), label.to(device)
          logits = model(x)
          pred = logits.argmax(dim=1)
          
          tot_corr += torch.eq(pred, label).float().sum().item()
          tot_num += x.size(0)
      acc = tot_corr / tot_num
      
  loss_list.append(loss.item())
  acc_list.append(acc)
    
  print('epoch: {}, loss: {}, acc: {}'.format(epoch+1, loss, acc))

epoch: 1, loss: 1.3981024026870728, acc: 0.4554
epoch: 2, loss: 1.2791693210601807, acc: 0.5377
epoch: 3, loss: 0.7110854387283325, acc: 0.6587
epoch: 4, loss: 0.6633554100990295, acc: 0.7077
epoch: 5, loss: 0.5661798119544983, acc: 0.759
epoch: 6, loss: 0.7514356970787048, acc: 0.7488
epoch: 7, loss: 0.5353270769119263, acc: 0.7792
epoch: 8, loss: 0.5380469560623169, acc: 0.7815
epoch: 9, loss: 0.2961450517177582, acc: 0.7799
epoch: 10, loss: 0.2953492999076843, acc: 0.811
epoch: 11, loss: 0.25534293055534363, acc: 0.8514
epoch: 12, loss: 0.16899806261062622, acc: 0.8503
epoch: 13, loss: 0.20900793373584747, acc: 0.8513
epoch: 14, loss: 0.04431191831827164, acc: 0.8488
epoch: 15, loss: 0.08536839485168457, acc: 0.8484
epoch: 16, loss: 0.16057142615318298, acc: 0.8486
epoch: 17, loss: 0.10420699417591095, acc: 0.8486
epoch: 18, loss: 0.07933694124221802, acc: 0.8426
epoch: 19, loss: 0.04908553510904312, acc: 0.8424
epoch: 20, loss: 0.04525141045451164, acc: 0.8427
epoch: 21, loss: 0.02