#### This notebook contains helper functions related to inference and evaluation of the trained models

In [1]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:2"

import sys
sys.path.append('..')

from utils.data import *
from utils.loss import *
from utils.model import *
from utils.train import *
from utils.plots import *
from utils.util import *

#import tensorboardX as tb
import torch.optim as optim

#### Load data to get the evaluation set

In [3]:
set_seed(42) 

reduced_data=load_and_process_data(mode='dna_gpn_esm2', lower_threshold=10, na_upper_threshold=100, protein_upper_threshold=1000, dataset_dir = '../data/')

protein_seqs = [i['protein_seq'] for i in reduced_data]
na_seqs = [i['dna_seq'] for i in reduced_data]
contact_maps = [torch.tensor(i['complex_contact_map']).T for i in reduced_data]
pdb_ids = [i['pdb_id'] for i in reduced_data]

train_dset, eval_dset = sequence_similarity_split(reduced_data, mode='dna')
eval_dataloader = DataLoader(eval_dset, batch_size=1, shuffle=False, collate_fn=collate_sequences)

#### Load model (NOTE: Replace 'path' with the actual paths)


In [4]:
d_model_q = 320
d_model_kv = 512
d_k = 32 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

protein_model = ESMModel().to(device)
na_model = GPNModel().to(device)
binding_model = CustomCrossAttention(d_model_q, d_model_kv, d_k).to(device)

checkpoint = torch.load('path', map_location=device)

protein_model.model.load_state_dict(checkpoint['protein_model_state_dict'])
na_model.model.load_state_dict(checkpoint['na_model_state_dict'])
binding_model.load_state_dict(checkpoint['model_state_dict'])

#### Get predictions and metrics


In [19]:
dl = eval_dataloader

protein_model.eval()
na_model.eval()
binding_model.eval()

#predictions = []
visualizer_data = []
total_loss_eval = 0
total_f1_eval = 0
total_prauc_eval = 0
total_mcc_eval = 0

with torch.no_grad():
            for protein_seqs, na_seqs, contact_maps, pdb_id in dl:
                protein_seqs = [('', i) for i in protein_seqs]
                protein_embeddings = protein_model(protein_seqs)  # List of tensors of shape (L, d_model)
                na_embeddings = na_model(na_seqs)  # List of tensors of shape (L, d_model)
                padded_protein_embeddings, padded_na_embeddings, padded_contact_maps, protein_lengths, rna_lengths = collate_embeddings(protein_embeddings, na_embeddings, contact_maps)
                padded_protein_embeddings, padded_na_embeddings = padded_protein_embeddings.to(device), padded_na_embeddings.to(device)
                mask = create_mask(protein_lengths, rna_lengths, padded_protein_embeddings.size(1), padded_na_embeddings.size(1)).unsqueeze(1).float().to(device)

                output, predicted_attention = binding_model(padded_protein_embeddings, padded_na_embeddings, mask)

                loss = weighted_bce_loss(predicted_attention, padded_contact_maps.squeeze(), mask, pos_weight=None)
                total_loss_eval += loss.item()
                batch_f1, batch_prauc, batch_mcc = get_batch_metrics(padded_contact_maps.cpu().numpy(), predicted_attention.detach().cpu().numpy(), protein_lengths, rna_lengths)

                total_f1_eval += batch_f1
                total_prauc_eval += batch_prauc
                total_mcc_eval += batch_mcc
                visualizer_data.append((padded_contact_maps, predicted_attention.detach(), protein_lengths, rna_lengths, pdb_id))

          

print('Total Loss: ', total_loss_eval/len(dl)) 
print('Total F1: ', total_f1_eval/len(dl))
print('Total PRAUC: ', total_prauc_eval/len(dl))
print('Total MCC: ', total_mcc_eval/len(dl))

#### Function to plot GT, Pred and Overlays


