## This notebook runs experiments where we make predictions using only 1 feature and then using all but 1 feature

In [None]:
from Data.Drosophilla.FlyDataMod import FlyDataModule
from Models.Transformer import PositionalEncoding, TransformerModule
from pytorch_lightning.callbacks import EarlyStopping
from IPython.core.debugger import set_trace
import pytorch_lightning as pl
import os
import matplotlib.pyplot as plt

In [None]:
cell_line="S2"
data_win_radius=5
batch_size=4
label_type="gamma"
for exclude_feature in list(reversed(range(0,29))):
    
    early_stop_callback = EarlyStopping(
    monitor="val weighted mse loss",
    min_delta=0.00,
    patience=3,
    verbose=False,
    mode='min')
    
    dm = FlyDataModule(cell_line=cell_line,
                  data_win_radius=data_win_radius,
                  batch_size=batch_size,
                  label_type=label_type,
                  exclude_feature=exclude_feature)
    dm.setup()
    
    hparams={'cell_line':cell_line,
            'data_win_radius':data_win_radius,
            'label_type':label_type,
            "batch_size":batch_size,
            "exclude_feature":exclude_feature}
    
    model = TransformerModule(
                ntoken=1,
                ninp=28,
                nhead=7,
                nhid=2048,
                nlayers=1,
                dropout=0,
                optimi="Adam",
                lr=0.001,
                hparams=hparams)
    
    rootdir = "Experiments/Transformer_Exclude_Features"
    if not os.path.isdir(rootdir):
        os.mkdir(rootdir)
        
    trainer = pl.Trainer(
                gpus=1,
                max_epochs=50,
                default_root_dir=rootdir,
                callbacks = [early_stop_callback])

    trainer.fit(model, dm)

In [None]:
cell_line="S2"
data_win_radius=5
batch_size=4
label_type="insulation"
label_val=3
for exclude_feature in list(reversed(range(0,29))):
    
    early_stop_callback = EarlyStopping(
    monitor="val weighted mse loss",
    min_delta=0.00,
    patience=3,
    verbose=False,
    mode='min')
    
    dm = FlyDataModule(cell_line=cell_line,
                  data_win_radius=data_win_radius,
                  batch_size=batch_size,
                  label_type=label_type,
                  label_val=label_val,
                  exclude_feature=exclude_feature)
    dm.setup()
    
    hparams={'cell_line':cell_line,
            'data_win_radius':data_win_radius,
            'label_type':label_type,
            'label_val':label_val,
            "batch_size":batch_size,
            "exclude_feature":exclude_feature}
    
    model = TransformerModule(
                ntoken=1,
                ninp=28,
                nhead=7,
                nhid=2048,
                nlayers=3,
                dropout=0,
                optimi="Adam",
                lr=0.00001,
                hparams=hparams)
    
    rootdir = "Experiments/Transformer_Exclude_Features_Insulation"
    if not os.path.isdir(rootdir):
        os.mkdir(rootdir)
        
    trainer = pl.Trainer(
                gpus=1,
                max_epochs=50,
                default_root_dir=rootdir,
                callbacks = [early_stop_callback])

    trainer.fit(model, dm)

In [None]:
cell_line="S2"
data_win_radius=5
batch_size=1
label_type="directionality"
label_val=10
for exclude_feature in list(reversed(range(0,29))):
    
    early_stop_callback = EarlyStopping(
    monitor="val weighted mse loss",
    min_delta=0.00,
    patience=3,
    verbose=False,
    mode='min')
    
    dm = FlyDataModule(cell_line=cell_line,
                  data_win_radius=data_win_radius,
                  batch_size=batch_size,
                  label_type=label_type,
                  label_val=label_val,
                  exclude_feature=exclude_feature)
    dm.setup()
    
    hparams={'cell_line':cell_line,
            'data_win_radius':data_win_radius,
            'label_type':label_type,
            'label_val':label_val,
            "batch_size":batch_size,
            "exclude_feature":exclude_feature}
    
    model = TransformerModule(
                ntoken=1,
                ninp=28,
                nhead=7,
                nhid=2048,
                nlayers=1,
                dropout=0,
                optimi="Adam",
                lr=0.001,
                hparams=hparams)
    
    rootdir = "Experiments/Transformer_Exclude_Features_Direction"
    if not os.path.isdir(rootdir):
        os.mkdir(rootdir)
        
    trainer = pl.Trainer(
                gpus=1,
                max_epochs=50,
                default_root_dir=rootdir,
                callbacks = [early_stop_callback])

    trainer.fit(model, dm)

