# Libraries

In [1]:
import os

# go up one directory
os.chdir("..")

import pandas as pd 
import numpy as np
from tqdm import tqdm
from sklearn_extra.cluster import KMedoids
from sklearn.decomposition import PCA
from functions import cosmic_val
from functions.graph_tools import *
from functions.data_handling import data_augmentation
from models.muse import *
from functions import cosmic_val
from functions import data_handling as dh
from tqdm import tqdm
import torch.optim as optim
import torch.nn as nn

# set seed
# np.random.seed(15)
# torch.manual_seed(15)

# Data

In [2]:
data_path = "data/catalogues_Ovary_SBS.tsv"
cosmic_path = "data/COSMIC_v3.4_SBS_GRCh37.txt"
output_folder = "data/processed"
output_filename = "Ordered_Ovary_SBS.csv"
ordered_data_path = os.path.join(output_folder, output_filename)

In [3]:
dh.load_preprocess_data(data_path, cosmic_path, sep1 = "\t", sep2 = "\t", output_folder = output_folder, output_filename = output_filename)

Data already exists in  data/processed/Ordered_Ovary_SBS.csv


In [4]:
# load data
data = pd.read_csv(ordered_data_path, index_col = 0)
cosmic = pd.read_csv(cosmic_path, sep = "\t", index_col = 0)

In [5]:
L_ONE = 128
TOLERANCE = 1e-10
CONSTRAINT = 'identity'

In [None]:
from collections import defaultdict
import pandas as pd
import numpy as np

# Dictionary of lists: {k: [iterations]}
results_dict = defaultdict(list)

losses_train = []
signatures = []
iterations = 12
k_range = 9

augmented_data = data_augmentation(X=data, augmentation=10)

for k in tqdm(range(2, k_range)):
    for i in range(iterations):

        muse_model = HybridAutoencoder(input_dim=data.shape[0],  # 96
                                       l_1=L_ONE,
                                       latent_dim=k,
                                       constraint=CONSTRAINT)

        # Training MUSE
        muse_error, muse_signatures, muse_exposures, muse_train_loss, muse_val_loss = train_model_for_extraction(
            model=muse_model,
            X_aug_multi_scaled=augmented_data.T,
            X_scaled=data.T,
            signatures=k,
            epochs=1000,
            batch_size=64,
            save_to='muse_test',
            iteration=i,
            patience=30
        )

        # Normalize signatures
        diagonals_muse = muse_signatures.sum(axis=0)
        muse_exposures = muse_exposures.T @ np.diag(diagonals_muse)
        muse_signatures = muse_signatures @ np.diag(1 / diagonals_muse)

        # Store results
        losses_train.append(muse_train_loss)
        signatures.append(muse_signatures)

        # Store data in structured format
        results_dict[k].append({
            "iteration": i,
            "muse_error": muse_error,
            "muse_signatures": muse_signatures  # Keep as NumPy array for easier processing
        })

# Convert dictionary into a DataFrame for better analysis
df_results = pd.DataFrame([
    {"k": k, "iteration": entry["iteration"], "muse_error": entry["muse_error"], "muse_signatures": entry["muse_signatures"]}
    for k, entries in results_dict.items()
    for entry in entries
])


  model.load_state_dict(torch.load(best_model_path))
  8%|▊         | 1/12 [00:47<08:38, 47.15s/it]

Early stopping at epoch 122


  model.load_state_dict(torch.load(best_model_path))
 17%|█▋        | 2/12 [01:43<08:47, 52.77s/it]

Early stopping at epoch 170


  model.load_state_dict(torch.load(best_model_path))
 25%|██▌       | 3/12 [02:31<07:35, 50.63s/it]

Early stopping at epoch 157


  model.load_state_dict(torch.load(best_model_path))
 33%|███▎      | 4/12 [03:25<06:52, 51.61s/it]

Early stopping at epoch 171


  model.load_state_dict(torch.load(best_model_path))
 42%|████▏     | 5/12 [04:11<05:49, 49.90s/it]

Early stopping at epoch 145


  model.load_state_dict(torch.load(best_model_path))
 50%|█████     | 6/12 [04:56<04:48, 48.03s/it]

Early stopping at epoch 145


  model.load_state_dict(torch.load(best_model_path))
 58%|█████▊    | 7/12 [05:53<04:15, 51.12s/it]

Early stopping at epoch 183


  model.load_state_dict(torch.load(best_model_path))
 67%|██████▋   | 8/12 [06:39<03:17, 49.28s/it]

Early stopping at epoch 137


  model.load_state_dict(torch.load(best_model_path))
 75%|███████▌  | 9/12 [07:24<02:24, 48.11s/it]

