# Notebook for evaluating Trained Models

In [None]:
%load_ext autoreload
%autoreload 2 

In [None]:
import sys
from pathlib import Path
from torchmetrics.functional import precision_recall
from torchmetrics import AveragePrecision,Accuracy
import torch
### adding model to path
sys.path.append('/home/jupyter/deepslide')
from src.datamodules.siamese_datamodule import Siamese_Landslide_Datamodule
from src.models.siamese_downstream_module import Segmentation_Model
from src.models.siamese_module import Siamese_Type_1
import numpy as np
import matplotlib.pyplot as plt
import pickle

In [None]:
import scipy

In [None]:
scipy.special.expit(0.)

### choose which experiment to evaluate

In [None]:
# settings
dataset         = 'Hokkaido'
pretraining     = 'Hokkaido'
# same as dataset but in small caption
dataset2        = 'hokkaido'  
loss            = 'dice'
experiment_name = 'segment_hokk_pretrain_hokk_cnn'
trainsize       = '5'
unet            = False

### you will need to adapt the model paths here

In [None]:
# pretrained models
net = Siamese_Type_1(**{'input_size':[2,128,128],'embedding_size':32,'unet':True,'decoder_depth':1, 'encoder_depth':1, 'cnn':True,'base_lr':0.001,'decoder_channels':[32]})
if pretraining=='Hokkaido':
    pretrain_path = '/home/jupyter/deepslide/logs/experiments/runs/Siamese_Type1_hokkaido/2022-09-27_12-14-28/checkpoints/epoch_269.ckpt'
elif pretraining=="Kaikoura":
    pretrain_path = '/home/jupyter/deepslide/logs/experiments/runs/Siamese_Type1_kaikora/2022-09-27_11-38-01/checkpoints/last.ckpt'
   
path = Path('/home/jupyter/deepslide/logs/experiments/runs/{}/'.format(experiment_name))

In [None]:
def compute_metrics(preds, targets, threshold=0.5):
    #print(preds)
    preds          =preds.view((preds.size()[0],-1))
    targets        =targets.view((targets.size()[0],-1))

    prec, rec      = precision_recall(preds, targets, threshold=threshold)
    f1_score       = 2*(prec*rec)/(prec+rec)

    # pr_curve       = PrecisionRecallCurve(num_classes=5)
    # precision, recall, thresholds = pr_curve(preds, targets)

    average_precision = AveragePrecision()
    AP_score       = average_precision(preds, targets)
    accuracy       = Accuracy(threshold=threshold).cuda()
    acc            = accuracy(preds, targets)
    return f1_score.item(), AP_score.item(), prec.item(), rec.item(), acc.detach().item()

### finds the right model files, evluates them, computes scores and dumps results

In [None]:
import json


filepaths1=[]
filepaths2=[]
for filepath in path.rglob("*config_tree.log"):
    file = open(filepath, 'r')
    lines = file.read().splitlines()
    for line in lines:
        if 'loss:' in line:
            if " "+loss in str(line):
                filepaths1.append(filepath.parent)

print(len(filepaths1))

for pp in filepaths1:
    for filepath in pp.glob("config_tree.log"):
        file = open(filepath, 'r')
        lines = file.read().splitlines()
        for line in lines:
            if 'pretrain_path:' in line:
                if pretrain_path in line:
                    filepaths2.append(filepath.parent)