In [None]:
label_types = ['gamma']
cell_line="S2"
data_win_radius=5
batch_size=4

for label_type in label_types:
    for solo_feature in list(reversed(range(0,29))):
        early_stop_callback = EarlyStopping(
        monitor="val weighted mse loss",
        min_delta=0.00,
        patience=3,
        verbose=False,
        mode='min')
        
        dm = FlyDataModule(cell_line="S2",
                  data_win_radius=5,
                  batch_size=64,
                  label_type=label_type,
                  label_val=10,
                  solo_feature=solo_feature)
        dm.setup()

        hparams={'cell_line':cell_line,
            'data_win_radius':data_win_radius,
            'label_type':label_type,
            "batch_size":batch_size,
            "solo_feature":solo_feature}
        
        model = TransformerModule(
                ntoken=1,
                ninp=1,
                nhead=7,
                nhid=2048,
                nlayers=5,
                dropout=0,
                optimi="Adam",
                lr=0.01,
                hparams=hparams)

        rootdir = "Experiments/Transformer_Solo_Features_"+str(label_type)
        if not os.path.isdir(rootdir):
            os.mkdir(rootdir)
        
        trainer = pl.Trainer(
                gpus=1,
                max_epochs=50,
                default_root_dir=rootdir,
                callbacks=[early_stop_callback])

        trainer.fit(model, dm)

In [None]:
label_types = ['insulation']
label_val    = 3
cell_line="S2"
data_win_radius=5
batch_size=4

for label_type in label_types:
    for solo_feature in list(reversed(range(0,29))):
        early_stop_callback = EarlyStopping(
        monitor="val weighted mse loss",
        min_delta=0.00,
        patience=3,
        verbose=False,
        mode='min')
        
        dm = FlyDataModule(cell_line="S2",
                  data_win_radius=5,
                  batch_size=64,
                  label_type=label_type,
                  label_val=label_val,
                  solo_feature=solo_feature)
        dm.setup()

        hparams={'cell_line':cell_line,
            'data_win_radius':data_win_radius,
            'label_type':label_type,
            "batch_size":batch_size,
            "solo_feature":solo_feature}
        
        model = TransformerModule(
                ntoken=1,
                ninp=1,
                nhead=7,
                nhid=2048,
                nlayers=5,
                dropout=0,
                optimi="Adam",
                lr=0.01,
                hparams=hparams)

        rootdir = "Experiments/Transformer_Solo_Features_"+str(label_type)
        if not os.path.isdir(rootdir):
            os.mkdir(rootdir)
        
        trainer = pl.Trainer(
                gpus=1,
                max_epochs=50,
                default_root_dir=rootdir,
                callbacks=[early_stop_callback])

        trainer.fit(model, dm)

In [None]:
label_types = ['directionality']
label_val    = 10
cell_line="S2"
data_win_radius=5
batch_size=4

