In [19]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [20]:
df = pd.read_csv('../data/hpa_new.csv')
df.shape

(12386, 4)

In [21]:
def get_maps():
    df = pd.read_csv('../data/mart_infer.csv')
    d = dict()
    for i in df.index:
       d[df['Gene'][i]] = df['Protein'][i] # creating a dictionary that maps a gene id to protein id
    
    return d

In [22]:
maps = get_maps()

prots = list()
for i in df.index:
    try:
        p = maps[df['Gene'][i]]
    except:
        p = None
    prots.append(p)

df['Protein'] = prots
df.shape

(12386, 5)

In [23]:
df = df.loc[df['Protein'].notnull()]
df.shape

(1372, 5)

In [24]:
df.to_csv('../data/lstm-hpa.csv', index=None)

In [25]:
df1 = df[df['Reliability'] == 'Enhanced']
df1 = df1[['Protein', 'Locations']]
df1.shape

(77, 2)

In [26]:
df2 = df[df['Reliability'] == 'Supported']
df2 = df2[['Protein', 'Locations']]
df2.shape

(227, 2)

In [27]:
df3 = df[df['Reliability'] == 'Approved']
df3 = df3[['Protein', 'Locations']]
df3.shape

(432, 2)

In [28]:
df4 = df[df['Reliability'] == 'Uncertain']
df4 = df4[['Protein', 'Locations']]
df4.shape

(636, 2)

In [29]:
df_new = pd.read_csv('../data/lstm_output.csv')
df1 = df1.merge(df_new, on='Protein', how='inner')
df1.shape

(77, 30)

In [30]:
df2 = df2.merge(df_new, on='Protein', how='inner')
df2.shape

(227, 30)

In [31]:
df3 = df3.merge(df_new, on='Protein', how='inner')
df3.shape

(432, 30)

In [32]:
df4 = df4.merge(df_new, on='Protein', how='inner')
df4.shape

(636, 30)

In [33]:
df_all = df[['Protein', 'Locations', 'Reliability']]
df_all = df_all.merge(df_new, on='Protein', how='inner')
df_all.shape

(1372, 31)

In [34]:
df1.to_csv('../data/lstm_enhanced.csv', index=None)
df2.to_csv('../data/lstm_supported.csv', index=None)
df3.to_csv('../data/lstm_approved.csv', index=None)
df4.to_csv('../data/lstm_uncertain.csv', index=None)

In [35]:
thres_list =[
    0.5753,
    0.6,
    0.56688,
    0.6143,
    0.5026,
    0.5057,
    0.471,
    0.5047,
    0.49586,
    0.409,
    0.56156,
    0.4413,
    0.50756,
    0.5062,
    0.4634,
    0.501,
    0.5088,
    0.484,
    0.5017,
    0.588,
    0.516,
    0.4325,
    0.412,
    0.5456,
    0.5186,
    0.5005,
    0.4496,
    0.5993
]

In [36]:
df1.head()

Unnamed: 0,Protein,Locations,Actin filaments,Cell Junctions,Centriolar satellite,Centrosome,Cytokinetic bridge,Cytoplasmic bodies,Cytosol,Endoplasmic reticulum,...,Mitotic spindle,Nuclear bodies,Nuclear membrane,Nuclear speckles,Nucleoli,Nucleoli fibrillar center,Nucleoplasm,Peroxisomes,Plasma membrane,Vesicles
0,ENSP00000406757,{'Endoplasmic reticulum'},0.001722,0.014736,0.002636,0.006245,0.001295,2.5e-05,0.04788,0.094151,...,1.5e-05,0.001577,0.001441,0.006502,0.014051,0.00379,0.156449,9.4e-05,0.08581,0.180249
1,ENSP00000306640,{'Endoplasmic reticulum'},0.005205,0.010469,0.01129,0.010372,0.011835,0.001017,0.370142,0.035072,...,0.001909,0.048036,0.019869,0.010488,0.018155,0.008725,0.192937,0.001909,0.101142,0.126527
2,ENSP00000034275,{'Nucleoplasm'},0.011434,0.010328,0.010303,0.014271,0.022424,0.001772,0.418319,0.027593,...,0.002298,0.036616,0.009879,0.012671,0.026137,0.017352,0.242839,0.003638,0.09177,0.057281
3,ENSP00000378451,{'Nucleoplasm'},0.14284,0.01771,0.080239,0.041999,0.036299,0.012811,0.333657,0.004521,...,0.053426,0.030956,0.021376,0.177079,0.00834,0.031174,0.744344,0.050406,0.06592,0.098864
4,ENSP00000253928,"{'Cytosol', 'Nucleoplasm', 'Nuclear bodies'}",0.013052,0.008326,0.010919,0.008572,0.010574,0.006398,0.334167,0.020617,...,0.007429,0.046129,0.017265,0.019759,0.024112,0.018624,0.345217,0.009463,0.050379,0.127206


