<a href="https://colab.research.google.com/github/ClaireZixiWang/SimCLR/blob/master/CIFAR100_Supervised_MTL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Data



In [1]:
from torch.utils.data import DataLoader
import torchvision.transforms as transforms 
import torch.nn as nn
import torchvision.models as models
import torch

In [2]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [3]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
num_epochs = 100
learning_rate = 0.0003
BATCH_SIZE = 256
weight_decay = 1e-4
PATH = '/content/drive/MyDrive/resnet18'

In [4]:
device

device(type='cuda')

In [5]:
# Borrowed from https://github.com/ryanchankh/cifar100coarse
import numpy as np
from torchvision.datasets import CIFAR100


class CIFAR100TwoLabels(CIFAR100):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(CIFAR100TwoLabels, self).__init__(root, train, transform, target_transform, download)

        # update labels
        coarse_labels = np.array([ 4,  1, 14,  8,  0,  6,  7,  7, 18,  3,
                                   3, 14,  9, 18,  7, 11,  3,  9,  7, 11,
                                   6, 11,  5, 10,  7,  6, 13, 15,  3, 15, 
                                   0, 11,  1, 10, 12, 14, 16,  9, 11,  5,
                                   5, 19,  8,  8, 15, 13, 14, 17, 18, 10,
                                   16, 4, 17,  4,  2,  0, 17,  4, 18, 17,
                                   10, 3,  2, 12, 12, 16, 12,  1,  9, 19, 
                                   2, 10,  0,  1, 16, 12,  9, 13, 15, 13,
                                  16, 19,  2,  4,  6, 19,  5,  5,  8, 19,
                                  18,  1,  2, 15,  6,  0, 17,  8, 14, 13])
        self.targets = np.array([self.targets, coarse_labels[self.targets]]).transpose()

        # update classes
        self.classes = [['beaver', 'dolphin', 'otter', 'seal', 'whale'],
                        ['aquarium_fish', 'flatfish', 'ray', 'shark', 'trout'],
                        ['orchid', 'poppy', 'rose', 'sunflower', 'tulip'],
                        ['bottle', 'bowl', 'can', 'cup', 'plate'],
                        ['apple', 'mushroom', 'orange', 'pear', 'sweet_pepper'],
                        ['clock', 'keyboard', 'lamp', 'telephone', 'television'],
                        ['bed', 'chair', 'couch', 'table', 'wardrobe'],
                        ['bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach'],
                        ['bear', 'leopard', 'lion', 'tiger', 'wolf'],
                        ['bridge', 'castle', 'house', 'road', 'skyscraper'],
                        ['cloud', 'forest', 'mountain', 'plain', 'sea'],
                        ['camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo'],
                        ['fox', 'porcupine', 'possum', 'raccoon', 'skunk'],
                        ['crab', 'lobster', 'snail', 'spider', 'worm'],
                        ['baby', 'boy', 'girl', 'man', 'woman'],
                        ['crocodile', 'dinosaur', 'lizard', 'snake', 'turtle'],
                        ['hamster', 'mouse', 'rabbit', 'shrew', 'squirrel'],
                        ['maple_tree', 'oak_tree', 'palm_tree', 'pine_tree', 'willow_tree'],
                        ['bicycle', 'bus', 'motorcycle', 'pickup_truck', 'train'],
                        ['lawn_mower', 'rocket', 'streetcar', 'tank', 'tractor']]