for label_type in label_types:
    for solo_feature in list(reversed(range(0,29))):
        early_stop_callback = EarlyStopping(
        monitor="val weighted mse loss",
        min_delta=0.00,
        patience=3,
        verbose=False,
        mode='min')
        
        dm = FlyDataModule(cell_line="S2",
                  data_win_radius=5,
                  batch_size=64,
                  label_type=label_type,
                  label_val=label_val,
                  solo_feature=solo_feature)
        dm.setup()

        hparams={'cell_line':cell_line,
            'data_win_radius':data_win_radius,
            'label_type':label_type,
            "batch_size":batch_size,
            "solo_feature":solo_feature}
        
        model = TransformerModule(
                ntoken=1,
                ninp=1,
                nhead=7,
                nhid=2048,
                nlayers=5,
                dropout=0,
                optimi="Adam",
                lr=0.01,
                hparams=hparams)

        rootdir = "Experiments/Transformer_Solo_Features_"+str(label_type)
        if not os.path.isdir(rootdir):
            os.mkdir(rootdir)
        
        trainer = pl.Trainer(
                gpus=1,
                max_epochs=50,
                default_root_dir=rootdir,
                callbacks=[early_stop_callback])

        trainer.fit(model, dm)

## Helper functions for evaluation

In [None]:
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
from scipy.stats import pearsonr
from scipy.stats import spearmanr


def getModelPredictions(model,
                   dm,
                   tvt):
    if tvt=="test":
        dataloader = dm.test_dataloader()
    if tvt=="train":
        dataloader = dm.train_dataloader()
    if tvt=="val":
        dataloader =dm.val_dataloader()
        
    full_label_vec  = []
    full_output_vec = []
    for b, batch in enumerate(dataloader):
        feature, label = batch
        feature = feature.to('cuda:0').float()
        label   = label.to('cuda:0').float()
        output  = model(feature)
        label   = label.squeeze()
        output  = output.squeeze()
        full_label_vec.append(label[int(len(label)/2)].item())
        full_output_vec.append(output[int(len(output)/2)].item())

    return full_label_vec, full_output_vec

def getModelMetrics(model,
                   dm,
                   tvt):
    label_vec, output_vec = getModelPredictions(model,
                                               dm,
                                               tvt)
    scores             = {}
    scores['mse']      = mean_squared_error(label_vec, output_vec)
    scores['mae']      = mean_absolute_error(label_vec, output_vec)
    scores['r2']       = r2_score(label_vec, output_vec)
    scores['pearson']  = pearsonr(label_vec, output_vec)[0]
    scores['spearman'] = spearmanr(label_vec, output_vec)[0]
    return scores

In [None]:
#This shows comparison of model when using only a single value
import glob

metrics = ['mse','mae','r2','pearson','spearman']
results = {}
for i, solo_feature in enumerate(range(0,29)):
        dm = FlyDataModule(cell_line="S2",
                  data_win_radius=5,
                  batch_size=1,
                  label_type='gamma',
                  label_val=10,
                  solo_feature=solo_feature)
        dm.setup()
        layer_weights = glob.glob("Experiments/Transformer_Solo_Features_gamma/lightning_logs/version_"+str(i)+"/checkpoints/*")[0]
        model   = TransformerModule.load_from_checkpoint(layer_weights).to("cuda:0")
        met_ret = getModelMetrics(model, dm, 'test')
        for metric in metrics:
            if metric not in results.keys():
                results[metric] = []
            results[metric].append(met_ret[metric])

import pandas as pd
exf = pd.read_csv("Data/Drosophilla/s2_kc_bg_scaled_18_features_2901.csv")
excluded_features = exf.columns[6:]
for metric in metrics:
    results_sorted           = results[metric]
    results_sorted           = sorted(results_sorted)
    excluded_features_sorted = [x for _, x in sorted(zip(results[metric],excluded_features), key=lambda pair: pair[0])]
    fig, ax = plt.subplots(1, figsize=(18,4))
    ax.set_xticks(list(range(0, len(excluded_features))))
    ax.set_xticklabels(excluded_features_sorted, rotation=45, fontsize=14)
    ax.set_title(metric)
    ax.scatter(list(range(0, len(results[metric]))) ,results_sorted)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.show()

In [None]:
##This shows comparison of model when excluding a single value
import glob

