In [None]:
%load_ext autoreload
%autoreload 2
!echo $HOSTNAME

import sys
print('Python path: ', sys.executable)

In [None]:
from pathlib import Path
from collections import namedtuple
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import sys
import pickle

## Gathering fewshot resutls

In [None]:
fewshot_performance = np.load("tcrp_fewshot-test-correlations.npz")['arr_0']

## Gathering baselines

In [None]:
datapath = Path("../../output/210803_drug-baseline-models/baseline_performances")

In [None]:
%%time

results = {}
for outer_directory in datapath.glob("*"): 
    drug = outer_directory.stem
    results[drug] = {}
    
    for inner_directory in outer_directory.glob("*"): 
        tissue = inner_directory.stem
        results[drug][tissue] = {}
        
        data = np.load(inner_directory / "baseline_performance.npz")
        
        for model in ['linear', 'KNN', 'RF']: 
            zero = data[f"{model}-zero"]
            zero = np.vstack([zero for _ in range(20)]) # There is only 1 possible zero-shot, so expanding for all trials
            performance = np.median(np.hstack([zero, data[f"{model}-fewshot"]]), axis=0)
            
            results[drug][tissue][model] = performance    

In [None]:
results_by_baseline = {'linear': [], 'KNN': [], 'RF': []}

for drug, d in results.items(): 
    for tissue, d in d.items(): 
        for model, p in d.items(): 
            results_by_baseline[model].append(p)
            
for model, ps in results_by_baseline.items(): 
    results_by_baseline[model] = np.vstack(ps)

## All performance

In [None]:
def get_statistics(data): 
    median = np.nanmedian(data, axis=0)
    index = np.random.choice(data.shape[0], size=(data.shape[0], 1000), replace=True)
    resampled = np.nanmedian(data[index], axis=0)

    low = np.nanpercentile(resampled, 2.5, axis=0)
    high = np.nanpercentile(resampled, 97.5, axis=0)
    
    ci = np.vstack([median - low, high-median])
    
    return median, ci

In [None]:
fig, ax = plt.subplots()

x = np.arange(11)

kwargs = {'capsize': 4}

for model, ps in results_by_baseline.items(): 
    median, yerr = get_statistics(ps)
    ax.errorbar(x, median, yerr=yerr, label=model, **kwargs)
    
median, yerr = get_statistics(fewshot_performance)
ax.errorbar(x, median, yerr=yerr, label='TCRP', **kwargs)

ax.legend()
labels = ['Pretrained'] + [str(i) for i in range(1, 11)]
ax.set_xticks(np.arange(11))
ax.set_xticklabels(labels)

ax.set_ylim([0, 0.3])
ax.set_xlabel("Number of samples, few-shot learning")
ax.set_ylabel("Correlation (predicted, actual)")

ax.set_title("Corrected results")