In [None]:
import sys, os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

import numpy as np
import torch
import pandas as pd
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

from MGraphDTA.regression.preprocessing import GNNDatasetFull, GNNDataset
from xai_dta.utils.models import load_model
from xai_dta.config import PROJ_ROOT
from src.utils.plot_utils import get_styled_figure_ax, style_legend

%config InlineBackend.figure_format = 'svg'


from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats('svg')

In [None]:
import importlib
import src.utils.plot_utils
importlib.reload(src.utils.plot_utils)
from src.utils.plot_utils import get_styled_figure_ax, style_legend

In [None]:
device='cpu'
dataset_name='kiba'
split='test'

root = os.path.abspath(os.path.join(os.getcwd(), '..'))
model = load_model(
    PROJ_ROOT, 
    model_name='kiba/epoch-2578, loss-0.0122, cindex-0.9741, test_loss-0.1265.pt'
)
dataset_path = os.path.join(PROJ_ROOT, "MGraphDTA", "regression", "data", dataset_name)
dataset = GNNDataset(
    root = dataset_path,
    train=split=='train',
    transform_unique=True
)
max_sample_train = len(dataset)

In [None]:
data_raw = pd.read_csv(os.path.join(dataset_path, 'raw', f"data_{split}.csv"))

embeddings = np.zeros((max_sample_train, 192))
labels = np.zeros((max_sample_train, 2)) 
preds = np.zeros((max_sample_train, 1)) 
columns = ['x', 'y', 'pred', 'gt', 'compound_iso_smiles','target_sequence']

indices = np.random.choice(len(dataset), max_sample_train, replace=False)

In [None]:
model

In [None]:
data_raw.head()

In [None]:
i = 0
for idx in tqdm(indices):
    smi, seq = data_raw.iloc[idx,:][['compound_iso_smiles', 'target_sequence']]
    data, _ = dataset.transform_unique(smi, seq)

    with torch.no_grad():
        out, protein_emb, ligand_emb = model.forward_features(data)
        # each: [[...]]
    
    embedding = torch.hstack((protein_emb, ligand_emb)).cpu().numpy()[0]
    # [bs, 192]

    embeddings[i] = embedding
    labels[i] = np.array((f"{out.item():.2f}", f"{data_raw.loc[idx,'affinity']:.2f}"))
    i += 1

In [None]:
list(dataset.graph_dict.keys())[:10]

In [None]:
embeddings.shape, labels.shape, indices.shape

In [None]:
filename = f'embeddings_train_kiba_{max_sample_train}_2578.npy'
folder = 'results/latent_space'
os.makedirs(folder, exist_ok=True)

In [None]:
np.save(os.path.join(folder, filename), embeddings)

In [None]:
filenamelabel = f'labels_train_kiba_{max_sample_train}_2578.npy'
filenameindices = f'indices_train_kiba_{max_sample_train}_2578.npy'
folder = 'results/latent_space'

In [None]:
np.save(os.path.join(folder, filenamelabel), labels)
np.save(os.path.join(folder, filenameindices), indices)

In [None]:
embeddings = np.load(os.path.join(folder, filename))
labels = np.load(os.path.join(folder, filenamelabel))
labels = np.load(os.path.join(folder, filenamelabel))

In [None]:
sns.histplot(labels[:,0], bins=50)
plt.title('Distribution of Predicted affinity')
plt.xlabel('Affinity')
plt.ylabel('Count')
plt.show()
plt.close()


sns.histplot(labels[:,1], bins=50)
plt.title('Distribution of True affinity')
plt.xlabel('Affinity')
plt.ylabel('Count')
plt.show()
plt.close()

In [None]:
t_sne = TSNE(
    n_components=2,
    perplexity=30,
    init="random",
    max_iter=250,
    random_state=0,
)
pts_t_sne = t_sne.fit_transform(embeddings)

In [None]:
data = np.column_stack(
    (pts_t_sne[:,0], 
     pts_t_sne[:,1], 
     labels[:,0], 
     labels[:,1], 
     data_raw.loc[indices,'target_sequence'].to_numpy(),
     data_raw.loc[indices,'compound_iso_smiles'].to_numpy()))
df = pd.DataFrame(data, columns=columns)  
print(columns)

In [None]:
df.to_csv(os.path.join(folder, 'tsne_embeddings.csv'), index=False)

In [None]:
plt.figure()
sns.scatterplot(df, x='x', y='y', hue='pred')#, style='gt')
plt.show()
plt.close() 

In [None]:
pca = PCA(n_components=2)
pca_result = pca.fit_transform(embeddings)

In [None]:
plt.scatter(pca_result[:, 0], pca_result[:, 1], c=labels[:,0], cmap='viridis')
plt.title(f"PCA - var explained: {sum(pca.explained_variance_ratio_):.2f}")
plt.show()
plt.close()

In [None]:
import umap.umap_ as umap_learn

print("Running UMAP on the 192D interaction space...")
reducer = umap_learn.UMAP(n_components=2, random_state=42, n_neighbors=15, min_dist=0.1)

embedding_2d = reducer.fit_transform(embeddings)
print("UMAP complete.")

In [None]:

# 4. Create the plot
plot_df = pd.DataFrame(embedding_2d, columns=['UMAP 1', 'UMAP 2'])
plot_df['label'] = data_raw.loc[indices,'compound_iso_smiles'].to_numpy()

print("Generating plot...")
plt.figure(figsize=(12, 8))
sns.scatterplot(
    data=plot_df,
    x='UMAP 1',
    y='UMAP 2',
    hue='label',
    #palette='Set1',
    s=5,
    alpha=0.6
)

plt.title('UMAP"', fontsize=16)
plt.xlabel('UMAP Dimension 1')
plt.ylabel('UMAP Dimension 2')
#plt.legend(title='Rule Family (Cluster)', markerscale=2)
plt.legend('')
plt.show()

In [None]:

# 4. Create the plot
plot_df = pd.DataFrame(embedding_2d, columns=['UMAP 1', 'UMAP 2'])
plot_df['label'] = labels[:,1] # gt

print("Generating plot...")
plt.figure(figsize=(12, 8))
sns.scatterplot(
    data=plot_df,
    x='UMAP 1',
    y='UMAP 2',
    hue='label',
    palette='viridis',
    s=5,
    alpha=0.6
)

