# import

In [1]:
import sys
import os
from tqdm import tqdm

import torch
from torch import nn, optim
from torch.optim import lr_scheduler

import numpy as np

from tensorboardX import SummaryWriter

sys.path.append('../src/')
import utils
import dataloader
from train_utils import LossMetric, LossAccumulator
from sam import SAM
from model.AutoEncoder import AutoEncoder

# set params

In [2]:
config = utils.readConfig('../config.json')

In [3]:
train_config = config['train_params']['pretrain']

batch_size = train_config['batch_size']
num_epochs = train_config['num_epochs']
lr = train_config['lr']
val_portion = train_config['val_portion']
lr_factor = train_config['lr_factor']
patience = train_config['patience']

model_config = config['model_params']
encoder_layers = model_config['encoder_layers']
decoder_layers = model_config['decoder_layers']
base_unit_num = model_config['base_unit_num']
emb_dim = model_config['emb_dim']

output_dir = '../data/pretrain'
utils.makeDirs(output_dir, ['log', 'weights'])

# load dataset

In [4]:
dataloaders, classes = dataloader.load_mnist('../data', batch_size, val_portion)

# def model

In [5]:
model = AutoEncoder(encoder_layers, decoder_layers,
                    base_unit_num, emb_dim)

# def optimizer

In [6]:
base_optimizer = optim.SGD
optimizer = SAM(model.parameters(), base_optimizer, lr=lr)

# train

## prepare

In [7]:
# define loss
loss_metric = LossMetric(nn.MSELoss())

# clear gpu cache memory
try:
    torch.cuda.empty_cache()
except:
    print("no GPUs")
    
# device 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# set learning rate scheduler
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                           factor=lr_factor,verbose=True,
                                           patience=patience)

# loss logger
train_loss_accumulator = LossAccumulator(dataloaders['train'])
val_loss_accumulator = LossAccumulator(dataloaders['validation'])

## train

In [None]:
# open tensorboard
writer = SummaryWriter(f'{output_dir}/log')

best_loss = 10.0
for epoch in tqdm(range(1, num_epochs + 1)):
    # Train
    
    # set model to TRAINING mode (for BatchNorm, Dropout etc...)
    model.train()
    
    # clear loss of pre epoch
    train_loss_accumulator.clear()
    
    for imgs, labels in dataloaders['train']:
        imgs = imgs.float().to(device)# cast inputs to GPU
        
        #1st step        
        rec_imgs = model(imgs)# forward
        loss = loss_metric(imgs, rec_imgs)# compute loss
        loss.backward()
        optimizer.first_step(zero_grad=True)# apply gradients
        
        train_loss_accumulator(loss)# logging loss
        
        #2nd step
        rec_imgs = model(imgs)# forward
        loss = loss_metric(imgs, rec_imgs)# compute loss
        loss.backward()
        optimizer.second_step(zero_grad=True)# apply gradients
        
    # Eval
    
    # set model to EVALUATION mode (for BatchNorm, Dropout etc...)
    model.eval()
    
    # clear loss of pre epoch
    val_loss_accumulator.clear()
    
    for imgs, labels in dataloaders['validation']:
        imgs = imgs.float().to(device)# cast inputs to GPU
        rec_imgs = model(imgs)# forward
        loss = loss_metric(imgs, rec_imgs)# compute loss
        val_loss_accumulator(loss)# apply gradients
    
    # add values to tensorboard
    writer.add_scalar("pretrain/loss", train_loss_accumulator.loss, epoch)
    writer.add_scalar("pretrain/val_loss", val_loss_accumulator.loss, epoch)
    
    # save best model
    if best_loss > val_loss_accumulator.loss:
        torch.save(model.encoder.state_dict(),
                   os.path.join(f'{output_dir}/weights', 'Encoder.pth'))
        torch.save(model.decoder.state_dict(),
                   os.path.join(f'{output_dir}/weights', 'Decoder.pth'))
        best_score = val_loss_accumulator.loss
        
# close tensorboard
writer.close()

 23%|██▎       | 58/256 [03:50<13:15,  4.02s/it]