Early stopping at epoch 142


  model.load_state_dict(torch.load(best_model_path))
 83%|████████▎ | 10/12 [07:58<01:27, 43.81s/it]

Early stopping at epoch 107


  model.load_state_dict(torch.load(best_model_path))
 92%|█████████▏| 11/12 [08:37<00:42, 42.26s/it]

Early stopping at epoch 118


  model.load_state_dict(torch.load(best_model_path))
100%|██████████| 12/12 [09:25<00:00, 47.09s/it]


Early stopping at epoch 143


  model.load_state_dict(torch.load(best_model_path))
  8%|▊         | 1/12 [00:32<06:02, 32.93s/it]

Early stopping at epoch 112


  model.load_state_dict(torch.load(best_model_path))
 17%|█▋        | 2/12 [01:18<06:41, 40.12s/it]

Early stopping at epoch 141


  model.load_state_dict(torch.load(best_model_path))
 25%|██▌       | 3/12 [02:03<06:22, 42.54s/it]

Early stopping at epoch 145


  model.load_state_dict(torch.load(best_model_path))
 33%|███▎      | 4/12 [02:39<05:19, 39.96s/it]

Early stopping at epoch 114


  model.load_state_dict(torch.load(best_model_path))
 42%|████▏     | 5/12 [03:24<04:51, 41.68s/it]

Early stopping at epoch 135


  model.load_state_dict(torch.load(best_model_path))
 50%|█████     | 6/12 [04:01<04:00, 40.14s/it]

Early stopping at epoch 122


  model.load_state_dict(torch.load(best_model_path))
 58%|█████▊    | 7/12 [04:36<03:11, 38.37s/it]

Early stopping at epoch 109


  model.load_state_dict(torch.load(best_model_path))
 67%|██████▋   | 8/12 [05:06<02:23, 35.96s/it]

Early stopping at epoch 98


  model.load_state_dict(torch.load(best_model_path))
 75%|███████▌  | 9/12 [05:52<01:56, 38.87s/it]

Early stopping at epoch 140


  model.load_state_dict(torch.load(best_model_path))
 83%|████████▎ | 10/12 [06:35<01:20, 40.24s/it]

Early stopping at epoch 137


  model.load_state_dict(torch.load(best_model_path))
 92%|█████████▏| 11/12 [07:20<00:41, 41.78s/it]

Early stopping at epoch 138


  model.load_state_dict(torch.load(best_model_path))
100%|██████████| 12/12 [07:59<00:00, 39.95s/it]


Early stopping at epoch 116


  model.load_state_dict(torch.load(best_model_path))
  8%|▊         | 1/12 [00:35<06:25, 35.03s/it]

Early stopping at epoch 114


  model.load_state_dict(torch.load(best_model_path))
 17%|█▋        | 2/12 [01:15<06:24, 38.43s/it]

Early stopping at epoch 122


  model.load_state_dict(torch.load(best_model_path))
 25%|██▌       | 3/12 [01:58<06:03, 40.35s/it]

Early stopping at epoch 128


  model.load_state_dict(torch.load(best_model_path))
 33%|███▎      | 4/12 [02:32<05:04, 38.03s/it]

Early stopping at epoch 111


  model.load_state_dict(torch.load(best_model_path))
 42%|████▏     | 5/12 [03:02<04:05, 35.06s/it]

Early stopping at epoch 94


  model.load_state_dict(torch.load(best_model_path))
 50%|█████     | 6/12 [03:41<03:37, 36.28s/it]

Early stopping at epoch 118


  model.load_state_dict(torch.load(best_model_path))
 58%|█████▊    | 7/12 [04:24<03:12, 38.49s/it]

Early stopping at epoch 133


  model.load_state_dict(torch.load(best_model_path))
 67%|██████▋   | 8/12 [05:06<02:38, 39.75s/it]

Early stopping at epoch 135


  model.load_state_dict(torch.load(best_model_path))
 75%|███████▌  | 9/12 [05:38<01:52, 37.35s/it]

Early stopping at epoch 105


  model.load_state_dict(torch.load(best_model_path))
 83%|████████▎ | 10/12 [06:23<01:19, 39.69s/it]

Early stopping at epoch 132


  model.load_state_dict(torch.load(best_model_path))
 92%|█████████▏| 11/12 [06:55<00:37, 37.12s/it]

Early stopping at epoch 94


  model.load_state_dict(torch.load(best_model_path))
100%|██████████| 12/12 [07:31<00:00, 37.60s/it]


Early stopping at epoch 107


  model.load_state_dict(torch.load(best_model_path))
  8%|▊         | 1/12 [00:39<07:15, 39.62s/it]

Early stopping at epoch 110


  model.load_state_dict(torch.load(best_model_path))
 17%|█▋        | 2/12 [01:28<07:31, 45.14s/it]