plt.title('UMAP', fontsize=16)
plt.xlabel('UMAP Dimension 1')
plt.ylabel('UMAP Dimension 2')
#plt.legend(title='Rule Family (Cluster)', markerscale=2)
plt.legend('')
plt.show()

# Modal-specific embeddings

In [None]:
print(max_sample_train, embeddings.shape, embeddings[0].shape)
embeddings_split = embeddings.reshape(max_sample_train * 2, 96)
# Rows alternating: protein_0, ligand_0, protein_1, ligand_1, ...

labels_split = np.repeat(labels, 2, axis=0)

protein_embeddings = embeddings_split[0::2]  # Even indices: 0, 2, 4, ...
ligand_embeddings = embeddings_split[1::2]   # Odd indices: 1, 3, 5, ...

# Extract predictions and ground truth (every other row since they're duplicated)
predictions = labels_split[0::2, 0].astype(float)  # First column, even rows
ground_truth = labels_split[0::2, 1].astype(float)  # Second column, even rows

ts = data_raw.loc[indices, 'target_sequence'].to_numpy()
cis = data_raw.loc[indices, 'compound_iso_smiles']
df_prot_lig = pd.DataFrame({
    'protein_emb': list(protein_embeddings),  # Each element is a 96-dim array
    'ligand_emb': list(ligand_embeddings),    # Each element is a 96-dim array
    'pred': predictions,
    'gt': ground_truth,
    'target_sequence':ts,
    'compound_iso_smiles':data_raw.loc[indices, 'compound_iso_smiles'].to_numpy(),
    'seq_len': data_raw.loc[indices, 'target_sequence'].str.len(),
    'num_C': cis.apply(lambda s: s.lower().count('c')),
    'num_O': cis.apply(lambda s: s.lower().count('o')),
    'num_N': cis.apply(lambda s: s.lower().count('n')),
    'num_H': cis.apply(lambda s: s.lower().count('h')),
    'num_F': cis.apply(lambda s: s.lower().count('f')),
    'seq': ts
})

print(df_prot_lig.shape)
print(df_prot_lig.head())
print("Unique sequences:",len(df_prot_lig.target_sequence.unique()))
print("Unique smiles:",len(df_prot_lig.compound_iso_smiles.unique()))

In [None]:

print("Unique sequences:",len(df_prot_lig.target_sequence.unique()))
print("Unique smiles:",len(df_prot_lig.compound_iso_smiles.unique()))
print("Total df len:", len(df_prot_lig))

## t-SNE

In [None]:
embeddings_split = embeddings.reshape(max_sample_train * 2, 96)
t_sne = TSNE(
    n_components=2,
    perplexity=30,
    init="random",
    max_iter=250,
    random_state=0,
)
pts_t_sne = t_sne.fit_transform(embeddings_split)

In [None]:
labels_repeated = np.repeat(labels, 2, axis=0)           # shape (2*n_samples, 2)
seq_id_repeated = np.repeat(data_raw.loc[indices, 'target_sequence'].to_numpy(), 2)  # shape (2*n_samples,)
smi_id_repeated = np.repeat(data_raw.loc[indices, 'compound_iso_smiles'].to_numpy(), 2)  # shape (2*n_samples,)
data_split_tsne = np.column_stack((
    pts_t_sne[:, 0],       # t-SNE x
    pts_t_sne[:, 1],       # t-SNE y
    labels_repeated,
    seq_id_repeated,
    smi_id_repeated 
))
#print(data_split[:10])
df_split_tsne = pd.DataFrame(data_split_tsne, columns=columns)
half_labels = np.tile(["protein", "ligand"], max_sample_train)
df_split_tsne["type"] = half_labels

print(df_split_tsne.describe())


plt.figure()
sns.scatterplot(df_split_tsne, x='x', y='y', hue='pred')#, style='gt')
plt.show()
plt.close()   


plt.figure()
sns.scatterplot(df_split_tsne, x='x', y='y', hue='gt')#, style='gt')
plt.show()
plt.close()   

plt.figure()
sns.scatterplot(df_split_tsne, x='x', y='y', hue='type')#, style='gt')
plt.show()
plt.close()   

## PCA

In [None]:
df_prot_lig
pca = PCA(n_components=2)
pts_pca = pca.fit_transform(
    df_prot_lig['protein_emb'].to_list() + df_prot_lig['ligand_emb'].to_list()
)
df_prot_lig_pca = pd.DataFrame({
    'x': pts_pca[:, 0],
    'y': pts_pca[:, 1],
    'pred': df_prot_lig['pred'].to_list() + df_prot_lig['pred'].to_list(),
    'gt': df_prot_lig['gt'].to_list() + df_prot_lig['gt'].to_list(),
    'type': ['protein_emb']*max_sample_train + ['ligand_emb']*max_sample_train
})
sns.scatterplot(df_prot_lig_pca, x='x', y='y', hue='pred')
plt.show()
sns.scatterplot(df_prot_lig_pca, x='x', y='y', hue='type')
plt.show()

### Explaining PCA

In [None]:
loadings = pd.Series(pca.components_[0], index=[f"f{i}" for i in range(96)])
top_pos = loadings.sort_values(ascending=False).head(10)
top_neg = loadings.sort_values(ascending=True).head(10)
print("Top positive features (x≈+20):")
print(top_pos)
print("\nTop negative features (x≈–20):")
print(top_neg)


In [None]:
print(np.corrcoef(pts_pca[:,0].astype(float), labels_repeated[:,0].astype(float))[0,1])


If correlation is small (≈0), then PCA isn’t separating by prediction/affinity — it’s likely just structural (protein vs ligand).

If correlation is large, then your model embeddings actually encode affinity along that direction.

## Analysis

Plan

- Separate the analyses — treat protein and ligand embeddings independently.

- Run PCA/UMAP/t-SNE per-type to expose meaningful structure inside each modality.

- Inspect PCA loadings to find which embedding dims matter.

- Probe embeddings with simple supervised probes (linear regression/classifier) to see which dimensions predict affinity/pred.

- Use attribution/perturbation on the base model (saliency, integrated gradients, SHAP on downstream probes) to tie embedding dims back to input features (sequence tokens / SMILES substructures).

- Report & visualize: per-sequence and per-compound results, correlation tables, and representative examples.

In [None]:
#@title Split and run pca separately

from sklearn.decomposition import PCA
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

prot = np.stack(df_prot_lig['protein_emb'].to_numpy())
lig  = np.stack(df_prot_lig['ligand_emb'].to_numpy())
print(prot.shape,lig.shape)