In [25]:
def overlay_attention_maps(gt_attention, predicted_attention, protein_len, rna_len,pdb_id,plot=True):
    fig, axes = plt.subplots(1, 3, figsize=(20, 40), constrained_layout=True)
    predicted_attention = predicted_attention.squeeze(0).cpu().numpy()
    gt_attention = gt_attention.squeeze(0).cpu().numpy()

    # Transpose the attention maps to flip the axes
    gt_attention = gt_attention.T
    predicted_attention = predicted_attention.T

    # Ground Truth Contact Map
    ax = axes[0]
    cax = ax.matshow(gt_attention[:rna_len, :protein_len], cmap='viridis')
    ax.set_title(f'GT')
    ax.set_xlabel('Protein Seq')
    ax.set_ylabel('NA Seq')

    # Overlay Plot
    ax = axes[1]
    overlay = np.ones((rna_len, protein_len, 3))
    for i in range(rna_len):
        for j in range(protein_len):
            if gt_attention[i, j] > 0.5 and predicted_attention[i, j] > 0.5:
                overlay[i, j] = [0, 0, 0]  # Purple for matching non-zero points
            elif gt_attention[i, j] > 0.5 and predicted_attention[i, j] <= 0.5:
                overlay[i, j] = [0, 0, 1]  # Blue for ground truth contacts
            elif gt_attention[i, j] <= 0.5 and predicted_attention[i, j] > 0.5:
                overlay[i, j] = [1, 0, 0]  # Red for incorrect predictions
    ax.imshow(overlay)
    ax.set_title(f'Overlay:{pdb_id}')
    ax.set_xlabel('Protein Seq')
    ax.set_ylabel('NA Seq')

    # Predicted Contact Map
    ax = axes[2]
    cax = ax.matshow(predicted_attention[:rna_len, :protein_len], cmap='viridis')
    ax.set_title('Pred')
    ax.set_xlabel('Protein Seq')
    ax.set_ylabel('NA Seq')

    # Adjust spacing between plots
    if not plot:
        plt.close()
        return fig
    else:
        plt.show()

#### Parse and Plot metrics from Tensorboard Logs

In [13]:
import os
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

def extract_scalars_from_event_file(event_file):
    event_acc = EventAccumulator(event_file)
    event_acc.Reload()
    
    scalars = {}
    for tag in event_acc.Tags()['scalars']:
        scalars[tag] = event_acc.Scalars(tag)
    
    return scalars

def plot_all(scalars, output_dir, save=False):
    # Plot CL Metrics
    metrics = ['Metrics/F1', 'Metrics/PR_AUC', 'Metrics/MCC']
    phases = ['', 'Eval']  # '' for train phase, 'Eval' for evaluation phase
    
    fig, axes = plt.subplots(2, 3, figsize=(9, 5), constrained_layout=True)
    for j, metric in enumerate(metrics):
        for i, phase in enumerate(phases):
            tag = f"{metric}/{phase}" if phase else metric
            
            if tag in scalars:
                values = scalars[tag]
                steps = [x.step for x in values]
                scalar_values = [x.value for x in values]

                ax = axes[i, j]
                ax.plot(steps, scalar_values, label=tag)
                ax.set_xlabel('Epochs')
                ax.set_ylabel(metric.split('/')[-1])
                ax.set_title(f'{metric.split("/")[-1]} (Eval)' if phase else f'{metric.split("/")[-1]} (Train)')
                ax.legend()
    
    if save:
        output_path = os.path.join(output_dir, "metrics_plot.png")
        plt.savefig(output_path)
        plt.close()
    else:
        plt.show()

    # Plot Loss
    phases = ['Train', 'Eval']
    fig, axes = plt.subplots(1, 2, figsize=(9, 5), constrained_layout=True)
    for i, phase in enumerate(phases):
        tag = f"Loss/{phase}"
        
        if tag in scalars:
            values = scalars[tag]
            steps = [x.step for x in values]
            scalar_values = [x.value for x in values]

            ax = axes[i]
            ax.plot(steps, scalar_values, label=tag)
            ax.set_xlabel('Epochs')
            ax.set_ylabel('Loss')
            ax.set_title(f'Loss ({phase})')
            ax.legend()
    
    if save:
        output_path = os.path.join(output_dir, "loss_plot.png")
        plt.savefig(output_path)
        plt.close()
    else:
        plt.show()

    # Plot Gradients
    gradients = ['Gradients/Binding_Model', 'Gradients/Protein_Model', 'Gradients/NA_Model']
    y_limits = [(0.2, 0.5), (0.125, 0.2), (0.01, 0.02)]  # (lower_limit, upper_limit) for each gradient
    fig, axes = plt.subplots(1, 3, figsize=(9, 5), constrained_layout=True)
    
    for i, gradient in enumerate(gradients):
        if gradient in scalars:
            values = scalars[gradient]
            steps = [x.step for x in values]
            scalar_values = [x.value for x in values]

            # Filter out epoch 0
            steps_filtered = [step for step in steps if step > 0]
            scalar_values_filtered = [scalar_values[j] for j, step in enumerate(steps) if step > 0]

            ax = axes[i]
            ax.plot(steps_filtered, scalar_values_filtered, label=gradient)
            ax.set_xlabel('Epochs')
            ax.set_ylabel('Gradient')
            ax.set_title(f'{gradient.split("/")[-1]}')
            ax.set_ylim(y_limits[i])
            ax.legend()
    
    if save:
        output_path = os.path.join(output_dir, "gradients_plot.png")
        plt.savefig(output_path)
        plt.close()
    else:
        plt.show()



In [17]:
scalars=extract_scalars_from_event_file('path')
plot_all(scalars, 'path', save=False)