for trainsize in ['2','5','10','20','-1']:
    model_dirs_augmented = []
    model_dirs = []
    timestamps_augmented =[]
    timestamps = []


    filepaths3=[]

    print(len(filepaths2))            
    for pp in filepaths2:
        for filepath in pp.rglob("*config_tree.log"):
            file = open(filepath, 'r')
            lines = file.read().splitlines()
            for line in lines:
                if 'trainsize' in line:
                    if trainsize in line:
                        filepaths3.append(filepath.parent)

    print(len(filepaths3))

    count1, count2= 0,0
    gss=[]

    for pp in filepaths3:
        count1+=1
        for filepath in pp.glob("*config_tree.log"):
            file = open(filepath, 'r')
            lines = file.read().splitlines()
            for line in lines:
                if '- hokkaido' in line:
                    count2+=1
                pp = filepath.parent/'checkpoints/'
                if 'pre_train_augmented' in line:
                    if 'true' in line:
                        if len(list(pp.glob("*.ckpt")))>3:
                            model_dirs_augmented.append([file for file in pp.glob("*.ckpt")])
                            with open(filepath.parent/'wandb/latest-run/files/wandb-summary.json') as user_file:
                                file_contents = user_file.read()
                                timestamps_augmented.append(dict(json.loads(file_contents))['_timestamp'])
                    else:
                        with open(filepath.parent/'wandb/latest-run/files/wandb-summary.json') as user_file:
                            file_contents = user_file.read()
                            timestamps.append(dict(json.loads(file_contents))['_timestamp'])
                        if len(list(pp.glob("*.ckpt")))>3:
                            model_dirs.append([file for file in pp.glob("*.ckpt")])
    print(len(model_dirs_augmented), len(model_dirs))                      
    if trainsize=='2':
        model_dirs_augmented = np.asarray(model_dirs_augmented)[np.argsort(np.asarray(timestamps_augmented))[-5::]].flatten()
        model_dirs = np.asarray(model_dirs)[np.argsort(np.asarray(timestamps))[-5::]].flatten()
    else:
        model_dirs_augmented = np.asarray(model_dirs_augmented)[np.argsort(np.asarray(timestamps_augmented))[-3::]].flatten()
        model_dirs = np.asarray(model_dirs)[np.argsort(np.asarray(timestamps))[-3::]].flatten()
    print(len(model_dirs_augmented), len(model_dirs)) 

    model_dirs

    datadict = {'data_dir': '/home/jupyter/deepslide/data/',
        'dict_dir': '/home/jupyter/deepslide/data/',
        'batch_size': 32,
        'num_workers': 8,
        'pin_memory': False,
        'input_channels': ['vh', 'vv'], #, 'los.rdr_0', 'los.rdr_1', 'topophase.cor_1', 'topophase.flat_imag', 'topophase.flat_real', 'dem']
        'input_transforms': ['Log_transform','Standardize'],
        'num_time_steps': 1,
        'setting': 'downstream',
        'datasets': [dataset2]}

    dloader = Siamese_Landslide_Datamodule(**datadict)
    test_loader = dloader.test_dataloader()


    # from collections import OrderedDict
    # new_dict=OrderedDict()
    # for key in pretrain_dict.keys():
    #     new_dict['pretrained.'+key] = pretrain_dict[key]

    mean_predictions={}
    targets={}
    metrics ={}   
    for name, pre_train_augmented, paths in zip(['pretrain','no_pretrain'],[True, False],[model_dirs_augmented, model_dirs]):

        metrics[name]={}
        mean_predictions[name]={}
        mean_predictions[name]['all']=np.zeros((56,128,128))
        mean_predictions[name]['nl']=np.zeros((24,128,128))
        mean_predictions[name]['l']=np.zeros((32,128,128))
        metrics[name]['APRC'] = []


        downstream_model_dict={    
            'input_size': [4,128,128],
            'embedding_size': 64,
            'pre_train_augmented': pre_train_augmented,
            'pretrain_path': pretrain_path,
            'unet': unet,
            'base_lr': 0.001,
            'pretrain_params': {'input_size':[2,128,128],'embedding_size':32,'decoder_depth':1,'encoder_depth':1,'unet':True,'cnn':True,
                                'base_lr':0.001,'decoder_channels':[32]},
            'encoder_depth': 1,
            'decoder_channels': [32],
            'loss': loss}

        count=0
        model = Segmentation_Model(**downstream_model_dict)
        #print(model)
        for path in paths:
            if not path.name=='last.ckpt':
                if 'ap' in path.name:
                    checkpoint    = torch.load(path)['state_dict']
                    # for key in checkpoint.keys():
                    #     print(key, checkpoint[key].shape)
                    # pretrain_dict = torch.load(downstream_model_dict['pretrain_path'])['state_dict']
                    # checkpoint.update(new_dict)
                    model.load_state_dict(checkpoint)
                    model.eval()
                    model.cuda()
                    for pre, post, label, names, weight in test_loader:
                        with torch.no_grad():
                            count+=1
                            preds  = model.forward(pre.cuda(),post.cuda())
                            label  = label.view(preds.size()).cuda()
                            mean_predictions[name]['all']+=np.squeeze(preds.detach().cpu().numpy())
                            targets['all'] = label.detach().cpu().numpy()
                            landslide=torch.sum(label, axis=(1,2,3))>0

                            preds_nl      = preds[~landslide]
                            targets['nl'] = label[~landslide].detach().cpu().numpy()
                            mean_predictions[name]['nl']+=np.squeeze(preds_nl.detach().cpu().numpy())
                            
                            preds_l      = preds[landslide]
                            targets['l'] = label[landslide].detach().cpu().numpy()
                            mean_predictions[name]['l']+=np.squeeze(preds_l.detach().cpu().numpy())
                            
                            preds    =torch.sigmoid(preds)

                            res  = compute_metrics(preds, label.cuda(), threshold=0.5)
  

                            metrics[name]['APRC'].append(res[1])
            if name=='pretrain':
                save_model=model

                            
    results={}
    results[dataset] = {}
    
    for name in ['pretrain','no_pretrain']:
        results[dataset][name] = {}
        for subfix in ['all','nl','l']:
            results[dataset][name]['APRC_'+subfix]=[]
            results[dataset][name]['f1_'+subfix]=[]
            results[dataset][name]['iou_'+subfix]=[]
            for ii in range(len(mean_predictions[name][subfix])):
                preds = scipy.special.expit(mean_predictions[name][subfix]/count)[ii:ii+1]
                res   = compute_metrics(torch.tensor(preds).cuda(), torch.tensor(targets[subfix]).cuda()[ii], threshold=0.5)
                results[dataset][name]['APRC_'+subfix].append(res[1])
                results[dataset][name]['f1_'+subfix].append(res[0])
                intersection = np.logical_and(preds>0.5,targets[subfix][ii])
                union        = np.logical_or(preds>0.5,targets[subfix][ii])
                iou_score    = np.sum(intersection) / np.sum(union)
                results[dataset][name]['iou_'+subfix].append(iou_score)
                
            results[dataset][name]['APRC_'+subfix]=np.asarray(results[dataset][name]['APRC_'+subfix])
            results[dataset][name]['f1_'+subfix]=np.asarray(results[dataset][name]['f1_'+subfix])
            results[dataset][name]['iou_'+subfix]=np.asarray(results[dataset][name]['iou_'+subfix])
            

            results[dataset][name]['APRC_'+subfix+'_mean']=np.mean(metrics[name]['APRC'])
            results[dataset][name]['f1_'+subfix+'_mean']=np.mean(results[dataset][name]['f1_'+subfix])
            results[dataset][name]['iou_'+subfix+'_mean']=np.mean(results[dataset][name]['iou_'+subfix])

            results[dataset][name]['APRC_'+subfix+'_median']=np.median(metrics[name]['APRC'])
            results[dataset][name]['f1_'+subfix+'_median']=np.median(results[dataset][name]['f1_'+subfix])
            results[dataset][name]['iou_'+subfix+'_median']=np.median(results[dataset][name]['iou_'+subfix])

            results[dataset][name]['APRC_'+subfix+'_std']=np.std(metrics[name]['APRC'])/len(metrics[name]['APRC'])
            results[dataset][name]['f1_'+subfix+'_std']=np.std(results[dataset][name]['f1_'+subfix])/len(results[dataset][name]['f1_'+subfix])
            results[dataset][name]['iou_'+subfix+'_std']=np.std(results[dataset][name]['iou_'+subfix])/len(results[dataset][name]['iou_'+subfix])


            mean_predictions[name][subfix]=np.squeeze(scipy.special.expit(mean_predictions[name][subfix]/count))
            results[dataset][name]['l1_'+subfix+'_mean'] = np.mean(np.abs(np.sum(mean_predictions[name][subfix], axis=(1,2))-np.sum(targets[subfix], axis=(1,2,3))))
            results[dataset][name]['sl1_'+subfix+'_mean']= np.mean(np.sum(mean_predictions[name][subfix], axis=(1,2))-np.sum(targets[subfix], axis=(1,2,3)))

            results[dataset][name]['l1_'+subfix+'_median'] = np.median(np.abs(np.sum(mean_predictions[name][subfix], axis=(1,2))-np.sum(targets[subfix], axis=(1,2,3))))
            results[dataset][name]['sl1_'+subfix+'_median']= np.median(np.sum(mean_predictions[name][subfix], axis=(1,2))-np.sum(targets[subfix], axis=(1,2,3)))

            results[dataset][name]['l1_'+subfix+'_std']  = np.std(np.abs(np.sum(mean_predictions[name][subfix], axis=(1,2))-np.sum(targets[subfix], axis=(1,2,3))))/len(targets[subfix])
            results[dataset][name]['sl1_'+subfix+'_std'] = np.std(np.sum(mean_predictions[name][subfix], axis=(1,2))-np.sum(targets[subfix], axis=(1,2,3)))/len(targets[subfix])

    pickle.dump(results, open('scores_{}_{}.pkl'.format(experiment_name, trainsize), 'wb'))

    pickle.dump(mean_predictions,open('label_predictions_{}_{}.pkl'.format(experiment_name, trainsize), 'wb'))