# Notebook to demonstrate summe label issues and validation split consistency 

In [2]:
import torch
from Model import model_dict,params_dict
from torch.utils.data import DataLoader
from tqdm import tqdm
import json
import os
from Utils import *
import torch.optim as optim
from Data import VideoData

import sys
import numpy as np
import torch.nn.init as init
import argparse
import h5py
results_dir = "Results"
weights_path = 'weights'

In [None]:
def train(config_path,save_path = 'weights'):
    with open(config_path,'r') as config_file:
        config = json.load(config_file)
    
    assert config['Model'] in model_dict.keys(), "Model is not available, modify dictionary to include them or check spelling"
    dataset_name = config['split'].split("_")[0]
    split_string = config['split'].strip(dataset_name).strip('.json')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    modelclass = model_dict[config['Model']]
    criterion = loss_dict[config['loss_function']]()
    num_epochs = config["num_epochs"]
    feature_extractor = config['feature_extractor']
    save_name = f'{feature_extractor}_{dataset_name}{split_string}'
    if not os.path.exists(os.path.join(save_path,save_name,dataset_name,config['Model'] )):
        os.makedirs(os.path.join(save_path,save_name,dataset_name,config['Model'] ))


    save_path = os.path.join(save_path,save_name,dataset_name,config['Model'] )
    print(save_name)

    params = params_dict[config['Model']][config['feature_extractor']]

    if config['data_aug'] :
        #data_augmentations  = [shuffle_dict[data_aug](**config['data_aug'][data_aug]) for data_aug in config['data_aug']]
        pass
    else:
        data_augmentations = []
    splits = config['total_splits'] if 'total_splits' in config.keys() else 5
    #dataset = h5py.File(config['datapath']+'.h5')
    dataset = h5py.File(os.path.join('Data',config['feature_extractor'],f'{config["feature_extractor"]}_{dataset_name}.h5'))
    print(params)
    for split in range(splits):
        print(f"Running Split:  {split+1}  for model: {config['Model']}")
        model = modelclass(**params)
        batchloader = VideoData('train',config['split'],split,transforms=data_augmentations,feature_extractor=feature_extractor,trainval=True)
        batchloader = DataLoader(batchloader,batch_size=1,shuffle=True)
        testdata = VideoData('test',config['split'],split,feature_extractor=feature_extractor,trainval=True)
        testloader = DataLoader(testdata,batch_size=1,shuffle=False)
        optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"],weight_decay=config['reg'])
        best_f1_score = -float('inf')
        best_correlation = -float('inf')
        model.to(device)
        if 'gradnorm_clip' in config:
            gradnorm_clip = config['gradnorm_clip']
        else:
            gradnorm_clip = 3
        # Make the directory for the split if it doesn't exist 
        if not os.path.exists(os.path.join(save_path,f'split_{split+1}')):
            os.mkdir(os.path.join(save_path,f'split_{split+1}'))
        save_path_split = os.path.join(save_path,f'split_{split+1}')
        for epoch in range(num_epochs):
            model.train()
            running_loss = 0.0
            total_samples = 0

            for data in tqdm(batchloader, desc=f"Epoch {epoch+1}/{num_epochs}", ncols=100):
                inputs, labels = data[0].to(device), data[1].to(device)
                optimizer.zero_grad()
                labels-=labels.min()
                labels/=labels.max()
                outputs = model(inputs)
                if len(outputs.shape)>2:
                    outputs = outputs.squeeze(-1)
                loss = criterion(outputs, labels)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradnorm_clip)
                optimizer.step()
                running_loss += loss.item()
                total_samples+=1
            epoch_loss = running_loss / len(batchloader)
            print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}")

            model.eval()
            test_datapoints = []
            test_names = []
            
