In [None]:
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import re

# ======== LOAD AND PARSE ROSETTA JSON RESULTS ========

def parse_rosetta_results(json_filename):
    with open(json_filename, 'r') as f:
        data = json.load(f)

    results = []
    for key, entries in data.items():
        if '-' not in key:
            continue
        nb, ag = key.split('-', 1)
        
        binding_energies = [entry['binding_energy'] for entry in entries if 'binding_energy' in entry]
        if len(binding_energies) == 0:
            continue
        
        mean_energy = np.mean(binding_energies)
        std_energy = np.std(binding_energies)

        results.append({
            'job_name': key,
            'nanobody': nb,
            'antigen': ag,
            'mean_binding_energy': mean_energy,
            'std_binding_energy': std_energy
        })
    
    return pd.DataFrame(results)

# ======== CREATE HEATMAP ========

def create_binding_matrix_visualization_rosetta(df, score_column='mean_binding_energy', title_suffix='Rosetta Binding Energy'):
    if len(df) == 0:
        print("No data to plot")
        return None
        
    # sort the names alphabetically (case insensitive)
    nanobodies = sorted(df['nanobody'].unique(), key=str.lower)
    antigens = sorted(df['antigen'].unique(), key=str.lower)
    
    # make pivot table
    matrix = df.pivot(index='nanobody', columns='antigen', values=score_column)
    matrix = matrix.reindex(index=nanobodies, columns=antigens)
    
    std_matrix = df.pivot(index='nanobody', columns='antigen', values='std_binding_energy')
    std_matrix = std_matrix.reindex(index=nanobodies, columns=antigens)
    
    # Create annotation labels with score ± std
    annot_labels = matrix.round(1).astype(str) + '\n±' + std_matrix.round(1).astype(str)
    
    plt.figure(figsize=(14, 12))
    cmap = 'viridis'
    
    ax = sns.heatmap(
        matrix,
        annot=annot_labels,
        fmt='',
        cmap=cmap,
        square=True,
        linewidths=0.5,
        cbar_kws={'label': f'{title_suffix}'},
        annot_kws={'size': 18}
    )
    
    cbar = ax.collections[0].colorbar
    cbar.set_label(f'{title_suffix}', size=20)
    
    # highlight the diagonal - these should be the real binding pairs
    for i in range(min(len(nanobodies), len(antigens))):
        ax.add_patch(plt.Rectangle((i, i), 1, 1, fill=False, edgecolor='red', lw=5))
    
    plt.title(f'Nanobody-Antigen Binding Matrix ({title_suffix})', fontsize=22, pad=20)
    plt.xlabel('Antigens', fontsize=22)
    plt.ylabel('Nanobodies', fontsize=22)
    plt.xticks(rotation=45, ha='right', fontsize=18)
    plt.yticks(rotation=0, fontsize=18)
    plt.tight_layout()
    
    plt.savefig("binding_matrix_rosetta.png", dpi=300)
    plt.show()
    
    return matrix

# ======== ANALYZE RESULTS ========

def analyze_rosetta_results(df):
    print("=== Rosetta Results Analysis ===")
    print()
    
    if len(df) == 0:
        print("ERROR: No data found - check file format")
        return
    
    print("Total experiments:", len(df))
    print("Number of nanobodies:", df['nanobody'].nunique())
    print("Number of antigens:", df['antigen'].nunique())
    print()
    
    print("Top 5 binding pairs (lowest mean Rosetta binding energy):")
    top_pairs = df.nsmallest(5, 'mean_binding_energy')
    for _, row in top_pairs.iterrows():
        print("  " + row['nanobody'] + " + " + row['antigen'] + ": " +
              str(round(row['mean_binding_energy'], 1)) +
              " ± " + str(round(row['std_binding_energy'], 1)))
    print()

    # Expected diagonal matches
    diagonal_pairs = []
    for _, row in df.iterrows():
        nb_clean = re.sub(r'^nb', '', row['nanobody'], flags=re.IGNORECASE).upper()
        ag_clean = row['antigen'].split('_')[0].upper()
        if nb_clean == ag_clean:
            diagonal_pairs.append(row)
    
    if diagonal_pairs:
        print("Expected binding pairs (diagonal matches):")
        for row in diagonal_pairs:
            print("  " + row['nanobody'] + " + " + row['antigen'] + ": " +
                  str(round(row['mean_binding_energy'], 1)) +
                  " ± " + str(round(row['std_binding_energy'], 1)))
        print()
    
    print("Summary statistics:")
    print("  Mean Binding Energy: {:.2f} ± {:.2f}".format(df['mean_binding_energy'].mean(), df['mean_binding_energy'].std()))
    print("  Energy Range:", round(df['mean_binding_energy'].min(), 1), "to", round(df['mean_binding_energy'].max(), 1))
    print()

# ======== MAIN EXECUTION ========

if __name__ == "__main__":
    filename = "filename.json"  # Replace with actual JSON file
    df = parse_rosetta_results(filename)
    
    if len(df) > 0:
        matrix = create_binding_matrix_visualization_rosetta(df)
        analyze_rosetta_results(df)
        df.to_csv('rosetta_results.csv', index=False)
        if matrix is not None:
            matrix.to_csv('binding_matrix_rosetta.csv')
    else:
        print("No data found - check file format")

    print("Done!")
