# Latent space before-after training

This notebook aims to visualize (PCA/tSNE) embeddings space before & after training.

# init

In [None]:
import sys, os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

import numpy as np
import torch
import pandas as pd
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

from MGraphDTA.regression.preprocessing import GNNDatasetFull, GNNDataset
from xai_dta.utils.models import load_model, load_untrained_model
from xai_dta.config import PROJ_ROOT

%config InlineBackend.figure_format = 'svg'


from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats('svg')

from src.utils.plot_utils import set_plot_style, get_styled_figure_ax, style_legend, adjust_plot_limits
DATASET_COLORS = DATASET_COLORS = ['#95BB63', '#7E7EA5', '#487C7D', '#EA805D', '#4897DC']

In [None]:
import importlib
import src.utils.plot_utils
importlib.reload(src.utils.plot_utils)
from src.utils.plot_utils import set_plot_style, get_styled_figure_ax, style_legend, adjust_plot_limits

In [None]:
device='cpu'
dataset_name='kiba'
split='test'

root = os.path.abspath(os.path.join(os.getcwd(), '..'))
model_trained = load_model(
    PROJ_ROOT, 
    model_name='kiba/epoch-2578, loss-0.0122, cindex-0.9741, test_loss-0.1265.pt'
)
model_untrained = load_untrained_model()
dataset_path = os.path.join(PROJ_ROOT, "MGraphDTA", "regression", "data", dataset_name)
dataset = GNNDataset(
    root = dataset_path,
    train=split=='train',
    transform_unique=True
)
max_sample_train = len(dataset)
data_raw = pd.read_csv(os.path.join(dataset_path, 'raw', f"data_{split}.csv"))

In [None]:
filename = f'embeddings_train_kiba_{max_sample_train}.npy'
filenamelabel = f'labels_train_kiba_{max_sample_train}.npy'
filenameindices = f'indices_train_kiba_{max_sample_train}.npy'
folder = 'results/latent_space'
embeddings = np.load(os.path.join(folder, filename))
labels = np.load(os.path.join(folder, filenamelabel))
indices = np.load(os.path.join(folder, filenameindices))
print(f"Loaded {embeddings.shape} embeddings; {labels.shape} labels")

In [None]:


filename = f'embeddings_train_kiba_{max_sample_train}_untrained.npy'
filenamelabel = f'labels_train_kiba_{max_sample_train}_untrained.npy'
embeddings_untrained = np.zeros((max_sample_train, 192))
labels_untrained = np.zeros((max_sample_train, 2)) 

i = 0
for idx in tqdm(indices):
    smi, seq = data_raw.iloc[idx,:][['compound_iso_smiles', 'target_sequence']]
    data, _ = dataset.transform_unique(smi, seq)

    with torch.no_grad():
        out, protein_emb, ligand_emb = model_untrained.forward_features(data)
        # each: [[...]]
    
    embedding_untrained = torch.hstack((protein_emb, ligand_emb)).cpu().numpy()[0]
    # [bs, 192]

    embeddings_untrained[i] = embedding_untrained
    labels_untrained[i] = np.array((f"{out.item():.2f}", f"{data_raw.loc[idx,'affinity']:.2f}"))
    i += 1

np.save(os.path.join(folder, filename), embeddings_untrained)
np.save(os.path.join(folder, filenamelabel), labels_untrained)
print(f"Created {embeddings_untrained.shape} untrained embeddings; {labels_untrained.shape} untrained labels")

In [None]:
filename = f'embeddings_train_kiba_{max_sample_train}_untrained.npy'
filenamelabel = f'labels_train_kiba_{max_sample_train}_untrained.npy'
embeddings_untrained = np.load(os.path.join(folder, filename))
labels_untrained = np.load(os.path.join(folder, filenamelabel))


In [None]:
ts = data_raw.loc[indices, 'target_sequence'].to_numpy()
cis = data_raw.loc[indices, 'compound_iso_smiles']