metrics = ['mse','mae','r2','pearson','spearman']
results = {}
for i, exclude_feature in enumerate(range(0,29)):
        dm = FlyDataModule(cell_line="S2",
                  data_win_radius=5,
                  batch_size=1,
                  label_type='gamma',
                  label_val=10,
                  exclude_feature=exclude_feature)
        dm.setup()
        layer_weights = glob.glob("Experiments/Transformer_Exclude_Features/lightning_logs/version_"+str(i)+"/checkpoints/*")[0]
        model   = TransformerModule.load_from_checkpoint(layer_weights).to("cuda:0")
        met_ret = getModelMetrics(model, dm, 'test')
        for metric in metrics:
            if metric not in results.keys():
                results[metric] = []
            results[metric].append(met_ret[metric])


import pandas as pd
exf = pd.read_csv("Data/Drosophilla/s2_kc_bg_scaled_18_features_2901.csv")
excluded_features = exf.columns[6:]
for metric in metrics:
    results_sorted           = results[metric]
    results_sorted           = sorted(results_sorted)
    excluded_features_sorted = [x for _, x in sorted(zip(results[metric],excluded_features), key=lambda pair: pair[0])]
    fig, ax = plt.subplots(1, figsize=(18,4))
    ax.set_xticks(list(range(0, len(excluded_features))))
    ax.set_xticklabels(excluded_features_sorted, rotation=45, fontsize=14)
    ax.set_title(metric)
    ax.scatter(list(range(0, len(results[metric]))) ,results_sorted)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.show()

In [None]:
##This shows comparison of model when excluding a single value
import glob

metrics = ['mse','mae','r2','pearson','spearman']
results = {}
for i, exclude_feature in enumerate(range(0,29)):
        dm = FlyDataModule(cell_line="S2",
                  data_win_radius=5,
                  batch_size=1,
                  label_type='insulation',
                  label_val=3,
                  exclude_feature=exclude_feature)
        dm.setup()
        layer_weights = glob.glob("Experiments/Transformer_Exclude_Features_Insulation/lightning_logs/version_"+str(i)+"/checkpoints/*")[0]
        model   = TransformerModule.load_from_checkpoint(layer_weights).to("cuda:0")
        met_ret = getModelMetrics(model, dm, 'test')
        for metric in metrics:
            if metric not in results.keys():
                results[metric] = []
            results[metric].append(met_ret[metric])


import pandas as pd
exf = pd.read_csv("Data/Drosophilla/s2_kc_bg_scaled_18_features_2901.csv")
excluded_features = exf.columns[6:]
for metric in metrics:
    results_sorted           = results[metric]
    results_sorted           = sorted(results_sorted)
    excluded_features_sorted = [x for _, x in sorted(zip(results[metric],excluded_features), key=lambda pair: pair[0])]
    fig, ax = plt.subplots(1, figsize=(18,4))
    ax.set_xticks(list(range(0, len(excluded_features))))
    ax.set_xticklabels(excluded_features_sorted, rotation=45, fontsize=14)
    ax.set_title(metric)
    ax.scatter(list(range(0, len(results[metric]))) ,results_sorted)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.show()

In [None]:
##This shows comparison of model when excluding a single value
import glob

metrics = ['mse','mae','r2','pearson','spearman']
results = {}
for i, exclude_feature in enumerate(range(0,29)):
        dm = FlyDataModule(cell_line="S2",
                  data_win_radius=5,
                  batch_size=1,
                  label_type='directionality',
                  label_val=10,
                  exclude_feature=exclude_feature)
        dm.setup()
        layer_weights = glob.glob("Experiments/Transformer_Exclude_Features_Direction/lightning_logs/version_"+str(i)+"/checkpoints/*")[0]
        model   = TransformerModule.load_from_checkpoint(layer_weights).to("cuda:0")
        met_ret = getModelMetrics(model, dm, 'test')
        for metric in metrics:
            if metric not in results.keys():
                results[metric] = []
            results[metric].append(met_ret[metric])