Early stopping at epoch 135


  model.load_state_dict(torch.load(best_model_path))
 25%|██▌       | 3/12 [02:06<06:16, 41.84s/it]

Early stopping at epoch 111


  model.load_state_dict(torch.load(best_model_path))
 33%|███▎      | 4/12 [02:45<05:26, 40.78s/it]

Early stopping at epoch 111


  model.load_state_dict(torch.load(best_model_path))
 42%|████▏     | 5/12 [03:24<04:41, 40.17s/it]

Early stopping at epoch 111


  model.load_state_dict(torch.load(best_model_path))
 50%|█████     | 6/12 [04:00<03:52, 38.82s/it]

Early stopping at epoch 112


  model.load_state_dict(torch.load(best_model_path))
 58%|█████▊    | 7/12 [04:33<03:04, 36.91s/it]

Early stopping at epoch 95


  model.load_state_dict(torch.load(best_model_path))
 67%|██████▋   | 8/12 [05:15<02:33, 38.35s/it]

Early stopping at epoch 119


  model.load_state_dict(torch.load(best_model_path))
 75%|███████▌  | 9/12 [05:54<01:55, 38.61s/it]

Early stopping at epoch 116


  model.load_state_dict(torch.load(best_model_path))
 83%|████████▎ | 10/12 [06:23<01:11, 35.58s/it]

Early stopping at epoch 88


  model.load_state_dict(torch.load(best_model_path))
 92%|█████████▏| 11/12 [07:01<00:36, 36.29s/it]

Early stopping at epoch 109


  model.load_state_dict(torch.load(best_model_path))
100%|██████████| 12/12 [07:53<00:00, 39.49s/it]


Early stopping at epoch 129


  model.load_state_dict(torch.load(best_model_path))
  8%|▊         | 1/12 [00:34<06:17, 34.30s/it]

Early stopping at epoch 111


  model.load_state_dict(torch.load(best_model_path))
 17%|█▋        | 2/12 [01:10<05:51, 35.14s/it]

Early stopping at epoch 103


  model.load_state_dict(torch.load(best_model_path))
 25%|██▌       | 3/12 [01:59<06:16, 41.87s/it]

Early stopping at epoch 137


  model.load_state_dict(torch.load(best_model_path))
 33%|███▎      | 4/12 [02:42<05:37, 42.14s/it]

Early stopping at epoch 134


  model.load_state_dict(torch.load(best_model_path))
 42%|████▏     | 5/12 [03:27<05:02, 43.27s/it]

Early stopping at epoch 129


  model.load_state_dict(torch.load(best_model_path))
 50%|█████     | 6/12 [03:53<03:43, 37.23s/it]

Early stopping at epoch 83


  model.load_state_dict(torch.load(best_model_path))
 58%|█████▊    | 7/12 [04:39<03:21, 40.24s/it]

Early stopping at epoch 147


  model.load_state_dict(torch.load(best_model_path))
 67%|██████▋   | 8/12 [05:20<02:42, 40.54s/it]

Early stopping at epoch 122


  model.load_state_dict(torch.load(best_model_path))
 75%|███████▌  | 9/12 [06:01<02:02, 40.69s/it]

Early stopping at epoch 141


  model.load_state_dict(torch.load(best_model_path))
 83%|████████▎ | 10/12 [06:53<01:28, 44.09s/it]

Early stopping at epoch 156


  model.load_state_dict(torch.load(best_model_path))
 92%|█████████▏| 11/12 [07:28<00:41, 41.15s/it]

Early stopping at epoch 108


  model.load_state_dict(torch.load(best_model_path))
100%|██████████| 12/12 [08:01<00:00, 40.11s/it]


Early stopping at epoch 102


  model.load_state_dict(torch.load(best_model_path))
  8%|▊         | 1/12 [00:39<07:09, 39.04s/it]

Early stopping at epoch 128


  model.load_state_dict(torch.load(best_model_path))
 17%|█▋        | 2/12 [01:22<06:58, 41.84s/it]

Early stopping at epoch 130


  model.load_state_dict(torch.load(best_model_path))
 25%|██▌       | 3/12 [01:56<05:44, 38.26s/it]

Early stopping at epoch 104


  model.load_state_dict(torch.load(best_model_path))
 33%|███▎      | 4/12 [02:29<04:46, 35.86s/it]

Early stopping at epoch 114


  model.load_state_dict(torch.load(best_model_path))
 42%|████▏     | 5/12 [03:07<04:16, 36.70s/it]

Early stopping at epoch 117


  model.load_state_dict(torch.load(best_model_path))
 50%|█████     | 6/12 [03:53<03:58, 39.83s/it]

