In [1]:
import torch
import sys
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt
import torchvision

In [2]:
!pip install gdown

Collecting gdown
  Downloading gdown-5.0.1-py3-none-any.whl.metadata (5.6 kB)
Downloading gdown-5.0.1-py3-none-any.whl (16 kB)
Installing collected packages: gdown
Successfully installed gdown-5.0.1


In [21]:
def get_file_id_by_model(folder_name):
  file_id = {'resnet18_100-epochs_stl10': '14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF',
             'resnet18_100-epochs_cifar10': '1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C',
             'resnet50_50-epochs_stl10': '1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu'}
  return file_id.get(folder_name, "Model not found.")

In [22]:
folder_name = 'resnet18_100-epochs_stl10'
file_id = get_file_id_by_model(folder_name)
print(folder_name, file_id)

resnet18_100-epochs_stl10 14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF


In [23]:
# download and extract model files
os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))
os.system('unzip {}'.format(folder_name))
!ls

Downloading...
From (original): https://drive.google.com/uc?id=14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF
From (redirected): https://drive.google.com/uc?id=14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF&confirm=t&uuid=2ce8dccc-6dc2-402e-9428-06179075ef31
To: /mnt/2tb/ruyi/DNN-NTL-Pruning/src/ssl/resnet18_100-epochs_stl10.zip
100%|██████████| 116M/116M [00:01<00:00, 92.3MB/s] 


Archive:  resnet18_100-epochs_stl10.zip
  inflating: checkpoint_0100.pth.tar  
  inflating: config.yml              
  inflating: events.out.tfevents.1610901470.4cb2c837708d.2683858.0  
  inflating: training.log            
checkpoint_0100.pth.tar
config.yml
events.out.tfevents.1610901470.4cb2c837708d.2683858.0
prune-ssl-model.py
resnet18_100-epochs_stl10.zip
simclr.ipynb
training.log


In [15]:
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import Subset

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

Using device: cuda


In [5]:
def get_stl10_data_loaders(download, batch_size=256):
  train_dataset = datasets.STL10('../../data', split='train', download=download,
                                  transform=transforms.ToTensor())

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=10, drop_last=False, shuffle=True)
  
  test_dataset = datasets.STL10('../../data', split='test', download=download,
                                  transform=transforms.ToTensor())

  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=10, drop_last=False, shuffle=False)
  return train_loader, test_loader

def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):
  train_dataset = datasets.CIFAR10('../../data', train=True, download=download,
                                  transform=transforms.ToTensor())

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=10, drop_last=False, shuffle=True)
  
  test_dataset = datasets.CIFAR10('../../data', train=False, download=download,
                                  transform=transforms.ToTensor())

  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=10, drop_last=False, shuffle=False)
  return train_loader, test_loader

In [17]:
def get_cifar_dataloader(ratio=1.0):
    """
    Get the CIFAR10 dataloader
    """
    # Data loading code for cifar10 
    train_transform = transforms.transforms.Compose([
        transforms.Resize(32),
        transforms.transforms.RandomHorizontalFlip(),
        transforms.transforms.ToTensor(),
        transforms.transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2023, 0.1994, 0.2010],
        ),
    ])

    val_transform = transforms.transforms.Compose([
        transforms.Resize(32),
        transforms.transforms.ToTensor(),
        transforms.transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2023, 0.1994, 0.2010],
        ),
    ])

    train_dataset = datasets.CIFAR10(
        root="../../data/",
        train=True,
        download=True,
        transform=train_transform,
    )

    # Define the size of the subset
    subset_size = int(len(train_dataset) * ratio)
    print(f"Using the sample size of {subset_size}.")

    # Create a random subset for training
    indices = np.random.permutation(len(train_dataset))
    train_indices = indices[:subset_size]
    train_subset = Subset(train_dataset, train_indices)

    val_dataset = datasets.CIFAR10(
        root="../../data/",
        train=False,
        download=True,
        transform=val_transform,
    )

    train_loader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=256,
        shuffle=True,
        num_workers=10,
    )

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=256,
        shuffle=False,
        num_workers=10,
    )

    return train_loader, val_loader

In [18]:
model = torchvision.models.resnet18(pretrained=False, num_classes=10).to(device)
checkpoint = torch.load('../../base_models/resnet18-simclr-cifar10.tar', map_location=device)
state_dict = checkpoint['state_dict']

for k in list(state_dict.keys()):

  if k.startswith('backbone.'):
    if k.startswith('backbone') and not k.startswith('backbone.fc'):
      # remove prefix
      state_dict[k[len("backbone."):]] = state_dict[k]
  del state_dict[k]



In [19]:
log = model.load_state_dict(state_dict, strict=False)
assert log.missing_keys == ['fc.weight', 'fc.bias']
train_loader, test_loader = get_cifar_dataloader()

Files already downloaded and verified
Using the sample size of 50000.
Files already downloaded and verified


In [20]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003, weight_decay=0.0008)
criterion = torch.nn.CrossEntropyLoss().to(device)

In [21]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [22]:
epochs = 20
for epoch in range(epochs):
  top1_train_accuracy = 0
  for counter, (x_batch, y_batch) in enumerate(train_loader):
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    logits = model(x_batch)
    loss = criterion(logits, y_batch)
    
    top1 = accuracy(logits, y_batch, topk=(1,))
    top1_train_accuracy += top1[0]

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  top1_train_accuracy /= (counter + 1)
  top1_accuracy = 0
  top5_accuracy = 0
  for counter, (x_batch, y_batch) in enumerate(test_loader):
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    logits = model(x_batch)
  
    top1, top5 = accuracy(logits, y_batch, topk=(1,5))
    top1_accuracy += top1[0]
    top5_accuracy += top5[0]
  
  top1_accuracy /= (counter + 1)
  top5_accuracy /= (counter + 1)
  print(f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 Train accuracy: {top5_accuracy.item()}\t")

Epoch 0	Top1 Train accuracy 59.30364990234375	Top1 Test accuracy: 68.212890625	Top5 Train accuracy: 97.109375	
Epoch 1	Top1 Train accuracy 72.28834533691406	Top1 Test accuracy: 71.9921875	Top5 Train accuracy: 97.724609375	
Epoch 2	Top1 Train accuracy 77.15800476074219	Top1 Test accuracy: 74.3359375	Top5 Train accuracy: 98.125	
Epoch 3	Top1 Train accuracy 80.27742004394531	Top1 Test accuracy: 74.39453125	Top5 Train accuracy: 98.251953125	
Epoch 4	Top1 Train accuracy 82.39835357666016	Top1 Test accuracy: 75.322265625	Top5 Train accuracy: 98.271484375	
Epoch 5	Top1 Train accuracy 84.78356170654297	Top1 Test accuracy: 76.26953125	Top5 Train accuracy: 98.349609375	
Epoch 6	Top1 Train accuracy 86.57565307617188	Top1 Test accuracy: 76.044921875	Top5 Train accuracy: 98.30078125	
Epoch 7	Top1 Train accuracy 87.97512817382812	Top1 Test accuracy: 75.99609375	Top5 Train accuracy: 98.3203125	
Epoch 8	Top1 Train accuracy 89.07206726074219	Top1 Test accuracy: 77.255859375	Top5 Train accuracy: 98.4277

In [24]:
# Save the model after fine-tuning
model_path = "../../base_models/resnet18-finetune-cifar10.tar" 
# save the model static dict 
torch.save({
    'state_dict': model.state_dict(),
}, model_path)