# Adding the correlation scores to have the picks from the datapoints 
            print(f"Compute F1 and Correlation for epoch: {epoch+1}")
            for inputs_t,names in tqdm(testloader,ncols=len(testdata)):
                with torch.no_grad():
                    importance_scores = model(inputs_t.to(device))
                importance_scores = importance_scores[0].to('cpu').tolist()
                test_datapoints.append(importance_scores)
                test_names.append(names[0])
            all_scores = eval_summary(test_datapoints,dataset,test_names,dataset_name)
    
            correlation_dict = evaluate_correlation(test_datapoints ,dataset,test_names,dataset_name)
            
            if correlation_dict['Average_Kendall']> best_correlation:    
                print(f"Saving epoch {epoch+1}")
                best_correlation = correlation_dict['Average_Kendall']
                print(f"Best Correlation Score:  {epoch+1}: {correlation_dict['Average_Kendall']} ")  # CHange this here 
                torch.save(model.state_dict(), os.path.join(save_path_split,"best_run_corr" + ".pth")) 

            if np.mean(all_scores).item() > best_f1_score:
                best_f1_score = np.mean(all_scores).item()
                print(f"Best F1 Score:  {epoch+1}: {best_f1_score} ")
                torch.save(model.state_dict(), os.path.join(save_path_split,"best_run_f1" + ".pth"))

        print(f'Best F1 score for split {split+1}: {best_f1_score} ')
        print(f'Best Correlation for split {split+1}: {best_correlation} ')
    print('Completed Training')

In [None]:
def inference(config_path,delete_weights=True):
    with open(config_path,'r') as config_file:
        config = json.load(config_file)
    if delete_weights==True:
        print("IMPORTANT NOTE: WEIGHTS ARE DELETED, RERUN TRAINING TO GET THEM BACK")
    assert config['Model_params']['Model'] in model_dict.keys(), "Model is not available, modify dictionary to include them or check spelling"
    
    #TODO perhaps make this a bit more flexible with the names

    dataset_name = config['split'].split("_")[0]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # save_path = os.path.join(args.save_path,config['save_name'],dataset_name,config['Model_params']['Model'] )   
    feature_extractor = config['feature_extractor']    
    modelclass = model_dict[config['Model_params']['Model']]
    
    if "Params" in list(config['Model_params'].keys()):
        params = config['Model_params']['Params']
    else:
        params = {}   # Running checks and creating results for each experiment
        #Save it as results/experiment/model
    
    #---------------------------------------------------- BEST Correlation runs-------------------------------------------------------------------------
    
    if not os.path.exists(os.path.join(results_dir,config['save_name'],'best_corr',dataset_name,config['Model_params']['Model'])):
        os.makedirs(os.path.join(results_dir,config['save_name'],'best_corr',dataset_name,config['Model_params']['Model']))
    result_dir = os.path.join(results_dir,config['save_name'],'best_corr',dataset_name,config['Model_params']['Model'])
    # Check if the F1 score and the correlation directories for the results exist
    if not os.path.exists(os.path.join(result_dir,'F1')):
        os.mkdir(os.path.join(result_dir,'F1'))
    if not os.path.exists(os.path.join(result_dir,'Correlation')):
        os.mkdir(os.path.join(result_dir,'Correlation'))
    if not os.path.exists(os.path.join(result_dir,'Output')):
        os.mkdir(os.path.join(result_dir,'Output'))
    result_f1_dir = os.path.join(result_dir,'F1')
    result_corr_dir = os.path.join(result_dir,'Correlation')
    result_f1_json = os.path.join(result_dir,'F1','results.json')
    result_corr_json = os.path.join(result_dir,'Correlation','results.json')
    result_out_json = os.path.join(result_dir,'Output','outputs.json')



    splits = 5
    # Loading the weights:
        # Loading the weights as : "weights/experiment/model"
    weight_path = os.path.join(weights_path,config['save_name'],dataset_name,config['Model_params']['Model'])
    print(os.path.join(weights_path,config['save_name'],dataset_name,config['Model_params']['Model']))
    assert os.path.exists(os.path.join(weights_path,config['save_name'],dataset_name,config['Model_params']['Model'])), "Model weights do not exist or pathing is incorrect, check path"
    dataset = h5py.File(os.path.join(config['datapath']+'.h5'))
    
    # Needs to do: Inference -> Correlation -> Post Process -> F1 -> save results in results dir
    all_splits_f1_scores = {}
    all_splits_correlations = {}
    output_dict ={}
    for split in range(splits):
        model = modelclass(**params)
        testdata = VideoData('test',config['split'],split,feature_extractor=feature_extractor)
        testloader = DataLoader(testdata,batch_size=1,shuffle=False)
        weight_path_split = os.path.join(weight_path,f"split_{split+1}",'best_run_corr.pth')
        model.load_state_dict(torch.load(weight_path_split,map_location=device))
        name_list = []
        output_list = []
        model.eval()
        model.to(device)
        for inputs,names in tqdm(testloader,ncols=100):
            inputs = inputs.to(device)
            with torch.no_grad():
                importance_scores = model(inputs)
                if len(importance_scores.shape)>2:
                    importance_scores = importance_scores.squeeze(-1)
            importance_scores = importance_scores[0].to('cpu').tolist()
            output_list.append(importance_scores)
            name_list.append(names[0])
            output_dict[str(names[0])] = importance_scores
        result_f1_dict = generate_f1_results(output_list,dataset,name_list,dataset_name) # This needs to be dumped into a JSON for each split
        correlation_dict = evaluate_correlation(output_list,dataset,name_list,dataset_name) # This as well
        
        # Saving the split in the respective directories

        split_save_f1_name = os.path.join(result_f1_dir,f"split_{split+1}.json")
        split_save_corr_name = os.path.join(result_corr_dir,f"split_{split+1}.json")
        split_save_f1_noname = os.path.join(result_f1_dir,f"split_{split+1}_noname.json")
        split_save_corr_noname = os.path.join(result_corr_dir,f"split_{split+1}_noname.json")
        
        
        all_splits_f1_scores[f'split_{split+1}'] = result_f1_dict['Average F1']
        all_splits_correlations[f'split_{split+1}'] = {}
        all_splits_correlations[f'split_{split+1}']['Kendall']= correlation_dict['Average_Kendall']
        all_splits_correlations[f'split_{split+1}']['Spearman']= correlation_dict['Average_Spearman']
       
        # Pop the average keys and then use it to save the split with name of the original
        result_f1_dict.pop('Average F1')
        correlation_dict.pop('Average_Kendall')
        correlation_dict.pop('Average_Spearman')
        #TODO: Remember to fix the key issue from before when you are evaluating the correlation and F1
        with open(split_save_f1_noname,'w') as json_file:
            json.dump(result_f1_dict,json_file,indent=4)
        with open(split_save_corr_noname,'w') as json_file:
            json.dump(correlation_dict,json_file,indent=4)
        result_f1_dict,correlation_dict = change_key_names(result_f1_dict,correlation_dict,dataset_name)

        # TODO DUMP THE JSONS
        with open(split_save_f1_name,'w') as json_file:
            json.dump(result_f1_dict,json_file,indent=4)
        with open(split_save_corr_name,'w') as json_file:
            json.dump(correlation_dict,json_file,indent=4)
        
        if delete_weights==True:
            os.remove(os.path.join(weight_path,f"split_{split+1}",'best_run_corr.pth'))

        
    print("----------------------------------------------- Final set of results from the experiments for the best correlation weights ------------------------------------------------------------------------------------------------")
    print(all_splits_correlations)
    result_f1_dict_final,correlation_dict_final = compute_average_results(all_splits_f1_scores,all_splits_correlations)
    print()

    with open(result_f1_json,'w') as json_file:
        json.dump(result_f1_dict_final,json_file,indent = 4 )
    with open(result_corr_json,'w') as json_file:
        json.dump(correlation_dict_final,json_file,indent = 4 )
    with open(result_out_json,'w') as json_file:
        json.dump(output_dict,json_file,indent = 4 )
   

