In [1]:
import import_ipynb
from CustomDataset import ControlsDataset

import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F

# Ignore warnings
import warnings
import time
warnings.filterwarnings("ignore")

from tqdm.notebook import tqdm_notebook

importing Jupyter notebook from CustomDataset.ipynb


In [3]:
class Trainer():
    def __init__(self, device, model, dataset, optimizer, criterion):
        self.device = device
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.dataset = dataset
        
    def initializeEpoch(self):
        self.summation = 0
        self.val_summation = 0
        self.validation_training = enumerate(self.dataset.validloader)
    
    def fit(self, epochs, report_period):
        save_path = "snapshots/{}_{:.5f}_weights.pt"
        iters_trained = []
        training_losses = []
        validation_losses = []
        count = 0
        
        for epoch in range(epochs):
            self.initializeEpoch()
            for i_batch, sampled_batch in tqdm_notebook(enumerate(self.dataset.dataloader),
                                                       total=len(self.dataset.dataloader)):

                torch.cuda.empty_cache()
                self.model.train()
                #inputs and forward pass
                images = sampled_batch['image'].to(self.device).float()
                controls = sampled_batch['control'].to(self.device).long()
                controls = torch.flatten(controls)
                
                #backwards pass
                self.optimizer.zero_grad()
                prediction = self.model(images)
                prediction = torch.flatten(prediction)
                
                print("Predictions", prediction.shape)
                print("controls", controls.shape)
                
                #calculate loss
                loss = self.criterion(prediction, controls)
                loss.backward()
                self.optimizer.step()
                
                torch.cuda.empty_cache()
                
                #get batch losses
                val_i,batch = self.validationBatch()
                val_loss = self.score(batch)
                self.summation += float(loss.data)
                self.val_summation += float(val_loss.data)
                
                if i_batch % report_period == 0:
                    iters_trained.append(count)
                    average_loss = round(self.summation/float(i_batch+1),5)
                    average_val_loss = round(self.val_summation/float(i_batch+1),5)
                    training_losses.append(average_loss)
                    validation_losses.append(average_val_loss)
                count += 1
                    
            print("Epoch: "+str(epoch))
            print("Training Loss: "+str(average_loss))
            print("Validation Loss: "+str(average_val_loss))
            #self.model.save_weights(self.optimizer, epoch, save_path.format(count,average_loss))
            #torch.save(self.model, save_path.format(count,average_loss))   
            plt.plot(iters_trained,training_losses, label="training")
            plt.plot(iters_trained,validation_losses, label="validation")
            plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
            plt.show()
        
        
    def validationBatch(self):
        try:
            val_i,batch = next(self.validation_training)
        except StopIteration:
            self.validation_training = enumerate(self.dataset.validloader)
            val_i,batch = next(self.validation_training)
        return val_i,batch
        
        
    def score(self, sampled_batch):
        self.model.eval()
        images = sampled_batch['image'].to(self.device).float()
        controls = sampled_batch['control'].to(self.device).long()
        #forward pass

        prediction = self.model(images)
        
        loss = self.criterion(prediction, controls)
        torch.cuda.empty_cache()
        return loss.data
    
    def label_distribution(self):
        histogram = plt.hist(self.dataset.labels.dataframe["Angle"])
        count = histogram[0]
        values = histogram[1]
        return count,values