import pandas as pd
exf = pd.read_csv("Data/Drosophilla/s2_kc_bg_scaled_18_features_2901.csv")
excluded_features = exf.columns[6:]
for metric in metrics:
    results_sorted           = results[metric]
    results_sorted           = sorted(results_sorted)
    excluded_features_sorted = [x for _, x in sorted(zip(results[metric],excluded_features), key=lambda pair: pair[0])]
    fig, ax = plt.subplots(1, figsize=(18,4))
    ax.set_xticks(list(range(0, len(excluded_features))))
    ax.set_xticklabels(excluded_features_sorted, rotation=45, fontsize=14)
    ax.set_title(metric)
    ax.scatter(list(range(0, len(results[metric]))) ,results_sorted)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.show()

In [None]:
##This shows comparison of model when excluding a single value
import glob

metrics = ['mse','mae','r2','pearson','spearman']
results = {}
for i, exclude_feature in enumerate(range(0,29)):
        dm = FlyDataModule(cell_line="S2",
                  data_win_radius=5,
                  batch_size=1,
                  label_type='gamma',
                  label_val=10,
                  exclude_feature=exclude_feature)
        dm.setup()
        layer_weights = glob.glob("Experiments/Transformer_Exclude_Features/lightning_logs/version_"+str(i)+"/checkpoints/*")[0]
        model   = TransformerModule.load_from_checkpoint(layer_weights).to("cuda:0")
        met_ret = getModelMetrics(model, dm, 'test')
        for metric in metrics:
            if metric not in results.keys():
                results[metric] = []
            results[metric].append(met_ret[metric])

            

fig, ax = plt.subplots(figsize=(20,5))
ax.set_title("Gamma")
twin1 = ax.twinx()
twin2 = ax.twinx()
twin3 = ax.twinx()
twin2.spines['left'].set_position(("axes", -0.1))
twin2.yaxis.set_label_position('left')
twin2.yaxis.set_ticks_position('left')
twin3.spines['right'].set_position(("axes", 1.1))

ax.spines['top'].set_visible(False)
twin1.spines['top'].set_visible(False)
twin2.spines['top'].set_visible(False)
twin3.spines['top'].set_visible(False)

sort_metric                = 'spearman'

excluded_features_sorted   = [x for _, x in sorted(zip(results[sort_metric], excluded_features),
                                                  key=lambda pair: pair[0])]
sorted_mse                 = [x for _, x in sorted(zip(results[sort_metric], 
                                                      results['mse']),
                                                 key=lambda pair:pair[0])]
sorted_mae                 = [x for _, x in sorted(zip(results[sort_metric], 
                                                      results['mae']),
                                                 key=lambda pair:pair[0])]
sorted_r2                 = [x for _, x in sorted(zip(results[sort_metric], 
                                                      results['r2']),
                                                 key=lambda pair:pair[0])]
sorted_pcc                 = [x for _, x in sorted(zip(results[sort_metric], 
                                                      results['pearson']),
                                                 key=lambda pair:pair[0])]
sorted_spc                 = [x for _, x in sorted(zip(results[sort_metric], 
                                                      results['spearman']),
                                                 key=lambda pair:pair[0])]

p0, = ax.plot(sorted_mse, label=metrics[0], c='green')
p1, = twin1.plot(sorted_mae, label=metrics[1], c='goldenrod')
p2, = twin2.plot(sorted_pcc, label=metrics[2], c='violet')
p3, = twin3.plot(sorted_spc, label=metrics[3], c='royalblue')
ax.set_xticks(list(range(0, len(excluded_features))))
ax.set_xticklabels(excluded_features_sorted, rotation=45, fontsize=14)
#p4, = twin1.plot(sorted_spc, label=metrics[4], c='goldenrod')
#ax.set_xticks(excluded_features_sorted)


ax.yaxis.label.set_color(p0.get_color())
ax.tick_params(axis='y', colors=p0.get_color())
twin1.yaxis.label.set_color(p1.get_color())
twin1.tick_params(axis='y', colors=p1.get_color())
twin2.yaxis.label.set_color(p2.get_color())
twin2.tick_params(axis='y', colors=p2.get_color())
twin3.yaxis.label.set_color(p3.get_color())
twin3.tick_params(axis='y', colors=p3.get_color())

