In [None]:
import os, sys
project_dir = os.path.join(os.getcwd(),'../..')
if project_dir not in sys.path:
    sys.path.append(project_dir)

attention_dir = os.path.join(project_dir, 'modules/AttentionMap')
if attention_dir not in sys.path:
    sys.path.append(attention_dir)

sparse_dir = os.path.join(project_dir, 'modules/Sparse')
if sparse_dir not in sys.path:
    sys.path.append(sparse_dir) 

import numpy as np
import torch, config
from torch import nn
import os

# CIFAR Dataset

In [None]:
from torchvision.datasets import CIFAR100
from torchvision.transforms import ToTensor, Normalize, Compose, Resize

transform = Compose([
        Resize((128, 128)),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
dataset = CIFAR100('data', train=True, download=True, transform=transform)

from torch.utils.data import DataLoader, random_split

val_size = int(0.1*len(dataset))
train_size = len(dataset) - val_size

train_set, val_set = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
val_loader = DataLoader(val_set, batch_size=512, shuffle=False)

# Model

Applying custom Inverted Residual with Coordinate attention mechanishm

In [None]:
from torchvision.models import MobileNetV2
from derma.architecture import MobileNetDecoder, InvertedResidual

inverted_residual_setting = [
        # t, c, n, s
        [1, 16, 1, 1],
        [6, 24, 2, 2],
        [6, 32, 3, 2],
        [6, 64, 4, 2],
        [6, 96, 3, 1],
        [6, 160, 3, 2],
        [6, 320, 1, 1],
    ]

encoder = MobileNetV2(inverted_residual_setting=inverted_residual_setting, block=InvertedResidual).features
decoder = MobileNetDecoder(inverted_residual_setting)
model = nn.Sequential(encoder, decoder)

# Training

In [None]:
from torch.utils.tensorboard import SummaryWriter
from derma.utils import train

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)
tb_writer = SummaryWriter(log_dir=os.path.join(config.RESULT_DIR, 'log/reconstruction/cifar'))
criterion = nn.MSELoss()
n_epoch = 10

train(model, [train_loader, val_loader], optimizer, criterion, n_epoch, tb_writer, reconstruction=True)

# Save encoder weights

In [None]:
save_dir = os.path.join(config.RESULT_DIR, 'weights/encoder/cifar')
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
    
torch.save(encoder.state_dict(), os.path.join(save_dir, 'encoder.pth'))