In [None]:
import pickle
import numpy as np
import torch
from torchvision import datasets, transforms
import torch.nn as nn
from importlib import reload
import MobileNet_v3
import functions
import datafunc
import config
config = reload(config)
MobileNet_v3 = reload(MobileNet_v3)
functions = reload(functions)
datafunc = reload(datafunc)
from functions import train, accuracy
from datafunc import MyDataLoader, train_test_split
from MobileNet_v3 import get_model as gm3

device = torch.device(config.DEVICE)

torch.cuda.empty_cache()

In [None]:
mobilenet = gm3(100, 'small', 1., 0.8, 3).to(device)

In [None]:
# dataset settings
batch_size = 256
IM_SIZE = 224  # resize image
NORMALIZE = ([0.485, 0.456, 0.406],
             [0.229, 0.224, 0.225])


train_transformer = transforms.Compose([
    transforms.Resize(IM_SIZE),
    transforms.RandomRotation(15),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(*NORMALIZE)
])


test_transformer = transforms.Compose([
    transforms.Resize(IM_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(*NORMALIZE)
])

cifar_train = datasets.CIFAR100('data/',
                             transform=train_transformer,
                             download=True)

with open('classes_name.pkl', 'wb') as f:
    pickle.dump(cifar_train.classes, f)

cifar_val = datasets.CIFAR100('data/',
                               transform=test_transformer,
                               train=True)

train_indices, val_indices = \
    train_test_split(np.arange(len(cifar_train)), .75, cifar_train.targets)

train_loader = MyDataLoader(cifar_train, batch_size, train_indices, shuffle=True)
val_loader = MyDataLoader(cifar_val, batch_size, val_indices, shuffle=True)

In [None]:
LR = 1e-2
OPTIM_MOMENTUM = 0.9
WEIGHT_DECAY = 1e-5  # l2 weight decay

optimizer = torch.optim.Adam(mobilenet.parameters(),
                            lr=LR,
                            weight_decay = WEIGHT_DECAY
                            )
loss_func = nn.CrossEntropyLoss()

# scheduler parameters
factor = 0.5
patience = 4
threshold = 0.001

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=factor, patience=patience,
    verbose=True, threshold=threshold
)


In [None]:
EPOCHS= 30
train_history, best_parameters = \
    train(mobilenet, train_loader, loss_func, optimizer,
          EPOCHS, accuracy, val_loader, scheduler)

In [None]:
torch.save({'model_state_dict': best_parameters}, 'model.torch')