pca_prot = PCA(n_components=2).fit_transform(prot)
pca_lig  = PCA(n_components=2).fit_transform(lig)


# plot side-by-side
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
sns.scatterplot(x=pca_prot[:,0], y=pca_prot[:,1], hue=df_prot_lig['pred'], palette="viridis")
plt.title("Protein PCA (colored by prediction)")
plt.subplot(1,2,2)
sns.scatterplot(x=pca_lig[:,0], y=pca_lig[:,1], hue=df_prot_lig['pred'], palette="viridis")
plt.title("Ligand PCA (colored by prediction)")
plt.show()


In [None]:
pca = PCA(n_components=5).fit(prot) 
loadings = pd.Series(pca.components_[0], index=[f"p{i}" for i in range(96)])
print("Top positive contributors to PC1 (protein):\n", loadings.nlargest(10))
print("\nTop negative contributors to PC1 (protein):\n", loadings.nsmallest(10))

In [None]:
pca = PCA(n_components=5).fit(lig)
loadings = pd.Series(pca.components_[0], index=[f"p{i}" for i in range(96)])
print("Top positive contributors to PC1 (ligand):\n", loadings.nlargest(10))
print("\nTop negative contributors to PC1 (ligand):\n", loadings.nsmallest(10))

In [None]:
import scipy.stats as ss
pc1_prot = pca_prot[:,0]
r_pred_prot = ss.pearsonr(pc1_prot, df_prot_lig['pred'])[0]
r_gt_prot   = ss.pearsonr(pc1_prot, df_prot_lig['gt'])[0]
print("Protein PC1 corr with pred, gt:", r_pred_prot, r_gt_prot)

In [None]:
import scipy.stats as ss
pc1_lig = pca_lig[:,0]
r_pred_lig = ss.pearsonr(pc1_lig, df_prot_lig['pred'])[0]
r_gt_lig   = ss.pearsonr(pc1_lig, df_prot_lig['gt'])[0]
print("Ligand PC1 corr with pred, gt:", r_pred_lig, r_gt_lig)

In [None]:
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import cross_val_score

#  probe protein embeddings
ridge = RidgeCV(alphas=np.logspace(-6,6,13), cv=5)
scores = cross_val_score(ridge, prot, df_prot_lig['gt'], scoring="neg_mean_squared_error", cv=5)
print("Protein probe MSE (CV):", -scores.mean())

# Fit and inspect coefficients
ridge.fit(prot, df_prot_lig['gt'])
coef = pd.Series(ridge.coef_, index=[f"p{i}" for i in range(96)]).sort_values(key=abs, ascending=False)
print("Top dims by absolute coefficient:\n", coef.head(10))


# Embedding space

In [None]:
df_prot_lig = pd.DataFrame({
    'protein_emb': list(protein_embeddings),  # Each element is a 96-dim array
    'ligand_emb': list(ligand_embeddings),    # Each element is a 96-dim array
    'pred': predictions,
    'gt': ground_truth,
    'target_sequence':data_raw.loc[indices, 'target_sequence'].to_numpy(),
    'compound_iso_smiles':data_raw.loc[indices, 'compound_iso_smiles'].to_numpy()
})
df_prot_lig.describe()

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Convert lists of embeddings into arrays
prot_emb_arr = np.stack(df_prot_lig['protein_emb'].to_numpy())  # shape (N, 96)
lig_emb_arr  = np.stack(df_prot_lig['ligand_emb'].to_numpy())    # shape (N, 96)

# Compute correlation of each embedding dim with prediction and ground truth
prot_corr_pred = np.corrcoef(prot_emb_arr.T, df_prot_lig['pred'].to_numpy())[0:96, -1]
prot_corr_gt   = np.corrcoef(prot_emb_arr.T, df_prot_lig['gt'].to_numpy())[0:96, -1]

lig_corr_pred  = np.corrcoef(lig_emb_arr.T, df_prot_lig['pred'].to_numpy())[0:96, -1]
lig_corr_gt    = np.corrcoef(lig_emb_arr.T, df_prot_lig['gt'].to_numpy())[0:96, -1]

# Create DataFrame for heatmap
heatmap_df = pd.DataFrame({
    'Protein': prot_corr_pred,
    #'protein_gt_corr': prot_corr_gt,
    'Ligand': lig_corr_pred,
    #'ligand_gt_corr': lig_corr_gt
})

plt.figure(figsize=(18, 6))
sns.heatmap(heatmap_df.T, cmap='coolwarm', center=0, annot=False, fmt=".2f")
#plt.title("Embedding dimension correlation with prediction / ground truth")
plt.xlabel("Embedding dimension")
plt.tight_layout()
plt.savefig('results/latent_space/embeddings_correlations.svg',bbox_inches='tight')
plt.show()

plt.figure(figsize=(50,6))
sns.heatmap(heatmap_df.T, cmap='coolwarm', center=0, annot=True, fmt=".2f")
plt.title("Embedding dimension correlation with prediction / ground truth")
plt.xlabel("Embedding dimension")
plt.show()


In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Helper function to calculate correlations for a subgroup
def get_seq_correlations(group_df, emb_col, target_col, n_dims=96):
    """
    Calculates the correlation between embedding dimensions and a target
    for a given subgroup DataFrame.
    """
    # Create a reusable Series of NaNs
    nan_series = pd.Series([np.nan] * n_dims, index=[f'dim_{i}' for i in range(n_dims)])
    
    # Find rows where both embedding and target are valid
    valid_indices = group_df[emb_col].notna() & group_df[target_col].notna()
    
    # Need at least 2 samples to compute correlation
    if valid_indices.sum() < 2:
        return nan_series # Return all NaNs

    emb_arr = np.stack(group_df.loc[valid_indices, emb_col].to_numpy())
    target_vals = group_df.loc[valid_indices, target_col].to_numpy()

    # Check for constant target value (correlation is undefined)
    if np.std(target_vals) == 0:
        return nan_series # Return all NaNs

    # np.corrcoef handles constant embedding dims by returning NaN
    corrs = np.corrcoef(emb_arr.T, target_vals)[0:n_dims, -1]
    
    return pd.Series(corrs, index=[f'dim_{i}' for i in range(n_dims)])

# --- 1. Calculate correlations, grouping by sequence ---
print("Calculating correlations for each sequence... This may take a moment.")