Early stopping at epoch 147


  model.load_state_dict(torch.load(best_model_path))
 58%|█████▊    | 7/12 [04:15<02:51, 34.24s/it]

Early stopping at epoch 75


  model.load_state_dict(torch.load(best_model_path))
 67%|██████▋   | 8/12 [05:05<02:36, 39.01s/it]

Early stopping at epoch 160


  model.load_state_dict(torch.load(best_model_path))
 75%|███████▌  | 9/12 [05:52<02:04, 41.65s/it]

Early stopping at epoch 153


  model.load_state_dict(torch.load(best_model_path))
 83%|████████▎ | 10/12 [06:26<01:18, 39.34s/it]

Early stopping at epoch 112


  model.load_state_dict(torch.load(best_model_path))
 92%|█████████▏| 11/12 [07:06<00:39, 39.57s/it]

Early stopping at epoch 127


  model.load_state_dict(torch.load(best_model_path))
100%|██████████| 12/12 [07:48<00:00, 39.01s/it]


Early stopping at epoch 125


  model.load_state_dict(torch.load(best_model_path))
  8%|▊         | 1/12 [00:29<05:26, 29.65s/it]

Early stopping at epoch 93


  model.load_state_dict(torch.load(best_model_path))
 17%|█▋        | 2/12 [01:16<06:40, 40.04s/it]

Early stopping at epoch 149


  model.load_state_dict(torch.load(best_model_path))
 25%|██▌       | 3/12 [01:48<05:24, 36.05s/it]

Early stopping at epoch 100


  model.load_state_dict(torch.load(best_model_path))
 33%|███▎      | 4/12 [02:15<04:21, 32.74s/it]

Early stopping at epoch 88


  model.load_state_dict(torch.load(best_model_path))
 42%|████▏     | 5/12 [02:52<03:58, 34.06s/it]

Early stopping at epoch 112


  model.load_state_dict(torch.load(best_model_path))
 50%|█████     | 6/12 [03:26<03:23, 33.99s/it]

Early stopping at epoch 104


  model.load_state_dict(torch.load(best_model_path))
 58%|█████▊    | 7/12 [03:55<02:42, 32.52s/it]

Early stopping at epoch 91


  model.load_state_dict(torch.load(best_model_path))
 67%|██████▋   | 8/12 [04:27<02:09, 32.34s/it]

Early stopping at epoch 97


  model.load_state_dict(torch.load(best_model_path))
 75%|███████▌  | 9/12 [05:01<01:38, 32.69s/it]

Early stopping at epoch 106


  model.load_state_dict(torch.load(best_model_path))
 83%|████████▎ | 10/12 [05:27<01:01, 30.72s/it]

Early stopping at epoch 84


  model.load_state_dict(torch.load(best_model_path))
 92%|█████████▏| 11/12 [06:02<00:31, 31.98s/it]

Early stopping at epoch 107


  model.load_state_dict(torch.load(best_model_path))
100%|██████████| 12/12 [06:40<00:00, 33.37s/it]

Early stopping at epoch 123





In [10]:
df_results.head()

Unnamed: 0,k,iteration,muse_error,muse_signatures
0,2,0,480496.901255,"[[0.0097312, 0.017147673], [0.009513606, 0.012..."
1,2,1,461517.109657,"[[0.013457963, 0.014270343], [0.014063982, 0.0..."
2,2,2,470522.444516,"[[0.016055238, 0.011268243], [0.009904848, 0.0..."
3,2,3,448949.854865,"[[0.016415482, 0.011322818], [0.0137136085, 0...."
4,2,4,470800.976584,"[[0.01290409, 0.014645866], [0.0125399865, 0.0..."


In [8]:
all_signatures = np.hstack(signatures)

In [9]:
print(all_signatures.shape)

(96, 420)


In [None]:
# For each value of k in the dataset, extract the consensus signatures



pam = KMedoids(n_clusters = k, metric='cosine').fit(all_signatures.T)
labels = pam.labels_
medoid_indices = pam.medoid_indices_
consensus_signatures = all_signatures[:, medoid_indices]

NameError: name 'LATENT_DIM' is not defined

In [None]:
matched_signatures, mean_similarity = cosmic_val.compute_match(consensus_signatures, cosmic, index = 0)

In [None]:
print(matched_signatures.head())
print("\nMean similarity of the matched signatures: ", mean_similarity)

In [None]:
reduced_signatures = PCA(n_components=2).fit_transform(all_signatures.T)
plot_clusters(reduced_signatures, labels, medoid_indices, 4, "AENMF signature clusters")

In [None]:
df_consensus = pd.DataFrame(consensus_signatures, index = data.index)

In [None]:
plot_signature(df_consensus, "AENMF consensus signatures")