In [1]:
import os, sys
project_dir = os.path.join(os.getcwd(),'..')
if project_dir not in sys.path:
    sys.path.append(project_dir)

attention_dir = os.path.join(project_dir, 'modules/AttentionMap')
if attention_dir not in sys.path:
    sys.path.append(attention_dir)

sparse_dir = os.path.join(project_dir, 'modules/Sparse')
if sparse_dir not in sys.path:
    sys.path.append(sparse_dir) 

import torch
from torch import nn
from derma.architecture import InvertedResidual
from torchvision.models import MobileNetV2
from torchvision import datasets, transforms

In [2]:
input_test = torch.rand((1,3,64,64))

# Original setting for mobilenet v2 (https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv2.py)
inverted_residual_setting = [
        # t, c, n, s
        [1, 16, 1, 1],
        [6, 24, 2, 2],
        [6, 32, 3, 2],
        [6, 64, 4, 2],
        [6, 96, 3, 1],
        [6, 160, 3, 2],
        [6, 320, 1, 1],
    ]

model = MobileNetV2(num_classes=100, inverted_residual_setting=inverted_residual_setting, block=InvertedResidual)

print('Features Output shape: {}'.format(model.features(input_test).shape))
print('Classifier Output shape: {}'.format(model(input_test).shape))

Features Output shape: torch.Size([1, 1280, 2, 2])
Classifier Output shape: torch.Size([1, 100])


# Dataset

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((128, 128))
])

dataset = datasets.CIFAR100('data', train=True, transform=transform, download=True)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to data\cifar-100-python.tar.gz


169001984it [00:28, 5882825.73it/s]                                


Extracting data\cifar-100-python.tar.gz to data


# Training

In [4]:
from derma.utils import train
from torch.utils.tensorboard import SummaryWriter

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)
criterion = nn.CrossEntropyLoss()
tb_writer = SummaryWriter('log/MobilenetCoordAtt')
n_epoch = 2

train(model, loader, optimizer, criterion, 2, tb_writer)

  0%|          | 0/10 [06:57<?, ?epoch/s, tls=4.2275]


KeyboardInterrupt: 