plt.show()

'''
import pandas as pd
exf = pd.read_csv("Data/Drosophilla/s2_kc_bg_scaled_18_features_2901.csv")
excluded_features = exf.columns[6:]
for metric in metrics:
    results_sorted           = results[metric]
    results_sorted           = sorted(results_sorted)
    excluded_features_sorted = [x for _, x in sorted(zip(results[metric],excluded_features), key=lambda pair: pair[0])]
    fig, ax = plt.subplots(1, figsize=(18,4))
    ax.set_xticks(list(range(0, len(excluded_features))))
    ax.set_xticklabels(excluded_features_sorted, rotation=45, fontsize=14)
    ax.set_title(metric)
    ax.scatter(list(range(0, len(results[metric]))) ,results_sorted)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.show()
'''

In [None]:
##This shows comparison of model when excluding a single value
import glob

metrics = ['mse','mae','r2','pearson','spearman']
results = {}
for i, exclude_feature in enumerate(range(0,29)):
        dm = FlyDataModule(cell_line="S2",
                  data_win_radius=5,
                  batch_size=1,
                  label_type='insulation',
                  label_val=3,
                  exclude_feature=exclude_feature)
        dm.setup()
        layer_weights = glob.glob("Experiments/Transformer_Exclude_Features_Insulation/lightning_logs/version_"+str(i)+"/checkpoints/*")[0]
        model   = TransformerModule.load_from_checkpoint(layer_weights).to("cuda:0")
        met_ret = getModelMetrics(model, dm, 'test')
        for metric in metrics:
            if metric not in results.keys():
                results[metric] = []
            results[metric].append(met_ret[metric])

            

fig, ax = plt.subplots(figsize=(10,4))
ax.set_title("Insulation")
twin1 = ax.twinx()
twin2 = ax.twinx()
twin3 = ax.twinx()
twin2.spines['left'].set_position(("axes", -0.1))
twin2.yaxis.set_label_position('left')
twin2.yaxis.set_ticks_position('left')
twin3.spines['right'].set_position(("axes", 1.1))

ax.spines['top'].set_visible(False)
twin1.spines['top'].set_visible(False)
twin2.spines['top'].set_visible(False)
twin3.spines['top'].set_visible(False)

sort_metric                = 'spearman'

excluded_features_sorted   = [x for _, x in sorted(zip(results[sort_metric], excluded_features),
                                                  key=lambda pair: pair[0])]
sorted_mse                 = [x for _, x in sorted(zip(results[sort_metric], 
                                                      results['mse']),
                                                 key=lambda pair:pair[0])]
sorted_mae                 = [x for _, x in sorted(zip(results[sort_metric], 
                                                      results['mae']),
                                                 key=lambda pair:pair[0])]
sorted_r2                 = [x for _, x in sorted(zip(results[sort_metric], 
                                                      results['r2']),
                                                 key=lambda pair:pair[0])]
sorted_pcc                 = [x for _, x in sorted(zip(results[sort_metric], 
                                                      results['pearson']),
                                                 key=lambda pair:pair[0])]
sorted_spc                 = [x for _, x in sorted(zip(results[sort_metric], 
                                                      results['spearman']),
                                                 key=lambda pair:pair[0])]

p0, = ax.plot(sorted_mse, label=metrics[0], c='green')
p1, = twin1.plot(sorted_mae, label=metrics[1], c='goldenrod')
p2, = twin2.plot(sorted_pcc, label=metrics[2], c='violet')
p3, = twin3.plot(sorted_spc, label=metrics[3], c='royalblue')
ax.set_xticks(list(range(0, len(excluded_features))))
ax.set_xticklabels(excluded_features_sorted, rotation=45, fontsize=14)
#p4, = twin1.plot(sorted_spc, label=metrics[4], c='goldenrod')
#ax.set_xticks(excluded_features_sorted)