prot_pred_corr_by_seq = df_prot_lig.groupby('target_sequence').apply(
    get_seq_correlations, emb_col='protein_emb', target_col='pred', 
    n_dims=96, include_groups=False 
)

prot_gt_corr_by_seq = df_prot_lig.groupby('target_sequence').apply(
    get_seq_correlations, emb_col='protein_emb', target_col='gt', 
    n_dims=96, include_groups=False 
)

lig_pred_corr_by_seq = df_prot_lig.groupby('target_sequence').apply(
    get_seq_correlations, emb_col='ligand_emb', target_col='pred', 
    n_dims=96, include_groups=False 
)

lig_gt_corr_by_seq = df_prot_lig.groupby('target_sequence').apply(
    get_seq_correlations, emb_col='ligand_emb', target_col='gt', 
    n_dims=96, include_groups=False
)

print("Calculations complete. Generating heatmaps...")

# --- 2. Plot the new heatmaps with a safer plotting block ---

def plot_correlation_clustermap(corr_df, title):
    """Helper function to safely plot the clustermap."""

    df_to_plot = corr_df.dropna(how='all', axis=0).dropna(how='all', axis=1)

    if df_to_plot.empty:
        print(f"SKIPPING PLOT: '{title}'")
        print("After cleaning, no valid (non-NaN) data remained to plot.")
        print("-" * 30)
        return

    if df_to_plot.shape[0] < 2 or df_to_plot.shape[1] < 2:
        print(f"SKIPPING PLOT: '{title}'")
        print(f"Not enough data to cluster (shape: {df_to_plot.shape}). Need at least (2, 2).")
        print("-" * 30)
        return

    print(f"Plotting '{title}' with shape {df_to_plot.shape}...")
    try:
        g = sns.clustermap(
            df_to_plot,
            cmap='coolwarm',
            center=0,
            figsize=(12, 10),
            cbar_kws={'label': 'Correlation'},
            xticklabels=False,
            yticklabels=False 
        )
        g.fig.suptitle(title, y=1.02)
        plt.show()
    except Exception as e:
        print(f"ERROR plotting '{title}': {e}")
        print("This can happen if data is still sparse. Skipping.")
        plt.close()
    print("-" * 30)


# --- Generate  plots  ---
plot_correlation_clustermap(
    prot_pred_corr_by_seq, 
    'Protein Embedding vs. Prediction (Clustered by Sequence)'
)

plot_correlation_clustermap(
    prot_gt_corr_by_seq, 
    'Protein Embedding vs. Ground Truth (Clustered by Sequence)'
)

plot_correlation_clustermap(
    lig_pred_corr_by_seq, 
    'Ligand Embedding vs. Prediction (Clustered by Sequence)'
)

plot_correlation_clustermap(
    lig_gt_corr_by_seq, 
    'Ligand Embedding vs. Ground Truth (Clustered by Sequence)'
)

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# 1. Calculate the standard deviation for each dimension (column)

if 'lig_gt_corr_by_seq' in locals() and not lig_gt_corr_by_seq.empty:
    
    # std(axis=0) calculates the stdev for each column
    gt_variability = lig_gt_corr_by_seq.std(axis=0)
    
    # Sort the dimensions from most variable to least variable
    gt_variability_sorted = gt_variability.sort_values(ascending=False)

    print("--- Top 10 Most 'Context-Dependent' Ligand Dimensions (vs. Ground Truth) ---")
    print(gt_variability_sorted.head(10))
    print("\n(Dimension index is on the left, Std. Dev. of correlation is on the right)")


    # 2. Create a bar plot to visualize the variability of all dimensions
    plt.figure(figsize=(18, 6))
    gt_variability_sorted.plot(
        kind='bar', 
        color='mediumpurple',
        width=0.8, # Use width to make bars touch, like a histogram
        edgecolor='black'
    )
    
    plt.title('Variability of Ligand Embedding Dimensions (vs. Ground Truth)', fontsize=16)
    plt.ylabel('Standard Deviation of Correlation', fontsize=12)
    plt.xlabel('Embedding Dimension (Sorted by Variability)', fontsize=12)
    
    plt.xticks([]) 
    
    # Add a note about what this means
    plt.text(0.98, 0.95, 
             'High Bar = High Variability\n(Dim is context-dependent)', 
             ha='right', va='top', transform=plt.gca().transAxes,
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.5))
    
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.show()

else:
    print("Could not find the 'lig_gt_corr_by_seq' DataFrame to analyze.")

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

if ('gt_variability_sorted' not in locals() or 
    'lig_gt_corr_by_seq' not in locals() or
    'df_prot_lig' not in locals()):
    
    print("Error: Required DataFrames ('gt_variability_sorted', 'lig_gt_corr_by_seq', 'df_prot_lig') not found.")
    print("Please re-run the previous code cells.")
