In [1]:
def drawit(df_temp, structure, predictions):
    
    import matplotlib.pyplot as plt
    import networkx as nx
    from matplotlib.lines import Line2D
    import pandas as pd
    import numpy as np
    from sklearn.metrics.pairwise import cosine_similarity
    from sklearn.metrics import f1_score

    structure = structure # The structure that we want to draw !!!!!
    
    # making a dataframe containing 1)Graph ids 2)Nodes 3)actual weight of edges 4)predicted weight of edges
    df2=pd.DataFrame()
    df2['chain1'] = df_temp['pair'].apply(lambda x: x.split('_')[1])
    df2['chain2'] = df_temp['pair'].apply(lambda x: x.split('_')[3])
    df2['pred'] = predictions
    df2['edge'] = list(df_temp.norm_obs_contact)


    nodes = set(df2['chain1']).union(set(df2['chain2']))
    m = max(df2.edge)
    
    ### Calculating the Cosine Similarity between predicted and actual graphs
    FP, FN = 0,0
    G1, G2 = [],[]
    for _, row in df2.iterrows():
        if row['edge']== 0 and row['pred'] > 0:
            FP += 1
        elif row['edge'] > 0 and row['pred'] == 0:
            FN += 1
        G1.append(row['edge'])
        G2.append(row['pred'])
    # Convert the lists to numpy arrays and reshape them
    G1 = np.asarray(G1).reshape(1, -1)
    G2 = np.asarray(G2).reshape(1, -1)
    similarity = cosine_similarity(G1, G2)[0][0]
    
    weighted_f1 = f1_score(np.where(df2.edge>0,1,0),np.where(df2.pred>0,1,0), average='weighted')
    
    
    ### Visualize the Graph

    # Create the figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 10))
    fig.suptitle(structure+'\nCosine Similarity='+str(round(similarity,2))+'   Weighted F1 score='+str(round(weighted_f1,2))+'   FP= '+str(FP)+'   FN='+str(FN), fontsize=20, fontweight='bold')

    # Create the first graph
    G1 = nx.Graph()
    G1.add_nodes_from(nodes)
    pos1 = nx.spring_layout(G1, seed = 5)
    nx.draw_networkx_nodes(G1, pos1, node_size=500, node_color='lightblue', alpha=0.8, ax=ax1)
    for _, row in df2.iterrows():
        G1.add_edge(row['chain1'], row['chain2'], weight=row['edge'])
        nx.draw_networkx_edges(G1, pos1, edgelist=[(row['chain1'], row['chain2'])], width=row['edge']/m*9, edge_color='navy', ax=ax1)
        if row['chain1']==row['chain2']:
            G1.remove_edge(row['chain1'], row['chain2'])
    nx.draw_networkx_labels(G1, pos1, font_size=14, font_family='sans-serif', font_weight='bold', ax=ax1)
    # Add title and axis labels to the first subplot
    ax1.set_title('True Interaction Graph', fontsize=17)
    ax1.set_xlabel('Chain', fontweight='bold', fontsize=16)
    ax1.set_ylabel('Chain', fontweight='bold', fontsize=16)
    ax1.axis('off')

    # Create the second graph
    G2 = nx.Graph()
    G2.add_nodes_from(nodes)
    pos2 = nx.spring_layout(G2, seed = 5)
    nx.draw_networkx_nodes(G2, pos2, node_size=500, node_color='lightblue', alpha=0.8, ax=ax2)
    for _, row in df2.iterrows():
        ch1, ch2 = row['chain1'], row['chain2']
        if  row['pred'] > 0 and row['edge']>0:
            G2.add_edge(ch1, ch2, weight=11)
            nx.draw_networkx_edges(G2, pos2, edgelist=[(ch1, ch2)], width=row['pred']/m*9, edge_color='blue',ax=ax2)
            if ch1 == ch2:
                G2.remove_edge(ch1, ch2)
        elif  row['pred'] > 0 and row['edge']==0:
            G2.add_edge(ch1, ch2, weight=10)
            nx.draw_networkx_edges(G2, pos2, edgelist=[(ch1, ch2)], width=row['pred']/m*9, edge_color='red',ax=ax2)
            if ch1 == ch2:
                G2.remove_edge(ch1, ch2)
        elif  row['pred'] == 0 and row['edge']>0:
            G2.add_edge(ch1, ch2, weight=9)
            nx.draw_networkx_edges(G2, pos2, edgelist=[(ch1, ch2)], width=row['edge']/m*9, edge_color=(0.95, 0.07,0.2), style='dashed',ax=ax2)
            if ch1 == ch2:
                G2.remove_edge(ch1, ch2)
    nx.draw_networkx_labels(G2, pos2, font_size=14, font_family='sans-serif', font_weight='bold',ax=ax2)
    # Add a legend to the second subplot
    legend_patches = [Line2D([0], [0], color='blue', lw=3),
                      Line2D([0], [0], color='red', lw=3),
                      Line2D([0], [0], color=(0.95, 0.07,0.2), lw=3, linestyle='--')]
    ax2.legend(legend_patches, ['TP', 'FP', 'FN'], loc='best', fontsize=12)
    # Add title and axis labels to the second subplot
    ax2.set_title('Predicted Interaction Graph', fontsize=17)
    ax2.set_xlabel('Chain', fontweight='bold', fontsize=16)
    ax2.set_ylabel('Chain', fontweight='bold', fontsize=16)
    ax2.axis('off')

    # Adjust the spacing between the subplots
    plt.subplots_adjust(hspace=0.5)
    #plt.savefig(structure+'.png',dpi=300)