In [None]:
!pip install wandb
!pip install torch torchvision

In [None]:
import wandb
import argparse
wandb.login()

In [None]:
wandb.init(
    project="CIFAR-100-deep-learning",
    config={
    "learning_rate": 0.001,
    "architecture": "ResNet18",
    "dataset": "CIFAR-100",
    "epochs": 1,
    }
)

In [68]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# transform before training
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [None]:
# load train, valid, test data
trainset, validset = torch.utils.data.random_split(
    torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform = transform),
    lengths=[45000, 5000])
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16,shuffle=True, num_workers=2)
validloader = torch.utils.data.DataLoader(validset, batch_size=16, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=16, shuffle=False, num_workers=2)

classes = ('beaver', 'dolphin', 'otter', 'seal', 'whale',
      'aquarium fish', 'flatfish', 'ray', 'shark', 'trout',
      'orchids', 'poppies', 'roses', 'sunflowers', 'tulips',
      'bottles', 'bowls', 'cans', 'cups', 'plates',
      'apples', 'mushrooms', 'oranges', 'pears', 'sweet peppers',
      'clock', 'computer keyboard', 'lamp', 'telephone', 'television',
      'bed', 'chair', 'couch', 'table', 'wardrobe',
      'bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach',
      'bear', 'leopard', 'lion', 'tiger', 'wolf',
      'bridge', 'castle', 'house', 'road', 'skyscraper',
      'cloud', 'forest', 'mountain', 'plain', 'sea',
      'camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo',
      'fox', 'porcupine', 'possum', 'raccoon', 'skunk',
      'crab', 'lobster', 'snail', 'spider', 'worm',
      'baby', 'boy', 'girl', 'man', 'woman',
      'crocodile', 'dinosaur', 'lizard', 'snake', 'turtle',
      'hamster', 'mouse', 'rabbit', 'shrew', 'squirrel',
      'maple', 'oak', 'palm', 'pine', 'willow',
      'bicycle', 'bus', 'motorcycle', 'pickup' 'truck', 'train',
      'lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor')

In [70]:
class ResBlock(nn.Module):
  def __init__(self, in_channels, out_channels, stride=1, downsample=None):
    super(ResBlock, self).__init__()
    self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, padding = 1)
    self.bn1 = nn.BatchNorm2d(out_channels)
    self.relu = nn.ReLU(inplace=True)
    self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding = 1)
    self.bn2 = nn.BatchNorm2d(out_channels)
    self.downsample = downsample
  def forward(self, x):
    residual=x
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)
    out = self.conv2(out)
    out = self.bn2(out)
    if(self.downsample):  #add shortcut
        residual = self.downsample(x)
    out += residual
    out = self.relu(out)
    return out

In [71]:
# creat network
class Net(nn.Module):
  def __init__(self, ResBlock):
    super(Net, self).__init__()
    self.in_channels = 64
    self.conv = nn.Conv2d(3, 64, 3)
    self.bn = torch.nn.BatchNorm2d(64)
    self.relu = torch.nn.ReLU(inplace=True)
    self.layer1 = self._make_layers(ResBlock, 64, 2)
    self.layer2 = self._make_layers(ResBlock, 128, 2, 2)
    self.layer3 = self._make_layers(ResBlock, 256, 2, 2)
    self.layer4 = self._make_layers(ResBlock, 512, 2, 2)
    self.avg_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
    self.fc = torch.nn.Linear(512, 100)

  def _make_layers(self, ResBlock, out_channels, blocks, stride=1):
    downsample = None
    # shortcut tensor size is needed to same as ResBlock result, here is processing shortcut tensor size
    if (stride != 1) or (self.in_channels != out_channels):
        downsample = torch.nn.Sequential(
            nn.Conv2d(self.in_channels, out_channels,1, stride=stride),
            torch.nn.BatchNorm2d(out_channels)
        )
    layers = []
    layers.append(ResBlock(self.in_channels, out_channels, stride, downsample))
    self.in_channels = out_channels
    for i in range(1, blocks):
        layers.append(ResBlock(out_channels, out_channels))
    return torch.nn.Sequential(*layers)

  def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    x = self.relu(x)
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)
    x = self.avg_pool(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)
    return x

net = Net(ResBlock).to(device)

# optimizer and loss function
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

In [None]:
# start training
from tqdm.notebook import tqdm
progress_bar = tqdm(enumerate(trainloader, 0), total=len(trainloader), desc='Training' )

for epoch in range(1):  # epoch
  running_loss = 0.0
  for i, data in (enumerate(trainloader, 0)):

    inputs, labels = data[0].to(device), data[1].to(device)

    # zero gradient
    optimizer.zero_grad()

    # forward, backward, optimize
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    wandb.log({"loss": loss})
    progress_bar.update(1)
    running_loss += loss.item()
    if i % 250 == 0:
      # validation
      net.eval()
      correct = 0
      total = 0
      with torch.no_grad():
        for data in validloader:
          images, labels = data[0].to(device), data[1].to(device)
          outputs = net(images)
          _, predicted = torch.max(outputs.data, 1)
          total += labels.size(0)
          correct += (predicted == labels).sum().item()
      progress_bar.set_postfix(accuracy=100 * correct / total)
      wandb.log({"accuracy": correct/total})
    running_loss = 0.0

print('Finished Training')

In [None]:
# test network on test set
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

# calculate accuracy
class_correct = list(0. for i in range(100))
class_total = list(0. for i in range(100))
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

for i in range(100):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))


In [None]:
wandb.finish()