else:

    # 1. Get the ID of the most variable dimension 
    top_dim_id = gt_variability_sorted.index[0]
    top_dim_index = int(top_dim_id.split('_')[1])

    print(f"--- Analyzing Top Variable Dimension: {top_dim_id} (index {top_dim_index}) ---")

    # Define a minimum number of samples to trust a correlation
    MIN_SAMPLES_FOR_STATS = 10 
    
    # 2a. Get the sample counts for each sequence
    sequence_counts = df_prot_lig['target_sequence'].value_counts()
    
    # 2b. Find which sequences meet our threshold
    valid_sequences = sequence_counts[sequence_counts >= MIN_SAMPLES_FOR_STATS].index
    
    print(f"Found {len(valid_sequences)} sequences (out of {len(sequence_counts)}) with at least {MIN_SAMPLES_FOR_STATS} samples.")

    # 2c. Get the correlations for our top dimension
    corrs_for_top_dim = lig_gt_corr_by_seq[top_dim_id].dropna()
    
    # 2d. Filter the correlations to ONLY include our valid sequences
    corrs_valid = corrs_for_top_dim.loc[corrs_for_top_dim.index.isin(valid_sequences)]
    
    if corrs_valid.empty:
        print("\n*** ERROR ***")
        print(f"No sequences found with >= {MIN_SAMPLES_FOR_STATS} samples that also have valid correlation data.")
        print("Try lowering 'MIN_SAMPLES_FOR_STATS' (e.g., to 5) and re-running.")
    else:
        # 3. Find the 3 representative sequences from the filtered list
        
        # Sequence with highest positive correlation
        seq_pos_corr = corrs_valid.idxmax()
        val_pos_corr = corrs_valid.max()
        print(f"Positive sequence: {seq_pos_corr[:20]}... (Corr: {val_pos_corr:.2f})")
        
        # Sequence with highest negative correlation
        seq_neg_corr = corrs_valid.idxmin()
        val_neg_corr = corrs_valid.min()
        print(f"Negative sequence: {seq_neg_corr[:20]}... (Corr: {val_neg_corr:.2f})")

        # Sequence with correlation closest to zero
        seq_zero_corr = (corrs_valid - 0).abs().idxmin()
        val_zero_corr = corrs_valid[seq_zero_corr]
        print(f"Neutral sequence:  {seq_zero_corr[:20]}... (Corr: {val_zero_corr:.2f})")

        selected_sequences = [seq_pos_corr, seq_neg_corr, seq_zero_corr]
        
        # 4. Prepare data for plotting
        plot_df = df_prot_lig[df_prot_lig['target_sequence'].isin(selected_sequences)].copy()
        
        plot_df[top_dim_id] = plot_df['ligand_emb'].apply(lambda x: x[top_dim_index] if isinstance(x, np.ndarray) else np.nan)
        
        type_map = {
            seq_pos_corr: f'Positive Corr (r={val_pos_corr:.2f}, N={sequence_counts[seq_pos_corr]})',
            seq_neg_corr: f'Negative Corr (r={val_neg_corr:.2f}, N={sequence_counts[seq_neg_corr]})',
            seq_zero_corr: f'Neutral Corr (r={val_zero_corr:.2f}, N={sequence_counts[seq_zero_corr]})'
        }
        plot_df['Correlation Type'] = plot_df['target_sequence'].map(type_map)

        # 5. Create the scatter plot
        print("\nGenerating new, filtered plot...")
        
        g = sns.lmplot(
            data=plot_df,
            x=top_dim_id,
            y='gt',
            hue='Correlation Type',
            height=6,
            aspect=1.5,
            palette='Set1',
            scatter_kws={'alpha': 0.6, 's': 50},
            legend_out=False
        )
        
        g.set_axis_labels(f'Value of Ligand Embedding Dimension {top_dim_index}', 'Ground Truth Value')
        plt.title(f'Context-Dependent Behavior of Dimension {top_dim_index} (Filtered)', fontsize=16)
        plt.grid(True, linestyle='--', alpha=0.5)
        plt.show()

The Plan:

    Cluster the Sequences: We'll use AgglomerativeClustering to group  sequences into k "families"

    Calculate Average Rules: We'll compute the mean correlation vector for each family.

    Plot the Averages: We'll create a heatmap.

In [None]:
from sklearn.cluster import AgglomerativeClustering
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt


if ('lig_pred_corr_by_seq' not in locals() or 
    'df_prot_lig' not in locals()):
    print("Error: Required DataFrames not found. Please re-run earlier code.")
else:

    # 1. Get the sample counts for each sequence
    MIN_SAMPLES_FOR_CLUSTERING = 10 
    sequence_counts = df_prot_lig['target_sequence'].value_counts()
    valid_sequences = sequence_counts[sequence_counts >= MIN_SAMPLES_FOR_CLUSTERING].index

    print(f"Total sequences before filtering: {len(lig_pred_corr_by_seq)}")
    print(f"Found {len(valid_sequences)} sequences with at least {MIN_SAMPLES_FOR_CLUSTERING} samples.")

    # 2. Filter the correlation data to ONLY include these valid sequences
    valid_corr_data = lig_pred_corr_by_seq.loc[lig_pred_corr_by_seq.index.isin(valid_sequences)]

    # 3. Prepare this clean data for clustering
    data_to_cluster_clean = valid_corr_data.dropna(how='all', axis=0).fillna(0)
    print(f"Clustering {len(data_to_cluster_clean)} valid sequences.")

    # 4. Set up the clustering model
    #    Let's try 3 clusters: "Family A", "Family B", "Noisy Family"
    N_CLUSTERS = 3
    model = AgglomerativeClustering(n_clusters=N_CLUSTERS)

    # 5. Fit the model and get the "family" (cluster) labels
    labels = model.fit_predict(data_to_cluster_clean)

    # 6. Add the labels back to our clean data
    cluster_labels = pd.Series(labels, index=data_to_cluster_clean.index, name='cluster')

    # 7. Let's see how big each "family" is
    print("\n--- Cluster (Family) Sizes ---")
    print(cluster_labels.value_counts().sort_index())

    # 8. Join the labels and calculate the "average rule"
    data_with_labels_clean = data_to_cluster_clean.join(cluster_labels)
    cluster_means_clean = data_with_labels_clean.groupby('cluster').mean()

    # 9. Print the resulting table
    print("\n--- Average Correlation 'Rule' for each family (DataFrame) ---")
    pd.set_option('display.max_columns', None)
    pd.set_option('display.width', 1000)
    print(cluster_means_clean.round(2))

    # 10. Plot the heatmap
    print("\n--- Plotting the 'Average Rule' for each family ---")
    plt.figure(figsize=(40, 3)) 
    sns.heatmap(
        cluster_means_clean,
        cmap='coolwarm',
        center=0,
        annot=True,
        fmt=".2f", 
        annot_kws={"size": 5}
    )
    #plt.title(f'Average Correlation "Rule" for each of the {N_CLUSTERS} Families (Filtered)', fontsize=16)
    plt.xlabel('Embedding Dimension')
    plt.ylabel('Cluster')
    plt.savefig('results/latent_space_families/average_correlation_rule_heatmap.svg', bbox_inches='tight')
    plt.show()
    plt.figure(figsize=(15, 3)) # 3 clusters = less height
    sns.heatmap(
        cluster_means_clean,
        cmap='coolwarm',
        center=0,
        annot=False,
        fmt=".2f", # Use 2 decimal places
        annot_kws={"size": 8}
    )
    plt.title(f'Average Correlation "Rule" for each of the {N_CLUSTERS} Families (Filtered)', fontsize=16)
    plt.xlabel('Embedding Dimension')
    plt.ylabel('Cluster (Family) ID')
    plt.show()

Other runs:

    2 Clusters:
        cluster
    0    167
    1     26

    3 Clusters: 
    cluster
    0    160
    1     26
    2      7

    4 Clusters: 
    cluster
    0      7
    1      3
    2    157
    3     26

    5 Clusters: 
    cluster
    0    157
    1      3
    2      4
    3     26
    4      3