embeddings_split_untrained = embeddings_untrained.reshape(max_sample_train * 2, 96)
labels_split_untrained = np.repeat(labels_untrained, 2, axis=0)
protein_embeddings_untrained = embeddings_split_untrained[0::2]  # Even indices: 0, 2, 4, ...
ligand_embeddings_untrained = embeddings_split_untrained[1::2]   # Odd indices: 1, 3, 5, ...
predictions_untrained = labels_split_untrained[0::2, 0].astype(float)  # First column, even rows
ground_truth_untrained = labels_split_untrained[0::2, 1].astype(float)  # Second column, even rows
df_prot_lig_untrained = pd.DataFrame({
    'protein_emb': list(protein_embeddings_untrained),  # Each element is a 96-dim array
    'ligand_emb': list(ligand_embeddings_untrained),    # Each element is a 96-dim array
    'pred': predictions_untrained,
    'gt': ground_truth_untrained,
    'target_sequence':ts,
    'compound_iso_smiles':data_raw.loc[indices, 'compound_iso_smiles'].to_numpy(),
    'seq_len': data_raw.loc[indices, 'target_sequence'].str.len(),
    'num_C': cis.apply(lambda s: s.lower().count('c')),
    'num_O': cis.apply(lambda s: s.lower().count('o')),
    'num_N': cis.apply(lambda s: s.lower().count('n')),
    'num_H': cis.apply(lambda s: s.lower().count('h')),
    'num_F': cis.apply(lambda s: s.lower().count('f')),
    'seq': ts
})
print(df_prot_lig_untrained.shape)
print("Unique sequences:",len(df_prot_lig_untrained.target_sequence.unique()))
print("Unique smiles:",len(df_prot_lig_untrained.compound_iso_smiles.unique()))


embeddings_split = embeddings.reshape(max_sample_train * 2, 96)
labels_split = np.repeat(labels, 2, axis=0)
protein_embeddings = embeddings_split[0::2]  # Even indices: 0, 2, 4, ...
ligand_embeddings = embeddings_split[1::2]   # Odd indices: 1, 3, 5, ...
predictions = labels_split[0::2, 0].astype(float)  # First column, even rows
ground_truth = labels_split[0::2, 1].astype(float)  # Second column, even rows
ts = data_raw.loc[indices, 'target_sequence'].to_numpy()
cis = data_raw.loc[indices, 'compound_iso_smiles']
df_prot_lig = pd.DataFrame({
    'protein_emb': list(protein_embeddings),  # Each element is a 96-dim array
    'ligand_emb': list(ligand_embeddings),    # Each element is a 96-dim array
    'pred': predictions,
    'gt': ground_truth,
    'target_sequence':ts,
    'compound_iso_smiles':data_raw.loc[indices, 'compound_iso_smiles'].to_numpy(),
    'seq_len': data_raw.loc[indices, 'target_sequence'].str.len(),
    'num_C': cis.apply(lambda s: s.lower().count('c')),
    'num_O': cis.apply(lambda s: s.lower().count('o')),
    'num_N': cis.apply(lambda s: s.lower().count('n')),
    'num_H': cis.apply(lambda s: s.lower().count('h')),
    'num_F': cis.apply(lambda s: s.lower().count('f')),
    'seq': ts
})

print(df_prot_lig.shape)
print("Unique sequences:",len(df_prot_lig.target_sequence.unique()))
print("Unique smiles:",len(df_prot_lig.compound_iso_smiles.unique()))

# 192 embeddings

## t-SNE

## PCA

### Untrained

In [None]:
pca = PCA(n_components=2)
pts_pca_untrained = pca.fit_transform(
    df_prot_lig_untrained['protein_emb'].to_list() + df_prot_lig_untrained['ligand_emb'].to_list()
)
df_prot_lig_pca_untrained = pd.DataFrame({
    'x': pts_pca_untrained[:, 0],
    'y': pts_pca_untrained[:, 1],
    'pred': df_prot_lig_untrained['pred'].to_list() + df_prot_lig_untrained['pred'].to_list(),
    'gt': df_prot_lig_untrained['gt'].to_list() + df_prot_lig_untrained['gt'].to_list(),
    'type': ['protein_emb']*max_sample_train + ['ligand_emb']*max_sample_train
})

