In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import utils
from models import MobileNet

# set model parameters
model_parameters = ('mobilenet', utils.num_classes, True)

# load model
model = MobileNet(*model_parameters).to(utils.device)

# training parameters
LEARNING_RATE = 0.005
EPOCHS = 3
BATCH_SIZE = 128
MOMENTUM = 0.9
GAMMA = 0.2
STEP_SIZE = 1

# create pytorch datasets
datasets = {x: utils.HernitiaDataset(utils.dfs_path + '/' + x + '_no_temp.pkl', is_stage_feature = True, 
            transform = utils.data_transforms[x]) for x in ['training', 'validation']}

# instantiate data loaders
dataloaders = {x: utils.DataLoader(dataset=datasets[x], batch_size=BATCH_SIZE, shuffle=True) for x in ['training']}

# criterion is cross entropy loss
criterion = nn.CrossEntropyLoss()

# observe that all parameters are being optimized
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

# decay LR by a factor GAMMA every STEP_SIZE epochs
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)

utils.train_model(model = model, 
                    model_name = model.model_name,  #  name of the model which will be the name of the saved weights file within /weights
                    dataloaders = dataloaders, 
                    criterion = criterion, 
                    optimizer = optimizer, 
                    scheduler = exp_lr_scheduler, 
                    num_epochs=EPOCHS,
                    validation = False)