In [None]:
!git clone https://github.com/0601p/Model-Compression.git

In [None]:
import torch
import torch.nn.utils.prune as prune
import torchvision
import random
import sys
sys.path.append("/content/Model-Compression")

from model import *
from prune import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'
random.seed(777)
torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed(777)

In [None]:
learning_rate = 1e-2
batch_size = 512
epoch = 20

In [None]:
root = "./CIFAR100"
transform = torchvision.transforms.ToTensor()
data_train = torchvision.datasets.CIFAR100(root, train = True, transform = transform, download = True)
data_test = torchvision.datasets.CIFAR100(root, train = False, transform = transform, download = True)

In [None]:
train_loader = torch.utils.data.DataLoader(dataset=data_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(dataset=data_test, batch_size=batch_size, shuffle=False, drop_last=True)

In [None]:
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=100).to(device)
checkpoint = torch.load("/content/Model-Compression/model.pth")
model.load_state_dict(checkpoint)
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

In [None]:
module = model.conv1
stt = L1Strategy()
print(stt.apply(weights = module.weight, amount = 0.1))

In [None]:
for name, layer in model.named_modules():
    if isinstance(layer, torch.nn.Conv2d):
        idx = stt.apply(weights = layer.weight, amount = 0.1)
    if isinstance(layer, torch.nn.Linear):
        idx = stt.apply(weights = layer.weight, amount = 0.1)

""" Trying this one might be helpful
layer.load_state_dict({'weight': torch.tensor([[[[0.4738, -0.2197],
                      [-0.3436, -0.0754]]],
                    [[[0.1662, 0.4098],
                      [-0.4306, -0.4828]]]])}, strict=False)
"""

In [None]:
def train_one_epoch(print_result = False):
    model.train()
    loss_sum = 0.0
    accuracy_sum = 0.0
    length = 0

    for X, Y in train_loader:
        X = X.to(device)
        Y = Y.to(device)

        optimizer.zero_grad()

        pred = model(X)
        loss = criterion(pred, Y)
        pred_idx = torch.argmax(pred, 1)
        loss_sum += loss.item()
        accuracy_sum += torch.sum((pred_idx == Y).float()).item()
        length += X.size(0)

        loss.backward()
        optimizer.step()

    if(print_result):
        print("loss :", loss_sum / length)
        print("accuracy:", accuracy_sum / length)

In [None]:
def eval():
    with torch.no_grad():
        model.eval()
        loss_sum = 0.0
        accuracy_sum = 0.0
        length = 0

        for X, Y in train_loader:
            X = X.to(device)
            Y = Y.to(device)

            pred = model(X)
            loss = criterion(pred, Y)
            pred_idx = torch.argmax(pred, 1)
            loss_sum += loss.item()
            accuracy_sum += torch.sum((pred_idx == Y).float()).item()
            length += X.size(0)
            
        print("loss :", loss_sum / length)
        print("accuracy:", accuracy_sum / length)

In [None]:
def train():
    for i in range(epoch):
        print("EPOCH[" + str(i + 1) + "]")
        print("==== train ====")
        train_one_epoch(print_result=True)
        
        print("==== eval ====")
        eval()

In [None]:
eval()