In [None]:
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torchvision.transforms import ToTensor,ToPILImage
from torchvision.io import read_image
from torchvision import models
from PIL import Image
import os
import io
import pandas as pd
import math
from torch.optim import Adam
from src.model import *
from src.loss import *

# Demo using VGG16 Model and CIFAR 10

In [None]:
# Step 1, fine-tune VGG16 on CIFAR-10

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

batch_size = 100

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
# trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
#                                           shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
# testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
#                                          shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
train_loader = torch.utils.data.DataLoader(trainset
    , batch_size = batch_size
    , shuffle = True)
test_loader = torch.utils.data.DataLoader(testset
    , batch_size = batch_size
    , shuffle = True)
n_total_step = len(train_loader)
print(n_total_step)

In [None]:
# Import a pre-trained VGG16 and replace classification layer with new layer that has number of classes in CIFAR-10
vgg16 = models.vgg16(weights='IMAGENET1K_V1')
input_lastLayer = vgg16.classifier[6].in_features
vgg16.classifier[6] = nn.Linear(input_lastLayer,10)

In [None]:
# Examine VGG16
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

num_epochs = 5
batch_size = 40
learning_rate = 0.001

model = vgg16.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate, momentum=0.9,weight_decay=5e-4)

for epoch in range(num_epochs):
    for i, (imgs , labels) in enumerate(train_loader):
        imgs = imgs.to(device)
        labels = labels.to(device)

        labels_hat = model(imgs)
        n_corrects = (labels_hat.argmax(axis=1)==labels).sum().item()
        loss_value = criterion(labels_hat, labels)
        loss_value.backward()
        optimizer.step()
        optimizer.zero_grad()
        if (i+1) % 250 == 0:
           print(f'epoch {epoch+1}/{num_epochs}, step: {i+1}/{n_total_step}: loss = {loss_value:.5f}, acc = {100*(n_corrects/labels.size(0)):.2f}%')
    print()

# After fine tuning the target model, DECORE can be used to compress the model

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()
print(torch.cuda.mem_get_info())

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

num_epochs = 1
lr = 0.01
net = DECOR(model)
net = net.to(device)

id = 0
param = []

for n,p in net.named_parameters():
    if n.endswith(f".S_{id}") and p.requires_grad:
        param.append(p)
        id += 1
        
optimizer = Adam(param, lr=lr)
criterion = CustomLoss(-200)

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()
print(torch.cuda.mem_get_info())

for epoch in range(num_epochs):
    #train_loss = 0.0
    for i, (imgs , labels) in enumerate(train_loader):
        imgs = imgs.to(device)
        labels = labels.to(device)
        labels_hat = net.target_model(imgs)
        loss_value = criterion(net.agents_list,labels_hat, labels)
        #train_loss += loss_value.detach().cpu().item() / len(train_loader)
        optimizer.zero_grad()
        loss_value.backward()
        optimizer.step()
        if (i+1) % 250 == 0:
           print(f'epoch {epoch+1}/{num_epochs}, step: {i+1}/{n_total_step}: loss = {loss_value:.5f}')
    print()