ax.yaxis.label.set_color(p0.get_color())
ax.tick_params(axis='y', colors=p0.get_color())
twin1.yaxis.label.set_color(p1.get_color())
twin1.tick_params(axis='y', colors=p1.get_color())
twin2.yaxis.label.set_color(p2.get_color())
twin2.tick_params(axis='y', colors=p2.get_color())
twin3.yaxis.label.set_color(p3.get_color())
twin3.tick_params(axis='y', colors=p3.get_color())

plt.show()


In [None]:
##This shows comparison of model when excluding a single value
import glob

metrics = ['mse','mae','r2','pearson','spearman']
results = {}
for i, exclude_feature in enumerate(range(0,29)):
        dm = FlyDataModule(cell_line="S2",
                  data_win_radius=5,
                  batch_size=1,
                  label_type='insulation',
                  label_val=3,
                  exclude_feature=exclude_feature)
        dm.setup()
        layer_weights = glob.glob("Experiments/Transformer_Exclude_Features_Direction/lightning_logs/version_"+str(i)+"/checkpoints/*")[0]
        model   = TransformerModule.load_from_checkpoint(layer_weights).to("cuda:0")
        met_ret = getModelMetrics(model, dm, 'test')
        for metric in metrics:
            if metric not in results.keys():
                results[metric] = []
            results[metric].append(met_ret[metric])

            

fig, ax = plt.subplots(figsize=(10,4))
ax.set_title("Direction")
twin1 = ax.twinx()
twin2 = ax.twinx()
twin3 = ax.twinx()
twin2.spines['left'].set_position(("axes", -0.1))
twin2.yaxis.set_label_position('left')
twin2.yaxis.set_ticks_position('left')
twin3.spines['right'].set_position(("axes", 1.1))

ax.spines['top'].set_visible(False)
twin1.spines['top'].set_visible(False)
twin2.spines['top'].set_visible(False)
twin3.spines['top'].set_visible(False)

sort_metric                = 'spearman'

excluded_features_sorted   = [x for _, x in sorted(zip(results[sort_metric], excluded_features),
                                                  key=lambda pair: pair[0])]
sorted_mse                 = [x for _, x in sorted(zip(results[sort_metric], 
                                                      results['mse']),
                                                 key=lambda pair:pair[0])]
sorted_mae                 = [x for _, x in sorted(zip(results[sort_metric], 
                                                      results['mae']),
                                                 key=lambda pair:pair[0])]
sorted_r2                 = [x for _, x in sorted(zip(results[sort_metric], 
                                                      results['r2']),
                                                 key=lambda pair:pair[0])]
sorted_pcc                 = [x for _, x in sorted(zip(results[sort_metric], 
                                                      results['pearson']),
                                                 key=lambda pair:pair[0])]
sorted_spc                 = [x for _, x in sorted(zip(results[sort_metric], 
                                                      results['spearman']),
                                                 key=lambda pair:pair[0])]

p0, = ax.plot(sorted_mse, label=metrics[0], c='green')
p1, = twin1.plot(sorted_mae, label=metrics[1], c='goldenrod')
p2, = twin2.plot(sorted_pcc, label=metrics[2], c='violet')
p3, = twin3.plot(sorted_spc, label=metrics[3], c='royalblue')
ax.set_xticks(list(range(0, len(excluded_features))))
ax.set_xticklabels(excluded_features_sorted, rotation=45, fontsize=14)
#p4, = twin1.plot(sorted_spc, label=metrics[4], c='goldenrod')
#ax.set_xticks(excluded_features_sorted)


ax.yaxis.label.set_color(p0.get_color())
ax.tick_params(axis='y', colors=p0.get_color())
twin1.yaxis.label.set_color(p1.get_color())
twin1.tick_params(axis='y', colors=p1.get_color())
twin2.yaxis.label.set_color(p2.get_color())
twin2.tick_params(axis='y', colors=p2.get_color())
twin3.yaxis.label.set_color(p3.get_color())
twin3.tick_params(axis='y', colors=p3.get_color())

plt.show()