There is a pattern, over segmenting.

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
DATASET_COLORS = ['#95BB63', '#BCBCE0', '#77b5b6', '#EA805D']

def format_label_family_name(label):
    return f'Cluster {int(label)+1}'

if 'cluster_means_clean' in locals():
    # 1. Identify the key dimensions based
    key_dimensions = [
        f'dim_{d}' for d in [0, 11, 19, 20, 23, 31, 35, 39, 43, 60, 74, 77]
    ]
    
    # 2. Extract only these key dimensions from the means table
    key_dim_data = cluster_means_clean[key_dimensions]
    
    # 3. Prepare the data for plotting (melt from wide to long)
    key_dim_data_long = key_dim_data.reset_index().melt(
        id_vars='cluster',
        var_name='Dimension',
        value_name='Correlation'
    )
    
    # 4. Create the grouped bar plot
    print("--- Plotting a comparison of the key 'rule' dimensions ---")
    #plt.figure(figsize=(18, 7))
    fig, ax = get_styled_figure_ax(figsize=(18, 7), aspect='none')
    sns.barplot(
        data=key_dim_data_long,
        x='Dimension',
        y='Correlation',
        hue='cluster',  # This creates the grouped bars
        palette=DATASET_COLORS,
    )
    
    #plt.title('Comparison of Key "Rules" Across the Three Families', fontsize=16)
    plt.xlabel('Embedding Dimension', fontsize=12)
    plt.ylabel('Average Correlation', fontsize=12)
    
    # Add a horizontal line at y=0 for reference
    plt.axhline(0, color='black', linestyle='--', linewidth=0.8)
    
    #plt.legend(title='Cluster (Family) ID')
    style_legend(ax, ncol=3, bbox_to_anchor=(0.5, 1.1), format_labels=format_label_family_name)

    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    plt.savefig('results/latent_space_families/key_dimensions_comparison.svg', dpi=300, bbox_inches='tight')
    
    print("Generated plot 'key_dimensions_comparison.png'")

else:
    print("Could not find 'cluster_means_clean' to plot. Please run the clustering code.")

### The Test

We can test that this is not random by simulating the "random" hypothesis:

1.  **Calculate our "Observed" metric:** We'll take our `cluster_means_clean` table (the 3 "rules") and calculate a single number that represents how different they are. The mean of their standard deviations will be the metric: `cluster_means_clean.std().mean()`. A high value means the rules are very different.
2.  **Run a Simulation (e.g., 1000 times):**
    a. Take our 166 "valid" sequences.
    b. **Shuffle the labels.** Instead of the real families, we'll randomly assign sequences to "fake" families of the same size.
    c. Calculate the "average rules" for these **fake, random** families.
    d. Calculate the metric (the variability) for these fake rules.
3.  **Compare:** We'll end up with 1 "Observed" value and 1000 "Random" values.
      * **If our analysis is real:** Our "Observed" value will be a massive outlier. The 1000 "Random" values will be tiny.
      * **If our analysis is random:** Our "Observed" value will be lost in the middle of the 1000 "Random" values.

The **p-value** will be the fraction of "Random" values that were *larger* than our "Observed" value. A p-value of `0.0` means our results are **not** random.


### Output

We will get two things:

1.  **A p-value** 
2.  **A Histogram**


In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt


if ('data_to_cluster_clean' not in locals() or 
    'cluster_means_clean' not in locals() or
    'cluster_labels' not in locals()):
    
    print("Error: Key data (e.g., 'data_to_cluster_clean') is missing.")
    print("This likely means the session restarted. Please re-run the clustering code first.")

else:
    print("--- Running Permutation Test for Statistical Significance ---")
    
    N_PERMUTATIONS = 1000
    
    # 1. Calculate our ONE "Observed" metric
    #    We use .std() to see how different the clusters are from each other,
    #    and .mean() to get a single number.
    observed_metric = cluster_means_clean.std(axis=0).mean()
    
    print(f"Observed Metric (Variability): {observed_metric:.4f}")
    
    # This holds the correlation data
    base_data = data_to_cluster_clean.copy() 
    
    # This holds the labels (e.g., [0, 1, 2, 0, 1...])
    original_labels = cluster_labels.copy()
    
    random_metrics = [] # We'll store 1000 "random" metrics here
    
    # 2. Run the simulation 1000 times
    for _ in range(N_PERMUTATIONS):
        # a. Shuffle the labels
        shuffled_labels = np.random.permutation(original_labels)
        
        # b. Assign the shuffled labels to the original data
        base_data['shuffled_cluster'] = shuffled_labels
        
        # c. Calculate the "average rules" for these FAKE random families
        random_means = base_data.groupby('shuffled_cluster').mean()
        
        # d. Calculate the metric for the FAKE rules
        random_metric = random_means.std(axis=0).mean()
        
        # e. Store it
        random_metrics.append(random_metric)
        
    print(f"Simulation complete. (Ran {N_PERMUTATIONS} permutations)")
    
    # 3. Compare and calculate p-value
    random_metrics = np.array(random_metrics)
    
    # How many "random" runs were > our "observed" one?
    n_exceeded = np.sum(random_metrics >= observed_metric)
    
    p_value = n_exceeded / N_PERMUTATIONS
    
    print("\n--- Results ---")
    print(f"P-value: {p_value}")
    print(f"(This is the probability of seeing our result by random chance)")

    # 4. Plot the results
    #plt.figure(figsize=(10, 6))
    fig, ax = get_styled_figure_ax(figsize=(15, 8), aspect='none', grid=True)
    sns.histplot(random_metrics, label='Random Permutations (Null Hypothesis)', 
                 bins=30, kde=True, color=DATASET_COLORS[2])
    
    # Plot our "Observed" value as a big red line
    plt.axvline(
        observed_metric, 
        color=DATASET_COLORS[-1], 
        linewidth=3, 
        linestyle='--', 
        label=f'Observed Value: {observed_metric:.4f}'
    )
    
    #plt.title('Permutation Test for Cluster Significance', fontsize=16)
    plt.xlabel('Variability Metric (Higher = More Different)')
    plt.ylabel('Frequency')
    #plt.legend()
    style_legend(ax, ncol=3, bbox_to_anchor=(0.5, 1.1))
    plt.savefig('results/latent_space_families/permutation_test.svg')
    plt.show()

# Feature space of proteins

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import umap.umap_ as umap_learn

