In [None]:
import os
import random
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import argparse

import torch
import torch.nn as nn

from sklearn.metrics import roc_auc_score,balanced_accuracy_score
from sklearn.preprocessing import OneHotEncoder

os.chdir("../") #Load from parent directory
from data_utils import Plots,gen_loader,load_datasets,compute_avg,log_data
from models import select_encoder
utils_plot=Plots()

In [None]:
def train_supervised(n_cross_val,data_type,model_name,tr_percentage,val_percentage,n_epochs,ablation,classifier_type,batch_size,
                     suffix,device,device_ids,window_size,encoder_type,encoding_size,lr,decay,
                     datasets,show_encodings,n_classes,verbose):
    res = {}
    train_accs, test_accs = {},{}
    train_losses, test_losses = {},{}
    val_accs,val_losses = {},{}
    test_aucs = {}
    
    for cv in range(n_cross_val):
        #Data
        train_loader,val_loader,test_loader = gen_loader(data_type,datasets,n_classes,tr_percentage,
                                                         val_percentage,window_size,batch_size,cv)

        #Load Location
        load_dir = './results/baselines/%s_%s/%s/'%(datasets,model_name,data_type)
            
        load_weights = str((load_dir +'encoding_%d_encoder_%d_checkpoint_%d%s.pth.tar')
               %(encoding_size,encoder_type, cv,suffix))
        
        if verbose:
            print('Loading from: ',load_weights)
        
        #Models
        input_size = [x.shape for (x,y) in train_loader][0][1]
        encoder,classifier = select_encoder(device,encoder_type,input_size,encoding_size,n_classes,classifier_type)
        encoder = encoder.to(device)
        classifier = classifier.to(device)
        
        checkpoint = torch.load(load_weights)
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        
        best_acc = 0
        train_accs[cv], test_accs[cv],val_accs[cv] = [],[],[]
        train_losses[cv], test_losses[cv],val_losses[cv] =[],[],[]
        test_aucs[cv] =[]
        
        #Define Optimizer and Loss
        loss_fn = torch.nn.CrossEntropyLoss()
        params = classifier.parameters()
        optimizer = torch.optim.Adam(params, lr=lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, n_epochs, gamma=0.99)

        for e in range(n_epochs):
            epoch_loss = 0
            epoch_acc = 0
            batch_count = 0

            for i,(x,y) in enumerate(train_loader):
                encoder.eval()
                classifier.train()
                
                optimizer.zero_grad()
                encodings = encoder(x.to(device))
                prediction = classifier(encodings.detach())
                
                state_prediction = torch.argmax(prediction, dim=1)

                loss = loss_fn(prediction, y.long().to(device))

                loss.backward()
                optimizer.step()

                epoch_acc += torch.eq(state_prediction.to('cpu'), y).sum().item()/len(x)
                epoch_loss += loss.item()
                batch_count += 1

            scheduler.step()
            
            if verbose:
                print('CV ',cv,' Epoch ',e,'Train Labels',tr_percentage)
            
            #Training Results
            train_loss,train_acc = epoch_loss / batch_count, epoch_acc / batch_count
            train_accs[cv].append(train_acc)
            train_losses[cv].append(train_loss)
            
            if verbose:
                print('Train Results: ',train_loss,train_acc)
                
            model=torch.nn.Sequential(encoder, classifier).to(device)
            val_loss,val_acc,_ = test_supervised(model,device,val_loader)
            val_accs[cv].append(val_acc)
            val_losses[cv].append(val_loss)
            
            if verbose:
                print('Validation Results: ',val_loss,val_acc)
            
            test_loss,test_acc,test_auc = test_supervised(model,device,test_loader,calc_auc=not(ablation))
            test_accs[cv].append(test_acc)
            test_losses[cv].append(test_loss)
            test_aucs[cv].append(test_auc)
            
            if verbose:
                print('Test Results: ',test_loss,test_acc,test_auc)
                print('')

    res['train_accs'] = train_accs
    res['train_losses'] = train_losses
    res['val_accs'] = val_accs
    res['val_losses'] = val_losses
    res['test_accs'] = test_accs
    res['test_losses'] = test_losses
    res['test_aucs'] = test_aucs
    
    return res


def test_supervised(model,device,data_loader,calc_auc=False):
    model.eval()
    
    loss_fn = torch.nn.CrossEntropyLoss()
    
    epoch_loss = 0
    epoch_acc = 0
    epoch_auc = 0
    batch_count = 0
    y_all, prediction_all = [], []
    
    for x, y in data_loader:
        #print(x.shape)
        prediction = model(x.to(device))
        state_prediction = torch.argmax(prediction, -1)
        loss = loss_fn(prediction, y.long().to(device))
        
        
        epoch_acc += torch.eq(state_prediction.to('cpu'), y).sum().item()/len(x)
        epoch_loss += loss.item()
        batch_count += 1
        
        y_all.append(y)
        prediction_all.append(prediction.detach().cpu().numpy())

    if calc_auc:
        y_all = np.concatenate(y_all, 0)
        prediction_all = np.concatenate(prediction_all, 0)

        prediction_class_all = np.argmax(prediction_all, -1)
        y_onehot_all = np.zeros(prediction_all.shape)
        y_onehot_all[np.arange(len(y_onehot_all)), y_all.astype(int)] = 1
        epoch_auc = roc_auc_score(y_onehot_all,prediction_all,multi_class='ovo')
        
    return epoch_loss / batch_count, epoch_acc / batch_count , epoch_auc