In [37]:
columns = list(df1.columns.values)
locs = columns[2:]
loc2id = {k: v for v,k in enumerate(locs)}

In [38]:
def get_set(locs):
    locs = locs[2:-2]
    locs = locs.replace("'", "")
    locs = locs.split(',')
    locs = [loc.strip() for loc in locs]
    locs = set(locs)
    return locs

def binarize(l, thresh_list, alpha):
    bin_list = list()
    for thresh, val in zip(thresh_list, l):
        if val > (alpha*thresh):
            bin_list.append(1)
        else:
            bin_list.append(0)
    return bin_list

def get_score(df, thresh, alpha=1.2):
    predictions = df.iloc[:, -28:]
    predictions = predictions.values.tolist()
    bin_predictions = list()
    pred_set = list()
    for probab in predictions:
        bin_probab = binarize(probab, thresh, alpha)
        bin_predictions.append(bin_probab)
        idx = [index for index, val in enumerate(bin_probab) if val == 1]
        loc_list = [locs[index] for index in idx]
        pred_set.append(set(loc_list))

    tar_set = list()
    for i in df.index:
        tar_loc = get_set(df['Locations'][i])
        tar_set.append(tar_loc)
    
    total, correct, count, extra = 0, 0, 0, 0
    for tar, pred in zip(tar_set, pred_set):
        total += len(tar)
        common = tar.intersection(pred)
        correct += len(common)
        
        count += 1
        new = pred - tar
        extra += len(new)

    return (correct/total), (extra/count)

In [39]:
# Enhanced score
score = get_score(df1, thres_list)
score

(0.13043478260869565, 0.2077922077922078)

In [40]:
# Supported score
score = get_score(df2, thres_list)
score

(0.0838150289017341, 0.2775330396475771)

In [41]:
# Approved score
score = get_score(df3, thres_list)
score

(0.07774390243902439, 0.17824074074074073)

In [42]:
# Uncertain score
score = get_score(df4, thres_list)
score

(0.4391829155060353, 9.536163522012579)

In [43]:
# Overall score
score = get_score(df_all, thres_list)
score

(0.26024873330262555, 4.534256559766764)

In [44]:
import os
def match(x, loc):
    locations = get_set(x)
    if loc in locations:
        return "Yes"
    return "No"

def generate_plots(df, path=None):
    cols_list = list(df.columns.values)
    locs = cols_list[-28:]
    loc_col = cols_list[1]
    
    if (path is not None) and (not os.path.isdir(path)):
        os.mkdir(path)

    for i,loc in enumerate(locs):
        x = df[loc].to_list()    
        y = df[loc_col].apply(match, args=(loc,))
        y = y.to_list()
        z = [loc]*len(y)
        
        ref_df = pd.DataFrame({
            "Probability": x, 
            "HPA": y,
            "Location":z
        })
        if i == 0:
            final_df = ref_df
        else:
            final_df = pd.concat([final_df, ref_df], ignore_index=True)
    
    sns.set(style='whitegrid', rc={'figure.figsize':(15,12)})
    ax = sns.swarmplot(x="Probability", y="Location", hue="HPA", data=final_df, dodge=True)
    ax.set_xlim(0,1)
    plt.savefig(path+'.png')
    plt.clf()

In [61]:
generate_plots(df1, path='../inference/Enhanced')

<Figure size 1080x864 with 0 Axes>

In [62]:
generate_plots(df2, path='../inference/Supported')

<Figure size 1080x864 with 0 Axes>

In [63]:
generate_plots(df3, path='../inference/Approved')

<Figure size 1080x864 with 0 Axes>

In [64]:
generate_plots(df4, path='../inference/Uncertain')

<Figure size 1080x864 with 0 Axes>

In [45]:
def unify_df(df):
    cols_list = list(df.columns.values)
    locs = cols_list[-28:]
    loc_col = cols_list[1]

    for i,loc in enumerate(locs):
        x = df[loc].to_list()    
        y = df[loc_col].apply(match, args=(loc,))
        y = y.to_list()
        z = [loc]*len(y)
        rel = df['Reliability'].to_list()
        
        ref_df = pd.DataFrame({
            "Probability": x, 
            "HPA": y,
            "Location": z,
            "Reliability": rel
        })
        if i == 0:
            final_df = ref_df
        else:
            final_df = pd.concat([final_df, ref_df], ignore_index=True)

    return final_df