In [None]:
df_prot_lig_pca_untrained['pred_bin'] = pd.qcut(
    df_prot_lig_pca_untrained['pred'], 
    q=5, 
    labels=False,
    duplicates='drop'
)

df_prot_lig_pca_untrained['gt_bin'] = pd.qcut(
    df_prot_lig_pca_untrained['gt'], 
    q=5, 
    labels=False,
    duplicates='drop'
)

In [None]:
fig, ax = get_styled_figure_ax(aspect='none',grid=False)
sns.scatterplot(df_prot_lig_pca_untrained, x='x', y='y', hue='pred', palette=DATASET_COLORS)
plt.title("PCA on untrained model's embeddings [pred]")
style_legend(ax, ncol=5, bbox_to_anchor=(0.5, 1.15))
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.show()

fig, ax = get_styled_figure_ax(aspect='none',grid=False)
sns.scatterplot(df_prot_lig_pca_untrained, x='x', y='y', hue='gt_bin', palette=DATASET_COLORS)
plt.title("PCA on untrained model's embeddings [gt]")
style_legend(ax, ncol=5, bbox_to_anchor=(0.5, 1.15))
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.show()

fig, ax = get_styled_figure_ax(aspect='none',grid=False)
sns.scatterplot(df_prot_lig_pca_untrained, x='x', y='y', hue='type', palette=DATASET_COLORS)
plt.title("PCA on untrained model's embeddings [type]")
style_legend(ax, ncol=2, bbox_to_anchor=(0.5, 1.15))
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.show()

### Trained

In [None]:
pca = PCA(n_components=2)
pts_pca = pca.fit_transform(
    df_prot_lig['protein_emb'].to_list() + df_prot_lig['ligand_emb'].to_list()
)
df_prot_lig_pca = pd.DataFrame({
    'x': pts_pca[:, 0],
    'y': pts_pca[:, 1],
    'pred': df_prot_lig['pred'].to_list() + df_prot_lig['pred'].to_list(),
    'gt': df_prot_lig['gt'].to_list() + df_prot_lig['gt'].to_list(),
    'type': ['protein_emb']*max_sample_train + ['ligand_emb']*max_sample_train
})

In [None]:
df_prot_lig_pca['pred_bin'] = pd.qcut(
    df_prot_lig_pca['pred'], 
    q=5, 
    labels=False,
    duplicates='drop'
)

df_prot_lig_pca['gt_bin'] = pd.qcut(
    df_prot_lig_pca['gt'], 
    q=5, 
    labels=False,
    duplicates='drop'
)

In [None]:
fig, ax = get_styled_figure_ax(aspect='none',grid=False)
sns.scatterplot(df_prot_lig_pca, x='x', y='y', hue='pred_bin', palette=DATASET_COLORS)
plt.title("PCA on trained model's embeddings [pred]")
style_legend(ax, ncol=5, bbox_to_anchor=(0.5, 1.15))
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.show()

fig, ax = get_styled_figure_ax(aspect='none',grid=False)
sns.scatterplot(df_prot_lig_pca, x='x', y='y', hue='gt_bin', palette=DATASET_COLORS)
plt.title("PCA on trained model's embeddings [gt]")
style_legend(ax, ncol=5, bbox_to_anchor=(0.5, 1.15))
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.show()

fig, ax = get_styled_figure_ax(aspect='none',grid=False)
sns.scatterplot(df_prot_lig_pca, x='x', y='y', hue='type', palette=DATASET_COLORS)
plt.title("PCA on trained model's embeddings [type]")
style_legend(ax, ncol=2, bbox_to_anchor=(0.5, 1.15))
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.show()