In [6]:
# Image preprocessing modules
transform = transforms.Compose([
    transforms.Pad(4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor()])

In [7]:
train_dataset = CIFAR100TwoLabels(root='./data', 
                                  train=True, 
                                  transform=transform, 
                                  download=True)
test_dataset = CIFAR100TwoLabels(root='./data', 
                                 train=False, 
                                 transform=transform, 
                                 download=True)
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=BATCH_SIZE, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=BATCH_SIZE, 
                                          shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


# ResNet18 with two prediction head

In [8]:
class ResNetTwoPrediction(nn.Module):
  def __init__(self):
    super(ResNetTwoPrediction, self).__init__()
    self.backbone = models.resnet18(pretrained=False)
    dim = self.backbone.fc.in_features
    # print(dim)
    self.backbone.fc = nn.Identity() # place holder for the fc layer
    self.fine = nn.Sequential(nn.Linear(dim, dim), 
                                       nn.ReLU(), 
                                       nn.Linear(in_features=512, out_features=100, bias=True))
    self.coarse = nn.Sequential(nn.Linear(dim, dim), 
                                         nn.ReLU(), 
                                         nn.Linear(in_features=512, out_features=20, bias=True))

  def forward(self, x):
    base_representation = self.backbone(x)
    fine_out = self.fine(base_representation)
    coarse_out = self.coarse(base_representation)
    return fine_out, coarse_out


In [9]:
model = ResNetTwoPrediction().to(device)


In [21]:
model = ResNetTwoPrediction().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), learning_rate, weight_decay=weight_decay)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
                                                           last_epoch=-1)

# Train the model
total_step = len(train_loader)
curr_lr = learning_rate
for epoch in range(num_epochs):
  for i, (images, labels) in enumerate(train_loader):
    labels = labels.transpose(0,1)
    images = images.to(device)
    labels = labels.to(device)

    # Forward pass
    fine_out, coarse_out = model(images)
    fine_loss = criterion(fine_out, labels[0])
    coarse_loss = criterion(coarse_out, labels[1])
    # print(fine_loss, coarse_loss)
    # Multi task loss
    loss = 0.5 * coarse_loss + 0.5 * fine_loss
    # print(loss)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (i+1) % 100 == 0:
      print ("Epoch [{}/{}], Step [{}/{}], Fine Loss: {:.4f}, Coarse Loss: {:.4f},  Total Loss: {:.4f}"
                .format(epoch+1, num_epochs, i+1, total_step, fine_loss.item(), coarse_loss.item(), loss.item()))

    # Decay learning rate
  if epoch >= 10:
    scheduler.step()                                                           

  if epoch %10 == 0:
    checkpoint_name = 'checkpoint_resnet18-sup-mtl_{:04d}.pth.tar'.format(epoch)
    torch.save(model.state_dict(), PATH+'/'+checkpoint_name)

checkpoint_name = 'checkpoint_resnet18-sup-mtl_{:04d}.pth.tar'.format(num_epochs)
torch.save(model.state_dict(), PATH+'/'+checkpoint_name)

Epoch [1/100], Step [100/196], Fine Loss: 3.6724, Coarse Loss: 2.2666,  Total Loss: 2.9695
Epoch [2/100], Step [100/196], Fine Loss: 3.3861, Coarse Loss: 2.1717,  Total Loss: 2.7789
Epoch [3/100], Step [100/196], Fine Loss: 2.6449, Coarse Loss: 1.7292,  Total Loss: 2.1871
Epoch [4/100], Step [100/196], Fine Loss: 2.7102, Coarse Loss: 1.7171,  Total Loss: 2.2137
Epoch [5/100], Step [100/196], Fine Loss: 2.4895, Coarse Loss: 1.5964,  Total Loss: 2.0430
Epoch [6/100], Step [100/196], Fine Loss: 2.4389, Coarse Loss: 1.6149,  Total Loss: 2.0269
Epoch [7/100], Step [100/196], Fine Loss: 2.5327, Coarse Loss: 1.6674,  Total Loss: 2.1000
Epoch [8/100], Step [100/196], Fine Loss: 2.2877, Coarse Loss: 1.4547,  Total Loss: 1.8712
Epoch [9/100], Step [100/196], Fine Loss: 2.2837, Coarse Loss: 1.4548,  Total Loss: 1.8693
Epoch [10/100], Step [100/196], Fine Loss: 2.0671, Coarse Loss: 1.2988,  Total Loss: 1.6830
Epoch [11/100], Step [100/196], Fine Loss: 2.0979, Coarse Loss: 1.2715,  Total Loss: 1.68