In [None]:
"""
This notebook will allow you to train or run a model on an individual modality
"""
import torch
import numpy as np
import torch.optim as optim
from torch.utils import data
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import pandas as pd
from sklearn.metrics import roc_auc_score
import datetime
import matplotlib.animation as animation
import matplotlib.pyplot as plt
from Models.Base_3DCAE import Base_3DCAE 
from io import StringIO
import os
import ffmpeg
import pdb
import parameters
from functions import create_pytorch_dataset
from functions import get_total_performance_metrics
from functions import get_performance_metrics
from functions import get_global_performance_metrics
from functions import get_window_metrics
from functions import get_frame_metrics
from functions import animate


In [None]:
#Parameters

window_len = parameters.window_len
stride = parameters.stride
fair_comparison = parameters.fair_comparison

device = parameters.device

dropout = parameters.dropout
learning_rate = parameters.learning_rate
num_epochs = parameters.num_epochs
chunk_size = parameters.chunk_size
forward_chunk = parameters.forward_chunk
forward_chunk_size = parameters.forward_chunk_size
loss_fn = parameters.loss_fn

def full_pipeline(name, dset, window_len, fair_comparison, path, stride, modelpath):   

    # Lets load the H%PY dataset into a pytorch dataset class.Please see dataset_creator on how to generate the H5PY file. 
    Test_Dataset, test_dataloader, Train_Dataset, train_dataloader = create_pytorch_dataset(name, dset, path, window_len, fair_comparison, stride, TOD = 'Both')
    print('Train Dataloader - {}'.format(len(train_dataloader)))
    print('Test Dataloader - {}'.format(len(test_dataloader)))
    
    # Now lets train our model

    # prepare for GPU training 
    print('Device Used - ' + device)
    torch.cuda.empty_cache()

    # select which model to use
    model = Base_3DCAE().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    def train_model(filepath):
        print("Training has Begun")
        model.train()
        for epoch in range(num_epochs):
            val_loss = 0
            frame_stats = [] 
            window_stats = [] 
            for i, (sample, labels) in enumerate(train_dataloader):
                # ===================forward=====================
                sample = sample.to(device, dtype=torch.float)
                # split sample into smaller sizes due to GPU memory constraints
                chunks = torch.split(sample, chunk_size, dim=1)
                recon_vid = []

                for chunk in chunks:
                    output = model(chunk)
                    output = output.to(device).permute(1, 0, 2, 3, 4)
                    model.zero_grad()
                    loss = loss_fn(output, chunk)
                    recon_vid.append(output)
                    # ===================backward====================
                    # Getting gradients w.r.t. parameters
                    loss.backward()
                    # Updating parameters
                    optimizer.step()
                    # Clear gradients w.r.t. parameters
                    optimizer.zero_grad()
                    torch.cuda.empty_cache()
                
                if epoch == num_epochs-1:
                    output = torch.cat(recon_vid, dim=1)
                    # convert tensors to numpy arrays for easy manipluations
                    sample = sample.data.cpu().numpy()
                    output = output.data.cpu().numpy()
                    labels = labels.data.cpu().numpy()
                    frame_std, frame_mean, frame_labels, window_std, window_mean, window_labels = get_performance_metrics(sample, output, labels, window_len)
                    frame_stats.append([frame_mean, frame_std, frame_labels])
                    window_stats.append([window_mean, window_std, window_labels])
            
            if epoch == num_epochs-1:
                #get_total_performance_metrics(frame_stats, window_stats, window_len)
                recon_errors = []
                recon_labels = []
                for i in range(len(frame_stats)):
                    # print(i)
                    # this a single video metrics
                    frame_mean, frame_std, frame_labels = frame_stats[i]
                    recon_errors.append([frame_mean, frame_std])
                    recon_labels.append(frame_labels)    
                np.save(project_directory+"\Output\Recon_Errors\\train_recon_errors_{}.npy".format(modality), recon_errors)
                np.save(project_directory+"\Output\Recon_Errors\\train_recon_labels_{}.npy".format(modality), recon_labels)
    
            # ===================log========================
            print("epoch [{}/{}], loss:{:.4f}".format(epoch + 1, num_epochs, loss.item()))
            torch.save(model.state_dict(), filepath) # save the model each epoch at location filepath
            
        torch.cuda.empty_cache()
        print("Training has Completed")
        
   
    def foward_pass(path):
        model.load_state_dict(torch.load(path)) # load a saved model 
        model.eval()

        frame_stats = []
        window_stats = []

        with torch.no_grad():
            print("foward pass occuring")
            # just forward pass of model on test dataset
            for j, (sample, labels) in enumerate(test_dataloader):
                # foward pass to get output
                torch.cuda.empty_cache()
                sample = sample.to(device, dtype=torch.float)
                chunks = torch.split(sample, forward_chunk, dim=1)
                recon_vid = []
                for chunk in chunks:
                    output = model(chunk)
                    output = output.to(device).permute(1, 0, 2, 3, 4)
                    recon_vid.append(output)
                    torch.cuda.empty_cache()
                output = torch.cat(recon_vid, dim=1)
                # convert tensors to numpy arrays for easy manipluations
                sample = sample.data.cpu().numpy()
                output = output.data.cpu().numpy()
                labels = labels.data.cpu().numpy()


                frame_std, frame_mean, frame_labels, window_std, window_mean, window_labels = get_performance_metrics(sample, output, labels, window_len)
                frame_stats.append([frame_mean, frame_std, frame_labels])
                window_stats.append([window_mean, window_std, window_labels])
                
                #if j % 10 == 0:
                    #print(sample.shape)
                    #animate(sample[0, :, :, :, :], output[0, :, :, :, :], frame_mean, dset, start_time)
                
        return(frame_stats, window_stats)
    
    start_time = str(datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S"))
    modality = (name + start_time)
    filepath = (project_directory+"\Output\Models\\"+ modality + start_time)
    # comment out this call if you dont want to train a model
    train_model(filepath)

    # INSERT modelpath instead of filepath to use a specified pre trained model 
    frame_stats, window_stats = foward_pass(filepath)
    
    print(modality)
    get_total_performance_metrics(modality, frame_stats, window_stats, window_len)
    get_global_performance_metrics(modality, frame_stats, window_stats, window_len)
   
    return() 

# Directory names of the raw dataset from the Fall-Data folder
# list_of_files = ['Thermal','ONI_IR','IP'] 
list_of_files = ['Thermal']

# Dataset names used during H5PY file creation (dsets variable from dataset_creator.py)
# list_of_datasets = ['Thermal_T3','ONI_IR_T','IP_T']
list_of_datasets = ['Thermal_T3'] 

# List of pre-trained model weight location if wanting to test trained model 
# list_of_models = ['x','x','x','x','x','x'] # after training - it will save them in the Models folder
list_of_models = ['x'] # after training - it will save them in the Models folder

script_directory=os.getcwd()
project_directory=os.path.dirname(script_directory)

for i in range(len(list_of_datasets)):
    modelpath = list_of_models[i]
    name = list_of_datasets[i]
    dset = list_of_files[i]
    path = "{}\Dataset\H5PY\Data_set-{}-imgdim64x64.h5".format(project_directory,name) 
    full_pipeline(name, dset, window_len, fair_comparison, path, stride, modelpath)