# 96 embeddings

## Proteins

### Untrained

In [None]:
pca = PCA(n_components=2)
pts_pca_untrained = pca.fit_transform(
    df_prot_lig_untrained['protein_emb'].to_list()
)
df_prot_lig_pca_untrained = pd.DataFrame({
    'x': pts_pca_untrained[:, 0],
    'y': pts_pca_untrained[:, 1],
    'pred': df_prot_lig_untrained['pred'].to_list(),
    'gt': df_prot_lig_untrained['gt'].to_list(),
    'type': ['protein_emb']*max_sample_train 
})

In [None]:
df_prot_lig_pca_untrained['pred_bin'] = pd.qcut(
    df_prot_lig_pca_untrained['pred'], 
    q=5, 
    labels=False,
    duplicates='drop'
)

df_prot_lig_pca_untrained['gt_bin'] = pd.qcut(
    df_prot_lig_pca_untrained['gt'], 
    q=5, 
    labels=False,
    duplicates='drop'
)

In [None]:
fig, ax = get_styled_figure_ax(aspect='none',grid=False)
sns.scatterplot(df_prot_lig_pca_untrained, x='x', y='y', hue='pred', palette=DATASET_COLORS)
plt.title("PCA on untrained model's protein embeddings [pred]")
style_legend(ax, ncol=5, bbox_to_anchor=(0.5, 1.15))
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.savefig('results/latent_space/pca_on_untrained_model_96prot-embeddings_pred.svg')
plt.close()

fig, ax = get_styled_figure_ax(aspect='none',grid=False)
sns.scatterplot(df_prot_lig_pca_untrained, x='x', y='y', hue='gt_bin', palette=DATASET_COLORS)
plt.title("PCA on untrained model's protein embeddings [gt]")
style_legend(ax, ncol=5, bbox_to_anchor=(0.5, 1.15))
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.savefig('results/latent_space/pca_on_untrained_model_96prot-embeddings_gt.svg')
plt.close()

fig, ax = get_styled_figure_ax(aspect='none',grid=False)
sns.scatterplot(df_prot_lig_pca_untrained, x='x', y='y', hue='type', palette=DATASET_COLORS)
plt.title("PCA on untrained model's protein embeddings [type]")
style_legend(ax, ncol=2, bbox_to_anchor=(0.5, 1.15))
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.savefig('results/latent_space/pca_on_untrained_model_96prot-embeddings_type.svg')
plt.close()

### Trained

In [None]:
pca = PCA(n_components=2)
pts_pca = pca.fit_transform(
    df_prot_lig['protein_emb'].to_list()
)
df_prot_lig_pca = pd.DataFrame({
    'x': pts_pca[:, 0],
    'y': pts_pca[:, 1],
    'pred': df_prot_lig['pred'].to_list(),
    'gt': df_prot_lig['gt'].to_list(),
    'type': ['protein_emb']*max_sample_train
})

df_prot_lig_pca['pred_bin'] = pd.qcut(
    df_prot_lig_pca['pred'], 
    q=5, 
    labels=False,
    duplicates='drop'
)

df_prot_lig_pca['gt_bin'] = pd.qcut(
    df_prot_lig_pca['gt'], 
    q=5, 
    labels=False,
    duplicates='drop'
)

fig, ax = get_styled_figure_ax(aspect='none',grid=False)
sns.scatterplot(df_prot_lig_pca, x='x', y='y', hue='pred_bin', palette=DATASET_COLORS)
plt.title("PCA on trained model's protein embeddings [pred]")
style_legend(ax, ncol=5, bbox_to_anchor=(0.5, 1.15))
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.savefig('results/latent_space/pca_on_trained_model_96prot-embeddings_pred.svg')
plt.close()

