In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import wandb
from tqdm import tqdm

In [None]:
# wandb login here
!wandb login

In [3]:
class BasicBlock(nn.Module):
  expansion = 1

  # see notes on variable stride, shortcut/skip/residual connection, and bias-free convolutions
  def __init__(self, in_channels, out_channels, stride=1):
    super(BasicBlock, self).__init__()
    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(out_channels)
    self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(out_channels)

    self.shortcut = nn.Sequential()
    if stride != 1 or in_channels != out_channels:
      self.shortcut = nn.Sequential(
          nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
          nn.BatchNorm2d(out_channels)
      )
    self.relu = nn.ReLU(inplace=True)

  def forward(self, x):
    identity = self.shortcut(x)

    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)

    out += identity
    out = self.relu(out)

    return out

In [4]:
class Bottleneck(nn.Module):
  expansion = 4

  # see notes on variable stride, variable kernel_size, self.expansion, and ReLU
  def __init__(self, in_channels, out_channels, stride=1):
    super(Bottleneck, self).__init__()

    # focus only on channel dimensions, not spatial dimensions (like conv2)
    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
    self.bn1 = nn.BatchNorm2d(out_channels)

    # variable stride. spatial reduction should happen alongside spatial feature extraction
    self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(out_channels)

    # focus only on channel dimensions, not spatial dimensions (like conv2)
    # restores/increases channel dimensions
    # final output: 4x more channels than out_channels
    self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)
    self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)

    self.shortcut = nn.Sequential()
    if stride != 1 or in_channels != out_channels * self.expansion:
      self.shortcut = nn.Sequential(
          nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),
          nn.BatchNorm2d(out_channels * self.expansion)
      )

    # modifies input instead of creating new tensor
    # saves memory during training
    self.relu = nn.ReLU(inplace=True)

  def forward(self, x):
    identity = self.shortcut(x)

    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)
    out = self.relu(out)

    out = self.conv3(out)
    out = self.bn3(out)

    out += identity
    out = self.relu(out)

    return out

In [5]:
class ResNet(nn.Module):
  def __init__(self, block, num_blocks, num_classes=10):
    super(ResNet, self).__init__()
    self.in_channels = 64

    # initial convolution modified for CIFAR-10
    # see notes for further detail
    self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(64)
    self.relu = nn.ReLU(inplace=True)

    # four layers with different number of blocks
    self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
    self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
    self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
    self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)

    self.avgpool = nn.AdaptiveAvgPool2d((1,1))
    self.fc = nn.Linear(512 * block.expansion, num_classes)

  def _make_layer(self, block, out_channels, num_blocks, stride):
    # Breaks down to:
    # [stride] is [2]
    # num_blocks-1 is 3 (4-1)
    # [1]*3 creates [1,1,1]
    # final result: [2,1,1,1]
    strides = [stride] + [1]*(num_blocks-1)
    layers = []
    for stride in strides:
      layers.append(block(self.in_channels, out_channels, stride))
      self.in_channels = out_channels * block.expansion
    # unpacks the list of blocks
    return nn.Sequential(*layers)

  def forward(self, x):
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.layer1(out)
    out = self.layer2(out)
    out = self.layer3(out)
    out = self.layer4(out)

    out = self.avgpool(out)
    out = torch.flatten(out, 1)
    out = self.fc(out)

    return out

In [7]:
class DataModule:
  def __init__(self, batch_size=128, num_workers=2):
    self.batch_size = batch_size
    self.num_workers = num_workers

    # define transforms
    self.transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    self.transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

  def setup(self):
    # download CIFAR-10
    self.train_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=self.transform_train
    )

    self.test_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=self.transform_test
    )

  def train_dataloader(self):
    return DataLoader(
        self.train_dataset,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=self.num_workers
    )

  def test_dataloader(self):
    return DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        shuffle=False,
        num_workers=self.num_workers
    )

In [18]:
class Trainer:
  def __init__(self, model, data_module, config):
    self.model = model
    self.data_module = data_module
    self.config = config

    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    self.model = self.model.to(self.device)

    self.criterion = nn.CrossEntropyLoss()
    self.optimizer = optim.SGD(
        self.model.parameters(),
        lr=config['learning_rate'],
        momentum=config['momentum'],
        weight_decay=config['weight_decay']
    )

    # see notes for why we use CosineAnnealingLR scheduler
    self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
        self.optimizer,
        T_max=config['epochs']
    )

  def train_epoch(self, epoch):
    self.model.train()
    train_loss = 0
    correct = 0
    total = 0

    # creates a progress bar for training epochs
    pbar = tqdm(self.data_module.train_dataloader(), desc=f'Epoch {epoch}')
    for batch_idx, (inputs, targets) in enumerate(pbar):
      inputs, targets = inputs.to(self.device), targets.to(self.device)

      self.optimizer.zero_grad()
      outputs = self.model(inputs)
      loss = self.criterion(outputs, targets)
      loss.backward()
      self.optimizer.step()

      train_loss += loss.item()
      _, predicted = outputs.max(1)
      total += targets.size(0)
      correct += predicted.eq(targets).sum().item()

      # update progress bar
      pbar.set_postfix({
          'loss': train_loss/(batch_idx+1),
          'acc': 100.*correct/total
      })

      # log metrics to wandb
      wandb.log({
          "train_loss": loss.item(),
          "train_acc": 100.*correct/total,
          "learning_rate": self.scheduler.get_last_lr()[0]
      })

  def test_epoch(self):
    self.model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
      for inputs, targets in self.data_module.test_dataloader():
        inputs, targets = inputs.to(self.device), targets.to(self.device)
        outputs = self.model(inputs)
        loss = self.criterion(outputs, targets)

        test_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    accuracy = 100.*correct/total
    avg_loss = test_loss/len(self.data_module.test_dataloader())

    # log metrics to wandb
    wandb.log({
        "test_loss": avg_loss,
        "test_acc": accuracy
    })

    return accuracy

  def fit(self):
    best_acc = 0
    for epoch in range(self.config['epochs']):
      self.train_epoch(epoch)
      accuracy = self.test_epoch()
      self.scheduler.step()

      # save best model
      if accuracy > best_acc:
        best_acc = accuracy
        torch.save(self.model.state_dict(), 'best_model.pth')

In [1]:
# see notes for different block configurations

def create_resnet18():
  return ResNet(BasicBlock, [2,2,2,2])

def create_resnet34():
  return ResNet(BasicBlock, [3,4,6,3])

def create_resnet50():
  return ResNet(Bottleneck, [3,4,6,3])

def create_resnet101():
  return ResNet(Bottleneck, [3,4,23,3])

def create_resnet152():
  return ResNet(Bottleneck, [3,8,36,3])

In [17]:
config = {
  'learning_rate': 0.1,
  'momentum': 0.9,
  'weight_decay': 5e-4,
  'epochs': 100,
  'batch_size': 128
}

In [None]:
wandb.init(
  project="resnet50-cifar10",
  config=config
)

In [None]:
# create model and data module
model = create_resnet50()
data_module = DataModule(batch_size=config['batch_size'])
data_module.setup()

In [None]:
trainer = Trainer(model, data_module, config)
trainer.fit()

In [None]:
wandb.finish()