# --- Check if we have the necessary data ---
if ('df_prot_lig' not in locals() or 
    'cluster_labels' not in locals()):
    
    print("Error: Key data ('df_prot_lig' or 'cluster_labels') is missing.")
    print("This likely means the session restarted.")
    print("Please re-run the analysis from the beginning (creating df_prot_lig) ")
    print("and the *filtered* clustering step (creating cluster_labels).")
else:

    # --- 1. Get the Unique Protein Embeddings ---
    # We drop duplicates and set the sequence as the index
    unique_prots_df = df_prot_lig.drop_duplicates(
        subset=['target_sequence']
    ).set_index('target_sequence')

    print(f"Found {len(unique_prots_df)} unique protein sequences.")

    # --- 2. Run UMAP to get 2D coordinates ---
    # Stack the list of arrays into a single numpy array
    prot_emb_array = np.stack(unique_prots_df['protein_emb'].values)

    # Initialize UMAP.
    reducer = umap_learn.UMAP(n_components=2, random_state=42, n_neighbors=15, min_dist=0.1)

    print("Running UMAP to reduce 96D protein embeddings to 2D...")
    # Fit UMAP to the protein embeddings
    embedding_2d = reducer.fit_transform(prot_emb_array)
    print("UMAP complete.")

    # --- 3. Create the Final DataFrame for Plotting ---
    # Create a new DataFrame with the 2D coordinates
    plot_df = pd.DataFrame(
        embedding_2d,
        columns=['UMAP 1', 'UMAP 2'],
        index=unique_prots_df.index # Use sequence as index
    )
    
    # --- 4. Merge our 'Family' (cluster_labels) ---
    plot_df = plot_df.join(cluster_labels)
    
    # Fill the 'NaN's so they show up in the plot legend
    # These are the sequences we didn't have enough data to test
    plot_df['cluster'] = plot_df['cluster'].fillna('Not in Test')
    
    # Convert cluster labels to string for categorical coloring
    plot_df['cluster'] = plot_df['cluster'].astype(str)

    # --- 5. Plot the Final Visualization ---
    print("Generating plot...")
    #plt.figure(figsize=(12, 8))
    fig, ax = get_styled_figure_ax(figsize=(15, 8), aspect='none', grid=True)
    sns.scatterplot(
        data=plot_df,
        x='UMAP 1',
        y='UMAP 2',
        hue='cluster', # Color by our "family"
        palette='Set1',
        s=50,          
        alpha=0.8
    )
    
    
    #plt.title('Protein Embeddings (UMAP) Colored by "Rule Family"', fontsize=16)
    plt.xlabel('UMAP Dimension 1')
    plt.ylabel('UMAP Dimension 2')
    #plt.legend(title='Rule Family (Cluster)')
    style_legend(ax, ncol=4, bbox_to_anchor=(0.5, 1.1))
    #plt.grid(True, linestyle='--', alpha=0.5)
    plt.savefig('results/latent_space_families/protein_emb_umap_col-rulefamily.svg')
    plt.show()
    plt.close()

# t-SNE

In [None]:
# Weighted average of outer products by prediction
interaction_matrix = np.zeros((96, 96))
for i in range(len(df_prot_lig)):
    p = prot_emb_arr[i][:, None]  # (96,1)
    l = lig_emb_arr[i][None, :]  # (1,96)
    interaction_matrix += df_prot_lig['pred'].iloc[i] * (p @ l)  # weighted outer product

# Normalize
interaction_matrix /= len(df_prot_lig)

plt.figure(figsize=(10,8))
sns.heatmap(interaction_matrix, cmap='coolwarm', center=0)
plt.title("Protein × Ligand embedding interaction heatmap (weighted by prediction)")
plt.xlabel("Ligand embedding dim")
plt.ylabel("Protein embedding dim")
plt.show()


In [None]:
plt.figure(figsize=(10,8))
sns.heatmap(interaction_matrix, 
            cmap='coolwarm', 
            center=0, 
            robust=True) # tells Seaborn to calculate the colormap range using quantiles
            # (by default, the 2nd and 98th percentiles) instead of the 
            # absolute minimum and maximum.
plt.title("Protein × Ligand interaction heatmap (robust scaling)")
plt.xlabel("Ligand embedding dim")
plt.ylabel("Protein embedding dim")
plt.show()

In [None]:
sample_heatmap = np.hstack([prot_emb_arr, lig_emb_arr])  # shape (N, 192)
sns.heatmap(sample_heatmap, cmap='coolwarm', center=0)
plt.title("All samples embedding heatmap")
plt.xlabel("Embedding dimensions (protein + ligand)")
plt.ylabel("Samples")
plt.show()


> Seems like only few dimensions are useful

In [None]:
from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Select top predictive dims only
top_dims = np.argsort(np.abs(prot_corr_pred))[-10:] 
prot_selected = prot_emb_arr[:, top_dims]

# Concatenate with prediction or ground-truth (as a feature)
prot_tsne_input = np.hstack([prot_selected, df_prot_lig['pred'].to_numpy()[:, None]])

tsne = TSNE(n_components=2, perplexity=30, random_state=42)
pts_tsne = tsne.fit_transform(prot_tsne_input)

plt.figure(figsize=(6,5))
sns.scatterplot(x=pts_tsne[:,0], y=pts_tsne[:,1], hue=df_prot_lig['pred'], palette='viridis')
plt.title("t-SNE: Protein embeddings + prediction")
plt.show()


In [None]:
plt.figure(figsize=(6,5))
sns.scatterplot(x=pts_tsne[:,0], y=pts_tsne[:,1], hue=df_prot_lig['seq'], palette='coolwarm')
plt.legend('')
plt.title("t-SNE: Protein embeddings + prediction")
plt.show()

In [None]:
top_dims_lig = np.argsort(np.abs(lig_corr_pred))[-10:]
lig_selected = lig_emb_arr[:, top_dims_lig]
lig_tsne_input = np.hstack([lig_selected, df_prot_lig['pred'].to_numpy()[:, None]])

pts_tsne_lig = TSNE(n_components=2, perplexity=30, random_state=42).fit_transform(lig_tsne_input)

sns.scatterplot(x=pts_tsne_lig[:,0], y=pts_tsne_lig[:,1], hue=df_prot_lig['pred'], palette='viridis')
plt.title("t-SNE: Ligand embeddings + prediction")
plt.show()

In [None]:
sns.scatterplot(x=pts_tsne_lig[:,0], y=pts_tsne_lig[:,1], hue=df_prot_lig['num_C'], palette='viridis')
plt.title("t-SNE: Ligand embeddings + prediction")
plt.show()

