In [None]:
import numpy as np
import pandas as pd
import random
import pickle
from scipy.stats import ttest_ind_from_stats

np.random.seed(42)
random.seed(42)

In [None]:
y = pd.read_pickle("../../data/tts/y_test_ord.pkl")

# Blast deduplicates lab labels if multiple hits are from same lab.
blast = np.load("../../data/blast/firstorder_blast_test_predictions.npy")
ours = np.load("../../data/results/TEST_100_sub_MLP_full_with_metadata_300300.npy")
ours_no_pheno = np.load("../../data/results/TEST_100_sub_nometadata_201300.npy")

In [None]:
def topkacc(pred,true,k):
    topkclasses = np.argsort(pred)[:,-k:]
    correct = ((topkclasses == true[:,None]).any(axis=1))
    accuracy = np.sum(correct) / len(true)
    return accuracy

def blastacc(predictions, true):
    top1_prediction = np.array([a[0] if len(a) != 0 else 9999999 for a in predictions])
    correct = (top1_prediction == true)
    return np.mean(correct)
    
def blastacc10(predictions, true):
    correct = [true[i] in predictions[i] if len(predictions[i]) != 0 else 9999999 for i in range(len(true))]
    return np.mean(correct)

In [None]:
def resample_metric(metric_func, pred, true, n=30, frac=50):
    """
    Returns the average and standard deviation of metric_func 
    evaluated on n resamples of frac proportion of test set.
    
    metric_func must have the signature
    metric_func(np.array(predictions), np.array(true))
    """
    length = len(true)
    assert len(pred) == length
    if len(pred) == 0:
        return (0,0)
    
    index = np.arange(length)
    samples = []
    for i in range(n):
        subidx = np.random.choice(index, size=length, replace=False)
        samples.append(
            metric_func(pred[index], true[index])
        )
    return np.mean(samples), np.std(samples)

In [None]:
# First top 1
ours_mean, ours_std = resample_metric(
    lambda x,y1: topkacc(x,y1,k=1),
    ours, y
)

blast_mean, blast_std = resample_metric(
    blastacc,
    blast, y
)

In [None]:
ttest_ind_from_stats(
mean1=ours_mean, std1=ours_std, nobs1=30,
mean2=blast_mean, std2=blast_std, nobs2=30, equal_var=False
)

In [None]:
# Top 10
ours_mean, ours_std = resample_metric(
    lambda x,y1: topkacc(x,y1,k=10),
    ours, y
)

blast_mean, blast_std = resample_metric(
    blastacc10,
    blast, y
)

In [None]:
ttest_ind_from_stats(
mean1=ours_mean, std1=ours_std, nobs1=30,
mean2=blast_mean, std2=blast_std, nobs2=30, equal_var=False
)

# Now Compare +/- phenotype

In [None]:
# First top 1
ours_mean, ours_std = resample_metric(
    lambda x,y1: topkacc(x,y1,k=1),
    ours, y
)

ours_seq_mean, ours_seq_std = resample_metric(
    lambda x,y1: topkacc(x,y1,k=1),
    ours_no_pheno, y
)

In [None]:
ttest_ind_from_stats(
mean1=ours_mean, std1=ours_std, nobs1=30,
mean2=ours_seq_mean, std2=ours_seq_std, nobs2=30, equal_var=False
)

In [None]:
# Top 10
# First top 1
ours_mean, ours_std = resample_metric(
    lambda x,y1: topkacc(x,y1,k=10),
    ours, y
)

ours_seq_mean, ours_seq_std = resample_metric(
    lambda x,y1: topkacc(x,y1,k=10),
    ours_no_pheno, y
)

In [None]:
ttest_ind_from_stats(
mean1=ours_mean, std1=ours_std, nobs1=30,
mean2=ours_seq_mean, std2=ours_seq_std, nobs2=30, equal_var=False
)

# Now fig 3 e (countries)

In [None]:

y = pd.read_pickle("../../data/tts/y_test_country.pkl")
blast = np.load("../../country_blast_test_predictions.npy")
ours = np.load("../../data/results/rf/predictions_TEST_countries_seq_meta_nous.npy")

In [None]:
# drop US 
mask = ~(y == 33)
y = y[mask]
blast = blast[mask]
ours = ours[mask]

In [None]:
# First top 1
ours_mean, ours_std = resample_metric(
    lambda x,y1: topkacc(x,y1,k=1),
    ours, y
)

blast_mean, blast_std = resample_metric(
    blastacc,
    blast, y
)

In [None]:
ttest_ind_from_stats(
mean1=ours_mean, std1=ours_std, nobs1=30,
mean2=blast_mean, std2=blast_std, nobs2=30, equal_var=False
)

In [None]:
# Top 10
ours_mean, ours_std = resample_metric(
    lambda x,y1: topkacc(x,y1,k=10),
    ours, y
)

blast_mean, blast_std = resample_metric(
    blastacc10,
    blast, y
)

In [None]:
ttest_ind_from_stats(
mean1=ours_mean, std1=ours_std, nobs1=30,
mean2=blast_mean, std2=blast_std, nobs2=30, equal_var=False
)