In [46]:
df = unify_df(df_all)
df.shape

(38416, 4)

In [47]:
df.to_csv('../data/df_plots.csv', index=None)

In [48]:
# Technically new notebook

In [78]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px

df = pd.read_csv('../data/df_plots.csv')
df.head()

Unnamed: 0,Probability,HPA,Location,Reliability
0,0.994914,No,Actin filaments,Uncertain
1,0.018757,No,Actin filaments,Approved
2,0.001722,No,Actin filaments,Enhanced
3,0.979168,No,Actin filaments,Uncertain
4,0.948109,No,Actin filaments,Uncertain


In [79]:
def get_window(x):
    if x >= 0.0 and x < 0.1:
        return 0
    if x >= 0.1 and x < 0.2:
        return 1
    if x >= 0.2 and x < 0.3:
        return 2
    if x >= 0.3 and x < 0.4:
        return 3
    if x >= 0.4 and x < 0.5:
        return 4
    if x >= 0.5 and x < 0.6:
        return 5
    if x >= 0.6 and x < 0.7:
        return 6
    if x >= 0.7 and x < 0.8:
        return 7
    if x >= 0.8 and x < 0.9:
        return 8
    return 9

df['Probability'] = df['Probability'].apply(lambda x: get_window(x))

In [81]:
def get_plot(df, rel, loc, loc_idx):
    df = df[df['Reliability']==rel]
    df = df[df['Location']==loc]
    d_pred = [0]*10      # Because 10 reliability values
    d_hpa = [0]*10       # Because 10 reliability values
    for i in df.index:
        prob = df['Probability'][i]
        hpa = df['HPA'][i]
        if hpa == "Yes":
            hpa = 1
        else:
            hpa = 0
        d_pred[prob]+=1
        d_hpa[prob]+=hpa
    d_vals = [x/y if y > 0 else -0.1 for x,y in zip(d_hpa, d_pred)]
    locs = [loc_idx]*len(d_vals)
    df = pd.DataFrame({
        'Class_id': [*range(0,10)],
        'Value': d_vals,
        'Loc_id': locs
    })
    return df

def reliability_plot(df):
    reliability = df.Reliability.values.tolist()
    reliability = list(set(reliability))
    locs = df.Location.values.tolist()
    locs = list(set(locs))
    locs = sorted(locs)
    for val in reliability:
        for i,loc in enumerate(locs):
            if i == 0:
                final_df = get_plot(df, val, loc, i)
            else:
                new_df = get_plot(df, val, loc, i)
                final_df = pd.concat([final_df, new_df], ignore_index=True) 
        # df_wide = final_df.pivot('Class_id', 'Loc_id', 'Value')
        # plt.figure(figsize=(10,5))
        # sns.lineplot(data=df_wide, dashes=False)
        # plt.legend(bbox_to_anchor=(1.01, 1),borderaxespad=0)
        # plt.tight_layout()
        # plt.savefig(f'../inference/{val}_new.png', format='png')
        fig = px.line(final_df, x='Class_id', y='Value', color='Loc_id')
        fig.write_html(f'../inference/{val}.html')

reliability_plot(df)

Mitochondria


In [82]:
def get_plot(df, rel, loc):
    df = df[df['Reliability']==rel]
    df = df[df['Location']==loc]
    d_pred = [0]*10      # Because 10 reliability values
    d_hpa = [0]*10       # Because 10 reliability values
    for i in df.index:
        prob = df['Probability'][i]
        hpa = df['HPA'][i]
        if hpa == "Yes":
            hpa = 1
        else:
            hpa = 0
        d_pred[prob]+=1
        d_hpa[prob]+=hpa
    d_vals = [x/y if y > 0 else -0.1 for x,y in zip(d_hpa, d_pred)]
    rels = [rel]*len(d_vals)
    df = pd.DataFrame({
        'Class_id': [*range(0,10)],
        'Value': d_vals,
        'Reliability': rels
    })
    return df            

def location_plot(df):
    reliability = df.Reliability.values.tolist()
    reliability = list(set(reliability))
    locs = df.Location.values.tolist()
    locs = list(set(locs))
    locs = sorted(locs)
    for loc in locs:
        for i,val in enumerate(reliability):
            if i == 0:
                final_df = get_plot(df, val, loc)
            else:
                new_df = get_plot(df, val, loc)
                final_df = pd.concat([final_df, new_df], ignore_index=True)
        fig = px.line(final_df, x='Class_id', y='Value', color='Reliability')
        fig.write_html(f'../inference/{loc}.html')

location_plot(df)