fig, ax = get_styled_figure_ax(aspect='none',grid=False)
sns.scatterplot(df_prot_lig_pca, x='x', y='y', hue='gt_bin', palette=DATASET_COLORS)
plt.title("PCA on trained model's protein embeddings [gt]")
style_legend(ax, ncol=5, bbox_to_anchor=(0.5, 1.15))
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.savefig('results/latent_space/pca_on_trained_model_96prot-embeddings_gt.svg')
plt.close()

fig, ax = get_styled_figure_ax(aspect='none',grid=False)
sns.scatterplot(df_prot_lig_pca, x='x', y='y', hue='type', palette=DATASET_COLORS)
plt.title("PCA on trained model's protein embeddings [type]")
style_legend(ax, ncol=2, bbox_to_anchor=(0.5, 1.15))
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.savefig('results/latent_space/pca_on_trained_model_96prot-embeddings_type.svg')
plt.close()

## Ligand

### Untrained

In [None]:
pca = PCA(n_components=2)
pts_pca_untrained = pca.fit_transform(
    df_prot_lig_untrained['ligand_emb'].to_list()
)
df_prot_lig_pca_untrained = pd.DataFrame({
    'x': pts_pca_untrained[:, 0],
    'y': pts_pca_untrained[:, 1],
    'pred': df_prot_lig_untrained['pred'].to_list(),
    'gt': df_prot_lig_untrained['gt'].to_list(),
    'type': ['ligand_emb']*max_sample_train 
})

df_prot_lig_pca_untrained['pred_bin'] = pd.qcut(
    df_prot_lig_pca_untrained['pred'], 
    q=5, 
    labels=False,
    duplicates='drop'
)

df_prot_lig_pca_untrained['gt_bin'] = pd.qcut(
    df_prot_lig_pca_untrained['gt'], 
    q=5, 
    labels=False,
    duplicates='drop'
)

fig, ax = get_styled_figure_ax(aspect='none',grid=False)
sns.scatterplot(df_prot_lig_pca_untrained, x='x', y='y', hue='pred', palette=DATASET_COLORS)
plt.title("PCA on untrained model's ligand embeddings [pred]")
style_legend(ax, ncol=5, bbox_to_anchor=(0.5, 1.15))
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.savefig('results/latent_space/pca_on_untrained_model_96ligand-embeddings_pred.svg')
plt.close()

fig, ax = get_styled_figure_ax(aspect='none',grid=False)
sns.scatterplot(df_prot_lig_pca_untrained, x='x', y='y', hue='gt_bin', palette=DATASET_COLORS)
plt.title("PCA on untrained model's ligand embeddings [gt]")
style_legend(ax, ncol=5, bbox_to_anchor=(0.5, 1.15))
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.savefig('results/latent_space/pca_on_untrained_model_96ligand-embeddings_gt.svg')
plt.close()

fig, ax = get_styled_figure_ax(aspect='none',grid=False)
sns.scatterplot(df_prot_lig_pca_untrained, x='x', y='y', hue='type', palette=DATASET_COLORS)
plt.title("PCA on untrained model's ligand embeddings [type]")
style_legend(ax, ncol=2, bbox_to_anchor=(0.5, 1.15))
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.savefig('results/latent_space/pca_on_untrained_model_96ligand-embeddings_type.svg')
plt.close()

### Trained

In [None]:
pca = PCA(n_components=2)
pts_pca = pca.fit_transform(
    df_prot_lig['ligand_emb'].to_list()
)
df_prot_lig_pca = pd.DataFrame({
    'x': pts_pca[:, 0],
    'y': pts_pca[:, 1],
    'pred': df_prot_lig['pred'].to_list(),
    'gt': df_prot_lig['gt'].to_list(),
    'type': ['ligand_emb']*max_sample_train
})

df_prot_lig_pca['pred_bin'] = pd.qcut(
    df_prot_lig_pca['pred'], 
    q=5, 
    labels=False,
    duplicates='drop'
)

df_prot_lig_pca['gt_bin'] = pd.qcut(
    df_prot_lig_pca['gt'], 
    q=5, 
    labels=False,
    duplicates='drop'
)