In [None]:
combined_top_dims = np.hstack([prot_selected, lig_selected, df_prot_lig['pred'].to_numpy()[:, None]])
pts_tsne_comb = TSNE(n_components=2, perplexity=30, random_state=42).fit_transform(combined_top_dims)

In [None]:
sns.scatterplot(x=pts_tsne_comb[:,0], y=pts_tsne_comb[:,1], hue=df_prot_lig['pred'], palette='viridis')
plt.title("t-SNE: Top protein+ligand dims + prediction")
plt.show()
sns.scatterplot(x=pts_tsne_comb[:,0], y=pts_tsne_comb[:,1], hue=df_prot_lig['gt'], palette='viridis')
plt.title("t-SNE: Top protein+ligand dims + prediction")
plt.show()

In [None]:
tsne_pts = pts_tsne_comb   # shape (N,2)
df_tsne = pd.DataFrame({
    'x': tsne_pts[:,0],
    'y': tsne_pts[:,1],
    'pred': df_prot_lig['pred'],
    'gt': df_prot_lig['gt'],
    'seq_len': df_prot_lig['target_sequence'].str.len(),
    'num_C': df_prot_lig['compound_iso_smiles'].apply(lambda s: s.lower().count('c')),
    'num_O': df_prot_lig['compound_iso_smiles'].apply(lambda s: s.lower().count('o')),
    'num_N': df_prot_lig['compound_iso_smiles'].apply(lambda s: s.lower().count('n')),
    'num_H': df_prot_lig['compound_iso_smiles'].apply(lambda s: s.lower().count('h')),
    'num_F': df_prot_lig['compound_iso_smiles'].apply(lambda s: s.lower().count('f')),
    'seq': df_prot_lig['target_sequence']
})

plt.figure(figsize=(6,5))
sns.scatterplot(data=df_tsne, x='x', y='y', hue='pred', palette='viridis', size='seq_len', sizes=(20,200))
plt.title("t-SNE colored by prediction, size = sequence length")
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.show()

for a in 'CON':
    plt.figure(figsize=(6,5))
    sns.scatterplot(data=df_tsne, x='x', y='y', hue=f'num_{a}', palette='coolwarm')
    plt.title(f"t-SNE colored by number of {a} in ligand")
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.show()

plt.figure(figsize=(6,5))
sns.scatterplot(data=df_tsne, x='x', y='y', hue='seq', palette='coolwarm')
plt.title("t-SNE colored by sequence")
plt.legend('') 
plt.show()


In [None]:
combined_top_dims = np.hstack([prot_selected, lig_selected, df_prot_lig['pred'].to_numpy()[:, None]])
pca = PCA(n_components=2)
pts_pca_comb = pca.fit_transform(combined_top_dims)
sns.scatterplot(x=pts_pca_comb[:,0], y=pts_pca_comb[:,1], hue=df_prot_lig['pred'], palette='viridis')
plt.title("PCA: Top protein+ligand dims + prediction")
plt.show()
sns.scatterplot(x=pts_pca_comb[:,0], y=pts_pca_comb[:,1], hue=df_prot_lig['gt'], palette='viridis')
plt.title("PCA: Top protein+ligand dims + prediction")
plt.show()
sns.scatterplot(x=pts_pca_comb[:,0], y=pts_pca_comb[:,1], hue=df_prot_lig['seq_len'], palette='coolwarm')
plt.title("PCA: Top protein+ligand dims + prediction")
plt.show()

In [None]:
from pandas import Series
import numpy as np

pca = PCA(n_components=2)
pts_pca_comb = pca.fit_transform(combined_top_dims)

# Loadings
loadings = pca.components_  # shape (2, n_dims)
n_prot = prot_selected.shape[1]
n_lig  = lig_selected.shape[1]

# PC1 loadings
pc1_loadings = Series(loadings[0], index=[f"P{i}" for i in range(n_prot)] +
                                      [f"L{i}" for i in range(n_lig)] +
                                      ['pred'])
pc2_loadings = Series(loadings[1], index=[f"P{i}" for i in range(n_prot)] +
                                      [f"L{i}" for i in range(n_lig)] +
                                      ['pred'])

print("Top positive PC1 dims:")
print(pc1_loadings.nlargest(10))
print("\nTop negative PC1 dims:")
print(pc1_loadings.nsmallest(10))


In [None]:
pc1 = pts_pca_comb[:,0]
pc2 = pts_pca_comb[:,1]

df_prot_lig = pd.DataFrame({
    'protein_emb': list(protein_embeddings),  # Each element is a 96-dim array
    'ligand_emb': list(ligand_embeddings),    # Each element is a 96-dim array
    'pred': predictions,
    'gt': ground_truth,
    'target_sequence':data_raw.loc[indices, 'target_sequence'].to_numpy(),
    'compound_iso_smiles':data_raw.loc[indices, 'compound_iso_smiles'].to_numpy(),
    'seq_len': df_prot_lig['target_sequence'].str.len(),
    'num_C': df_prot_lig['compound_iso_smiles'].apply(lambda s: s.lower().count('c')),
    'num_O': df_prot_lig['compound_iso_smiles'].apply(lambda s: s.lower().count('o')),
    'num_N': df_prot_lig['compound_iso_smiles'].apply(lambda s: s.lower().count('n')),
    'num_H': df_prot_lig['compound_iso_smiles'].apply(lambda s: s.lower().count('h')),
    'num_F': df_prot_lig['compound_iso_smiles'].apply(lambda s: s.lower().count('f')),
    'seq': df_prot_lig['target_sequence']
})

for feature in ['pred','gt','seq_len','num_C','num_O','num_N']:
    corr_x = np.corrcoef(pc1, df_prot_lig[feature])[0,1]
    corr_y = np.corrcoef(pc2, df_prot_lig[feature])[0,1]
    print(f"{feature}: corr with PC1={corr_x:.2f}, PC2={corr_y:.2f}")


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(12,4))
sns.barplot(x=pc1_loadings.index, y=pc1_loadings.values)
plt.xticks(rotation=90)
plt.title("PCA PC1 loadings (contribution of each dimension)")
plt.show()

plt.figure(figsize=(12,4))
sns.barplot(x=pc2_loadings.index, y=pc2_loadings.values)
plt.xticks(rotation=90)
plt.title("PCA PC2 loadings (contribution of each dimension)")
plt.show()