In [1]:
import sys
import copy

import torch
from torch import nn, optim
from torch.utils import data
from torchvision import datasets, transforms

In [2]:
# sys.path.append("/home/matthias/Documents/EmbeddedAI/deep-microcompression/")
sys.path.append("../../")

from development import (
    Sequential,
    BatchNorm2d,
    Conv2d,
    Linear,
    ReLU,
    MaxPool2d,
    Flatten
)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
mobilenetv1_file = f"mobilenetv1_state_dict_{DEVICE}.pth"

LUCKY_NUMBER = 25
torch.manual_seed(LUCKY_NUMBER)
torch.random.manual_seed(LUCKY_NUMBER)
torch.cuda.manual_seed(LUCKY_NUMBER)


In [4]:
DEVICE

'cuda'

In [5]:
data_transform = transforms.Compose([
    # transforms.RandomCrop((24, 24)),
    # transforms.Resize((32, 32)),
    transforms.ToTensor(),
])

cifar10_train_dataset = datasets.CIFAR10("./datasets", train=True, download=True, transform=data_transform)
cifar10_test_dataset = datasets.CIFAR10("./datasets", train=False, download=True, transform=data_transform)

cifar10_train_loader = data.DataLoader(cifar10_train_dataset, batch_size=32, shuffle=True)
cifar10_test_loader = data.DataLoader(cifar10_test_dataset, batch_size=32)

cifar100_train_dataset = datasets.CIFAR100("./datasets", train=True, download=True, transform=data_transform)
cifar100_test_dataset = datasets.CIFAR100("./datasets", train=False, download=True, transform=data_transform)

cifar100_train_loader = data.DataLoader(cifar100_train_dataset, batch_size=32, shuffle=True)
cifar100_test_loader = data.DataLoader(cifar100_test_dataset, batch_size=32)


In [6]:
def DeepWiseSeperableConv2d(
        in_channel:int,
        out_channels:int,
        kernel_size:int,
        stride:int,
        padding:int,
):
    return (Conv2d(in_channels=in_channel, out_channels=in_channel, kernel_size=kernel_size, stride=stride, groups=in_channel, padding=padding),
            BatchNorm2d(num_features=in_channel),
            ReLU(),
            Conv2d(in_channels=in_channel, out_channels=out_channels, kernel_size=1, stride=1, padding=0),
            BatchNorm2d(num_features=out_channels),
            ReLU())

In [7]:
mobilenetv1_model = Sequential(
    Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=2, padding=0),
    BatchNorm2d(num_features=32),
    ReLU(),
    *DeepWiseSeperableConv2d(in_channel=32, out_channels=32, kernel_size=3, stride=1, padding=0),
    *DeepWiseSeperableConv2d(in_channel=32, out_channels=32, kernel_size=3, stride=1, padding=0),
    *DeepWiseSeperableConv2d(in_channel=32, out_channels=32, kernel_size=3, stride=2, padding=0),
    Flatten(),
    Linear(in_features= 32*5*5, out_features=10)
).to(DEVICE)

top1_acc_fun = lambda y_pred, y_true: (y_pred.argmax(dim=1) == y_true).sum().item()
# top5_acc_fun = lambda y_pred, y_true: (y_true in (y_pred.topk(dim=1))).sum().item()

In [None]:
try:
    # raise RuntimeError
    mobilenetv1_model.load_state_dict(torch.load(mobilenetv1_file, weights_only=True))
    
except (RuntimeError, FileNotFoundError, RuntimeError) as e:
    
    criterion_fun = nn.CrossEntropyLoss()
    optimizion_fun = optim.Adam(mobilenetv1_model.parameters(), lr=1.e-3)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizion_fun, mode="min", patience=2, factor=.1)

    mobilenetv1_model.fit(
        cifar10_train_loader, 150, 
        criterion_fun, optimizion_fun, lr_scheduler,
        validation_dataloader=cifar10_test_loader, 
        metrics={"top1_acc" : top1_acc_fun},
        device=DEVICE,
    )
    torch.save(mobilenetv1_model.state_dict(), mobilenetv1_file)
    
mobilenetv1_model.evaluate(cifar10_test_loader, top1_acc_fun, device=DEVICE)

  0%|          | 0/150 [00:00<?, ?it/s]

  1%|          | 1/150 [00:31<1:18:49, 31.74s/it]

epoch    0 | train loss 0.0465 | validation loss 0.0404 | train acc 0.4604 | validation acc 0.5335


  1%|▏         | 2/150 [01:03<1:18:03, 31.65s/it]

epoch    1 | train loss 0.0377 | validation loss 0.0370 | train acc 0.5728 | validation acc 0.5791


  2%|▏         | 3/150 [01:34<1:16:27, 31.21s/it]

epoch    2 | train loss 0.0339 | validation loss 0.0335 | train acc 0.6175 | validation acc 0.6172


  3%|▎         | 4/150 [02:04<1:15:20, 30.96s/it]

epoch    3 | train loss 0.0318 | validation loss 0.0319 | train acc 0.6417 | validation acc 0.6470


  3%|▎         | 5/150 [02:37<1:16:13, 31.54s/it]

epoch    4 | train loss 0.0300 | validation loss 0.0330 | train acc 0.6626 | validation acc 0.6293


  4%|▍         | 6/150 [03:10<1:17:17, 32.20s/it]

epoch    5 | train loss 0.0289 | validation loss 0.0300 | train acc 0.6766 | validation acc 0.6661


  5%|▍         | 7/150 [03:41<1:15:39, 31.75s/it]

epoch    6 | train loss 0.0280 | validation loss 0.0300 | train acc 0.6866 | validation acc 0.6602


  5%|▌         | 8/150 [04:11<1:14:12, 31.36s/it]

epoch    7 | train loss 0.0273 | validation loss 0.0292 | train acc 0.6926 | validation acc 0.6717


## Prunning

In [None]:
sparsity = .5

mobilenetv1_model.cpu()
mobilenetv1_mcu_model = mobilenetv1_model.prune_channel(0.5)

mobilenetv1_mcu_model.to(DEVICE)
print(f"Pruned with {sparsity}, acc = {mobilenetv1_mcu_model.evaluate(cifar10_test_loader, accuracy_fun, device=DEVICE)}")

criterion_fun = nn.CrossEntropyLoss()
optimizion_fun = optim.Adam(mobilenetv1_mcu_model.parameters(), lr=1.e-3)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizion_fun, mode="min", patience=2)

mobilenetv1_mcu_model.fit(
    cifar10_train_loader, 15, 
    criterion_fun, optimizion_fun, lr_scheduler,
    validation_dataloader=cifar10_test_loader, 
    metrics={"top1_acc" : top1_acc_fun},
    device=DEVICE,
)