fig, ax = get_styled_figure_ax(aspect='none',grid=True)
sns.scatterplot(df_prot_lig_pca, x='x', y='y', hue='pred_bin', palette=DATASET_COLORS)
plt.title("PCA on trained model's ligand embeddings [pred]")
style_legend(ax, ncol=5, bbox_to_anchor=(0.5, 1.15))
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.savefig('results/latent_space/pca_on_trained_model_96ligand-embeddings_pred.svg')
plt.close()

fig, ax = get_styled_figure_ax(aspect='none',grid=True)
sns.scatterplot(df_prot_lig_pca, x='x', y='y', hue='gt_bin', palette=DATASET_COLORS)
plt.title("PCA on trained model's ligand embeddings [gt]")
style_legend(ax, ncol=5, bbox_to_anchor=(0.5, 1.15))
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.savefig('results/latent_space/pca_on_trained_model_96ligand-embeddings_gt.svg')
plt.close()

fig, ax = get_styled_figure_ax(aspect='none',grid=True)
sns.scatterplot(df_prot_lig_pca, x='x', y='y', hue='type', palette=DATASET_COLORS)
plt.title("PCA on trained model's ligand embeddings [type]")
style_legend(ax, ncol=2, bbox_to_anchor=(0.5, 1.15))
ax.set_xlabel("PCA 1")
ax.set_ylabel("PCA 2")
plt.savefig('results/latent_space/pca_on_trained_model_96ligand-embeddings_type.svg')
plt.close()

#### What dimensions drive PCA?

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# 1. Get the components (The loadings)
# shape: (2, number_of_embedding_dimensions)
loadings = pca.components_

# 2. Create a DataFrame for better readability
n_features = loadings.shape[1]
feature_names = [f"Dim_{i}" for i in range(n_features)]

df_loadings = pd.DataFrame(
    loadings.T,  # Transpose so rows are original dims, cols are PCs
    index=feature_names,
    columns=['PC1', 'PC2']
)

# 3. identify the top contributing dimensions for PC1
# abs() => a large negative weight is just as influential as a large positive one
top_pc1_drivers = df_loadings.iloc[df_loadings['PC1'].abs().argsort()[::-1]]

print("Top 5 Dimensions driving PC1:")
print(top_pc1_drivers[['PC1']].head(5))

# 4. Identify the top contributing dimensions for PC2
top_pc2_drivers = df_loadings.iloc[df_loadings['PC2'].abs().argsort()[::-1]]

print("\nTop 5 Dimensions driving PC2:")
print(top_pc2_drivers[['PC2']].head(5))

In [None]:
fig, ax = get_styled_figure_ax(figsize=(12, 7), aspect='none',grid=True)
# Plotting the absolute values to visualize "strength" of contribution
sns.heatmap(np.abs(pca.components_), cmap='viridis', yticklabels=['PC1', 'PC2'])
plt.title("Heatmap of Feature Importance for Ligand-only embeddings\n(Absolute Loadings)")
plt.xlabel("Original Embedding Dimensions")
plt.savefig('results/latent_space/pca_on_trained_model_96ligand-embeddings_heatmap-pca-components.svg')
plt.close()

In [None]:
# Get top 10 dimensions based on PC1 importance
top_10_indices = df_loadings['PC1'].abs().nlargest(10).index
top_10_data = df_loadings.loc[top_10_indices]

fig, ax = get_styled_figure_ax(figsize=(12, 12), aspect='none',grid=True)
top_10_data.plot(kind='bar', color=DATASET_COLORS, ax=ax, grid=True, legend=True)
plt.title("Loadings of Top 10 Dimensions (sorted by PC1 impact)")
plt.ylabel("Weight (Loading)")
plt.axhline(0, color='black', linewidth=0.8)
style_legend(ax, ncol=2, bbox_to_anchor=(0.5, 1.15))
plt.savefig('results/latent_space/pca_on_trained_model_96ligand-embeddings_batplot-pca-components.svg')
plt.close();