In [None]:
def save_res(args,res):
    if args['ablation']:
        name = './results/baselines/%s_%s/%s/labs_ssl_encoding_%d_encoder_%d_classifier_%d%s'%(
            args['datasets'],args['model_name'],args['data_type'],args['encoding_size'],args['encoder_type'],
            args['classifier_type'],args['suffix'])
        
        arr = [res['train_accs_dict'],res['train_losses_dict'],res['val_accs_dict'],res['val_losses_dict'],
               res['test_accs_dict'],res['test_losses_dict']]
        
        names = ['train_accs_dict','train_losses_dict','val_accs_dict','val_losses_dict',
         'test_accs_dict','test_losses_dict']
    else:
        
        name = './results/baselines/%s_%s/%s/log_ssl_encoding_%d_encoder_%d_classifier_%d%s'%(
            args['datasets'],args['model_name'],args['data_type'],args['encoding_size'],args['encoder_type'],
            args['classifier_type'],args['suffix'])
        
        arr = [res['train_accs'],res['train_losses'],res['val_accs'],res['val_losses'],res['test_accs'],res['test_losses']]
        names = ['train_accs','train_losses','val_accs','val_losses','test_accs','test_losses']
    
    log_data(name,arr,names)
    
    return

In [None]:
def run_sup(args):
    #Run Process
    res = train_supervised(**args)
    
    if args['show_encodings']:
        
        #Plot Accuracy/Loss
        utils_plot.plot_acc_loss('Fully Supervised',compute_avg(res['train_accs']),compute_avg(res['test_accs']),
                      compute_avg(res['train_losses']),compute_avg(res['test_losses']))

        #Plot Features
        
        title = 'Fully Supervised Encoding TSNE for %s'%(args['data_type'])
        for cv in range(args['n_cross_val']):
            train_data,train_labels,test_data,test_labels = load_datasets(args['data_type'],args['datasets'],cv)
            utils_plot.plot_distribution(test_data, test_labels,args['encoder_type'],
                                         args['encoding_size'],args['window_size'],'sup',
                                         args['datasets'],args['data_type'],args['suffix'],
                                         args['device'], title, cv)
    return res

In [None]:
def calc_metrics(args,res):
    metrics = {}
    max_acc = np.mean(np.max(np.array(list(res['test_accs'].values())),axis=1).flatten())
    
    #Calculate Final Test Accuracy using highest validation accuracy
    ids = [np.argmax(res['val_accs'][i]) for i in range(args['n_cross_val'])]
    nums = [res['test_accs'][i][ids[i]] for i in range(args['n_cross_val'])]
    
    final_acc = np.mean(nums)
    final_diff = max((np.mean(nums) - min(nums)),(max(nums)-np.mean(nums)))
            
    ids = [np.argmax(res['val_accs'][i]) for i in range(args['n_cross_val'])]
    nums = [res['test_aucs'][i][ids[i]] for i in range(args['n_cross_val'])]
    
    final_auc = np.mean(nums)
    final_auc_diff = max((np.mean(nums) - min(nums)),(max(nums)-np.mean(nums)))
    
    metrics['final_acc'] = final_acc
    metrics['final_diff'] = final_diff
    metrics['final_auc'] = final_auc
    metrics['final_auc_diff'] = final_auc_diff
    metrics['max_acc'] = max_acc
    
    
    return metrics

In [None]:
def main(args):

    #Devices
    args['device'] = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    args['device_ids'] = [i for i in range(torch.cuda.device_count())]
    print('Using', args['device'])

    #Experiments
    if args['data_type']=='afdb':
        args['n_classes'] = 4
    elif args['data_type']=='ims':
        args['n_classes'] = 5
    elif args['data_type']=='urban':
        args['n_cross_val'] = 10
        args['n_classes'] = 10 
        
    #Experiment Parameters
    args['window_size'] = 2500
    args['encoder_type'] = 1
    args['encoding_size'] = 128
    args['lr'] = 1e-3
    args['decay'] = 1e-5
    args['datasets'] = args['data_type']
    args['n_classes'] = 10
        
    if args['ablation'] == False:
        results = run_sup(args)
        save_res(args,results)

        metrics = calc_metrics(args,results)
        print(metrics)

    else:
        train_accs_dict,train_losses_dict = {},{}
        test_accs_dict,test_losses_dict = {},{}
        val_accs_dict,val_losses_dict = {},{}
        args['show_encodings'] = False
        tr = [0.01,0.05,0.1,0.15,0.2,0.3,0.4,0.5,0.6,0.7,0.8]

        for train_per in tr:
            args['tr_percentage']= train_per
            results = run_sup(args)

            train_accs_dict[train_per] = results['train_accs']
            train_losses_dict[train_per] = results['train_losses']
            test_accs_dict[train_per] = results['test_accs']
            test_losses_dict[train_per] = results['test_losses']
            val_accs_dict[train_per] = results['val_accs']
            val_losses_dict[train_per] = results['val_losses']

            results['train_accs_dict'] = train_accs_dict
            results['train_losses_dict'] = train_losses_dict
            results['test_accs_dict'] = test_accs_dict
            results['test_losses_dict'] = test_losses_dict
            results['val_accs_dict'] = val_accs_dict
            results['val_losses_dict'] = val_losses_dict

        save_res(args,results)
        utils_plot.plot_ablation(args['n_cross_val'],tr,val_accs_dict,test_accs_dict,test_losses_dict)
    return

In [None]:
args = {'n_cross_val':10,
        'model_name':'cpc', #options: cpc, tloss, tnc, simclr
        'data_type':'afdb', #options: afdb, ims, urban
        'tr_percentage':0.8,
        'val_percentage': 0.2,
        'n_epochs':15,
        'ablation': False,
        'classifier_type':0,
        'batch_size': 100,
        'suffix':'',
        'verbose':True,
        'show_encodings': False} 

main(args)