#------------------------------------------Best F1 score runs---------------------------------------------------------
    if not os.path.exists(os.path.join(results_dir,config['save_name'],'best_f1',dataset_name,config['Model_params']['Model'])):
        os.makedirs(os.path.join(results_dir,config['save_name'],'best_f1',dataset_name,config['Model_params']['Model']))
    result_dir = os.path.join(results_dir,config['save_name'],'best_f1',dataset_name,config['Model_params']['Model'])
    # Check if the F1 score and the correlation directories for the results exist
    if not os.path.exists(os.path.join(result_dir,'F1')):
        os.mkdir(os.path.join(result_dir,'F1'))
    if not os.path.exists(os.path.join(result_dir,'Correlation')):
        os.mkdir(os.path.join(result_dir,'Correlation'))
    if not os.path.exists(os.path.join(result_dir,'Output')):
        os.mkdir(os.path.join(result_dir,'Output'))
    result_f1_dir = os.path.join(result_dir,'F1')
    result_corr_dir = os.path.join(result_dir,'Correlation')
    result_f1_json = os.path.join(result_dir,'F1','results.json')
    result_corr_json = os.path.join(result_dir,'Correlation','results.json')
    result_out_json = os.path.join(result_dir,'Output','outputs.json')



    splits = 5
    # Loading the weights:
        # Loading the weights as : "weights/experiment/model"
    weight_path = os.path.join(weights_path,config['save_name'],dataset_name,config['Model_params']['Model'])
    assert os.path.exists(os.path.join(weights_path,config['save_name'],dataset_name,config['Model_params']['Model'])), "Model weights do not exist or pathing is incorrect, check path"
    dataset = h5py.File(config['datapath']+'.h5')
    
    # Needs to do: Inference -> Correlation -> Post Process -> F1 -> save results in results dir
    output_dict ={}
    all_splits_f1_scores = {}
    all_splits_correlations = {}
    for split in range(splits):
        model = modelclass(**params)
        testdata = VideoData('test',config['split'],split,feature_extractor=feature_extractor)
        testloader = DataLoader(testdata,batch_size=1,shuffle=False)
        weight_path_split = os.path.join(weight_path,f"split_{split+1}",'best_run_f1.pth')
        model.load_state_dict(torch.load(weight_path_split,map_location=device))
        
        name_list = []
        output_list = []
        model.eval()
        model.to(device)
        print('here')
        for inputs,names in tqdm(testloader,ncols=100):
            
            inputs = inputs.to(device)
            with torch.no_grad():
                importance_scores = model(inputs)
                if len(importance_scores.shape)>2:
                    importance_scores = importance_scores.squeeze(-1)
            importance_scores = importance_scores[0].to('cpu').tolist()
            output_list.append(importance_scores)
            name_list.append(names[0])
            output_dict[str(names[0])] = importance_scores
        
        result_f1_dict = generate_f1_results(output_list,dataset,name_list,dataset_name) # This needs to be dumped into a JSON for each split
        correlation_dict = evaluate_correlation(output_list,dataset,name_list,dataset_name) # This as well
        
        # Saving the split in the respective directories

        split_save_f1_name = os.path.join(result_f1_dir,f"split_{split+1}.json")
        split_save_corr_name = os.path.join(result_corr_dir,f"split_{split+1}.json")
        split_save_f1_noname = os.path.join(result_f1_dir,f"split_{split+1}_noname.json")
        split_save_corr_noname = os.path.join(result_corr_dir,f"split_{split+1}_noname.json")
        
        
        all_splits_f1_scores[f'split_{split+1}'] = result_f1_dict['Average F1']
        all_splits_correlations[f'split_{split+1}'] = {}
        all_splits_correlations[f'split_{split+1}']['Kendall']= correlation_dict['Average_Kendall']
        all_splits_correlations[f'split_{split+1}']['Spearman']= correlation_dict['Average_Spearman']
       
        # Pop the average keys and then use it to save the split with name of the original

        result_f1_dict.pop('Average F1')
        correlation_dict.pop('Average_Kendall')
        correlation_dict.pop('Average_Spearman')

        #TODO: Remember to fix the key issue from before when you are evaluating the correlation and F1
        with open(split_save_f1_noname,'w') as json_file:
            json.dump(result_f1_dict,json_file,indent=4)
        with open(split_save_corr_noname,'w') as json_file:
            json.dump(correlation_dict,json_file,indent=4)
        result_f1_dict,correlation_dict = change_key_names(result_f1_dict,correlation_dict,dataset_name)

        # TODO DUMP THE JSONS
        with open(split_save_f1_name,'w') as json_file:
            json.dump(result_f1_dict,json_file,indent=4)
        with open(split_save_corr_name,'w') as json_file:
            json.dump(correlation_dict,json_file,indent=4)
        if delete_weights==True:
            os.remove(os.path.join(weight_path,f"split_{split+1}",'best_run_f1.pth'))

        
    print("----------------------------------------------- Final set of results from the experiments for the best f1 score weights ------------------------------------------------------------------------------------------------")

    print(all_splits_correlations)
    result_f1_dict_final,correlation_dict_final = compute_average_results(all_splits_f1_scores,all_splits_correlations)


    with open(result_f1_json,'w') as json_file:
        json.dump(result_f1_dict_final,json_file,indent = 4 )
    with open(result_corr_json,'w') as json_file:
        json.dump(correlation_dict_final,json_file,indent = 4 )
        
    with open(result_out_json,'w') as json_file:
        json.dump(output_dict,json_file,indent = 4 )
   