## import modules

In [4]:
import argparse
import torch
import torch.optim as optim
from tqdm import tqdm
from datetime import datetime
import time
from utils import load_param, save_param
from load_data import nyu2_dataloaders
from model.model import get_model
from loss import compute_loss

## define check_loss on_set
purpose: check loss on validation set and test set

In [5]:
def check_loss_on_set(dataloader, model, device):
    model.eval()
    loss = 0
    with torch.no_grad():
        for x_val, y_val in dataloader:
            x_val = x_val.to(device=device)
            y_val = y_val.to(device=device)
            y_pred = model(x_val)
            
            _loss = compute_loss(pred=y_pred,
                                 truth=y_val,
                                 device=device,
                                 _alpha=loss_params['_alpha'], 
                                 _lambda=loss_params['_lambda'], 
                                 _mu=loss_params['_mu'])
            loss += _loss
        loss /= len(dataloader)
        print("Test on [val]: loss avg: %.4f" 
              % (
                    loss    
                )
              )

## hyperparams

In [6]:
hparams = {
    'epochs': 25,
    'lr': 1e-4,
    'L2': 1e-4,
    'batch_size': 32
}
loss_params = {
    '_alpha': 0.5,
    '_lambda': 1,
    '_mu': 1
}

## Define train session

In [7]:
def train(train_dataloader,
          val_dataloader,
          model,
          optimizer,
          epochs,
          device):
    print_every = 5
    
    model = model.to(device=device)
    
    start_time = time.time()
    
    print("train(): Training on {}".format(device))
    
    for epoch in range(epochs):
        # batched_image_size: (batch_size, C, H, W)
        for i, (x_tr, y_tr) in enumerate(tqdm(train_dataloader)):
            # turn to train mode
            model.train()
            
            x_tr = x_tr.to(device=device)
            y_tr = y_tr.to(device=device)
            y_pred = model(x_tr)
            
            loss = compute_loss(pred=y_pred,
                                truth=y_tr,
                                device=device,
                                _alpha=loss_params['_alpha'], 
                                _lambda=loss_params['_lambda'], 
                                _mu=loss_params['_mu'])
            
            optimizer.zero_grad()
            
            loss.backward()
            
            optimizer.step()
            
            end_time = time.time()
            if i % print_every == 0:
                # print the information of the epoch
                print("[Epoch]: %d/%d [Iteration]: %d/%d, [loss]: %.4f, [Time Spent]: %.3f"
                      %(
                            epoch, epochs, 
                            i, len(train_dataloader), 
                            loss, 
                            (end_time - start_time)
                        )
                      )
                
        # check on validation set each epoch
        check_loss_on_set(dataloader=val_dataloader,
                          model=model,
                          device=device)

## Define main session
main() is actually the function does the setup and call train()

In [8]:
def main():
    # hyperparams    
    # ---------------- params ---------------- #
    epochs = hparams['epochs']
    lr = hparams['lr']
    weight_decay = hparams['L2']
    # batchsize should better be more than 32 since BN is used frequently
    batchsize = hparams['batch_size']
    # ---------------- params ---------------- #
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)
    
    print("main(): Getting model......")
    model = get_model(encoder='resnet50')
    
    optimizer = optim.Adam(model.parameters(), 
                           lr=lr,
                           weight_decay=weight_decay)
    
    print("main(): Getting dataloaders......")
    train_set, val_set, test_set = nyu2_dataloaders(batchsize=batchsize,
                                             nyu2_path='./nyu2_train')
    
    print("main(): start training......")
    # all epochs wrapped in train()
    train(train_dataloader=train_set,
          val_dataloader=val_set,
          model=model,
          optimizer=optimizer,
          epochs=epochs,
          device=device)
    
    print("Training Session is over, test the model on testset")
    # after training, test it on testset
    check_loss_on_set(dataloader=test_set,
                      model=model,
                      device=device)
    
    # SAVE THE PARAMETERS
    # default is current time, change it whatever you like
    filelabel = datetime.today().strftime('%Y-%m-%d-%H:%M:%S')
    save_param(model=model,
               pth_path='./model_pth/{}.pth'.format(filelabel))

## RUN!

In [9]:
main()

main(): Getting model......


Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to pretrained_model/resnet50/resnet50-19c8e357.pth
100%|██████████| 97.8M/97.8M [00:27<00:00, 3.69MB/s]


main(): Getting dataloaders......
Entering nyu2_dataloaders()
---------------- Loading Dataloaders ----------------
-------- Datasets are ready, preparing Dataloaders --------




ValueError: num_samples should be a positive integer value, but got num_samples=0