In [46]:
import sys
import pandas as pd

import matplotlib.pyplot as plt
import numpy as np
import importlib

# project specific
sys.path.append('../src')
import helpers
from utils import benchmark, data_handler, visualisation
from models import VQ_VAE_0, VQ_VAE_1


import seaborn as sns
import matplotlib.pyplot as plt
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()

import plotly.express as px
import plotly.subplots as sp
import plotly.graph_objs as go

import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt



from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster


pd.options.display.width = 1000


absolute_path = "c:/thesis/data/cancer"
import scipy.cluster.hierarchy as sch

import pickle

# for translation of gene symbols
import mygene
mg = mygene.MyGeneInfo()

%load_ext tensorboard
!rm -rf ../workfiles/logs/

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


## Experiment 

### Load dataset
(raw)

In [47]:
with open('../workfiles/BRCA_ds.pkl', 'rb') as f:
#with open('../workfiles/normed_BRCA_ds.pkl', 'rb') as f:
    data, metadata = pickle.load(f)

In [None]:
dat = data
print(dat.shape)
feature_num = dat.shape[1]
dat = dat.reshape(-1,1,feature_num)
print(dat.shape)
label = metadata["PAM50_labels"]
feature_num = metadata["n_features"]

In [None]:
class Mydatasets(torch.utils.data.Dataset):
    def __init__(self, data1 ,transform = None):
        self.transform = transform
        self.data1 = data1
        self.datanum = len(data1)

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        
        out_data1 = torch.tensor(self.data1[idx]).float() 
        if self.transform:
            out_data1 = self.transform(out_data1)

        return out_data1

In [None]:
train_data, test_data = train_test_split(dat, test_size = 0.1, random_state = 66)
print('train data:',len(train_data))
print('test data:',len(test_data))
train_data_set = Mydatasets(data1 = train_data)
test_data_set = Mydatasets(data1 = test_data)
train_dataloader = torch.utils.data.DataLoader(train_data_set, batch_size = 32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_data_set, batch_size = 32, shuffle=True)

In [None]:
# for m1 mac
#DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

DEVICE = torch.device("mps")

In [None]:
importlib.reload(VQ_VAE_1)


# new best performer
out_dim = 64   
VQ_VAE = VQ_VAE_1.Model(
            dropout = 0.1,
            input_size = feature_num, 
            encoder_dim = out_dim,
            num_embeddings = 512,
            embedding_dim = 32,   
            commitment_cost = 1,
            decay= 0
           ).to(DEVICE)




Classifier_loss = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(VQ_VAE.parameters(), lr=1e-4, amsgrad=False)
data_variance = np.var(dat)

print(DEVICE)

In [None]:
train_res_recon_error = []
perplexities = []
frames = []
n_frames = 0

def callbacks(epoch):
    # Code to run every 10 epochs
    if (epoch + 1) % 10 == 0:
        # Your additional code here
        en_lat = []
        en_quantized = []
        en_reconstruction = []

        VQ_VAE.eval()

        data_set = Mydatasets(data1 = dat)
        data_set = torch.utils.data.DataLoader(data_set, batch_size = 256, shuffle=False) 


        for i in range(len(dat)):
            en_data = data_set.dataset[i][0]
            latent_1 = VQ_VAE._encoder(en_data.view(1, 1, feature_num).float().to(DEVICE))
            _, data_recon, _, _,latent_2 = VQ_VAE(en_data.view(1, 1, feature_num).float().to(DEVICE))
            en_lat.append(latent_1.cpu().detach().numpy())
            en_quantized.append(latent_2.cpu().detach().numpy())
            en_reconstruction.append(data_recon.cpu().detach().numpy())

        encode_out = np.array(en_lat).reshape(len(dat), -1)
        quantized_out = np.array(en_quantized).reshape(len(dat), -1)
        reconstruction_out = np.array(en_reconstruction).reshape(len(dat), -1)


        stack = np.vstack([dat[0].reshape(1, -1), reconstruction_out[0].reshape(1, -1)])


        pca = PCA(n_components=2)
        pca.fit(encode_out)
        pca_result = pca.transform(encode_out)

        

        #index_column = np.full((pca_result.shape[0], 1), epoch + 1, dtype=int)
        index_column = np.full((pca_result.shape[0], 1), n_frames, dtype=int)

        pca_result_with_index = np.hstack((index_column, pca_result))

        frames.append(pca_result_with_index)
        n_frames += 1

        if (epoch + 1) % 500 == 0:

            fig, axs = plt.subplots(1, 4, figsize=(12, 3))


            # Plot the line plot in the second subplot
            axs[0].plot(train_res_recon_error, label='Training Loss')
            axs[0].set_title('Training Loss Plot')
            #axs[0].set_xticks([])

            sns.heatmap(stack, ax=axs[1], cbar=False)
            axs[1].set_title('Stacked heatmap of two samples')
            axs[1].set_xticks([])
            axs[1].set_yticks([])


            sns.heatmap(encode_out, ax = axs[2], cbar=False)
            axs[2].set_title('Heatmap of hole quantized dataset')
            axs[2].set_xticks([])
            axs[2].set_yticks([])


            sns.scatterplot(x = pca_result[:, 0], y = pca_result[:, 1], c=label, ax=axs[3])
            axs[3].set_title('PCA')
            axs[3].set_xticks([])
            axs[3].set_yticks([])

            plt.subplots_adjust(wspace=0)  
            plt.tight_layout()
            plt.show()



In [None]:
EPOCH = 5000

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', min_lr= 0.000001)
print(optimizer.param_groups[0]['lr'])

VQ_VAE.train()


for epoch in tqdm(range(EPOCH)):
    
    running_loss = 0.0
    count = 0
    quantized_merge = torch.empty(0, 1, 64).to(DEVICE)
    
    # Training loop
    for _, inputs in enumerate(train_dataloader):
        optimizer.zero_grad()
        inputs = inputs.to(DEVICE)
        vq_loss, data_recon, perplexity, _, quantized = VQ_VAE(inputs)
        recon_error = F.mse_loss(data_recon, inputs) / data_variance
        loss = recon_error + vq_loss #+ perplexity
        loss.backward()
        optimizer.step()
        count += 1
        running_loss += loss.item()
    
    # Calculate and store training loss for this epoch
    train_loss = running_loss / count
    train_res_recon_error.append(train_loss)
    perplexities.append(perplexity.cpu().detach().numpy())
    callbacks(epoch)
    


    
# Plot training and validation loss curves
epochs = np.arange(1, EPOCH + 1)
plt.plot(epochs, train_res_recon_error, label='Training Loss')
#plt.plot(epochs, val_res_recon_error, label='Validation Loss')
plt.plot(epochs, perplexities, label='Perplexity')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

print(optimizer.param_groups[0]['lr'])


In [None]:
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

#%matplotlib notebook

# Create a figure and axis for the animation
fig, ax = plt.subplots()

# Define an update function for the animation
def update(frame):
    ax.clear()
    ax.set_title(f'Frame {frame}')
    
    # Get the PCA result for the current frame
    pca_result = frames[frame]
    
    # Scatter plot of PCA results with color based on index
    scatter = ax.scatter(pca_result[:, 1], pca_result[:, 2], c=label)
    




# Create the animation
ani = FuncAnimation(fig, update, frames=n_frames, repeat=True)



# Display the animation as HTML
HTML(ani.to_jshtml())

In [None]:
#plt.show()
ani.save('../img/aniPCA_GDS_VQ-VAE-1.mp4', writer='ffmpeg')

In [None]:
#torch.save(VQ_VAE, "../workfiles/torch_temp")
torch.save(VQ_VAE, "../workfiles/torch_VQ-VAE-1-GDS")



In [None]:
# Plot training and validation loss curves
epochs = np.arange(1, EPOCH + 1)
plt.plot(train_res_recon_error, label='Training Loss')
plt.plot(perplexities, label='Perplexity')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

print(optimizer.param_groups[0]['lr'])

In [None]:
index = 50
plt.plot(train_res_recon_error[index:], label='Training Loss')
#plt.plot(epochs[index:], val_res_recon_error[index:], label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
plt.plot(train_res_recon_error[40:], label='Training Loss')
#plt.plot(val_res_recon_error[40:], label='Training Loss')


In [None]:
en_lat = []
en_quantized = []
en_reconstruction = []

VQ_VAE.eval()

data_set = Mydatasets(data1 = dat)
data_set = torch.utils.data.DataLoader(data_set, batch_size = 256, shuffle=False) 


for i in range(len(dat)):
    en_data = data_set.dataset[i][0]
    latent_1 = VQ_VAE._encoder(en_data.view(1, 1, feature_num).float().to(DEVICE))
    _, data_recon, _, _,latent_2 = VQ_VAE(en_data.view(1, 1, feature_num).float().to(DEVICE))
    en_lat.append(latent_1.cpu().detach().numpy())
    en_quantized.append(latent_2.cpu().detach().numpy())
    en_reconstruction.append(data_recon.cpu().detach().numpy())

encode_out = np.array(en_lat)
encode_out = encode_out.reshape(len(dat), -1)
quantized_out = np.array(en_quantized)
quantized_out = quantized_out.reshape(len(dat), -1)
reconstruction_out = np.array(en_reconstruction)
reconstruction_out = reconstruction_out.reshape(len(dat), -1)

print('encode_out:', encode_out.shape)
print('quantized_out:', quantized_out.shape)

In [None]:
# compatibility between notebooks
decoded_data = reconstruction_out
label = pd.Series(label)

In [None]:
print(data.shape)
print(encode_out.shape)
print(quantized_out.shape)
print(reconstruction_out.shape)

In [None]:
fig = sp.make_subplots(rows=2, cols=1, shared_xaxes=False, vertical_spacing=0.1)
# Add the original image as a heatmap-like plot
heatmap_trace1 = go.Heatmap(z=data[0].reshape(1, -1) )
fig.add_trace(heatmap_trace1, row=1, col=1)



# Add the decoded image as a heatmap-like plot
heatmap_trace3 = go.Heatmap(z=reconstruction_out[0].reshape(1, -1))
fig.add_trace(heatmap_trace3, row=2, col=1)
# Update layout
fig.update_layout(title='Stacked Graph of Image and Latent Space', showlegend=False)
fig.show()

In [None]:
sns.clustermap(data, col_cluster= False)

In [None]:
sns.clustermap(quantized_out)

In [None]:
sns.clustermap(encode_out, col_cluster= False)

In [None]:
#sns.clustermap(decoded_data, col_cluster= False)
sns.heatmap(decoded_data)

In [None]:
sns.heatmap(data)

In [None]:
importlib.reload(visualisation)

print("######################## OG Groups : ")
visualisation.plot_clusters(encode_out, label)

### what happens when i use keras 

In [None]:
from models import vanilla_autoencoder
import tensorflow as tf
from tensorflow.keras import callbacks
import datetime

In [None]:
with open('../workfiles/BRCA_ds.pkl', 'rb') as f:
#with open('../workfiles/normed_BRCA_ds.pkl', 'rb') as f:
    
    data_original, metadata = pickle.load(f)

In [None]:
seq_names = metadata["sequence_names"]
n_genes = metadata["n_features"]
gene_names = metadata["feature_names"] 

In [None]:
importlib.reload(vanilla_autoencoder) # to allow modification of the script without restarting the whole session

latent_dim = 16

t_shape = (n_genes)


autoencoder = vanilla_autoencoder.generate_model(t_shape, latent_dim)
autoencoder.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())

In [None]:
checkpoint_filepath = '../workfiles/simple_ae/checkpoint'
model_checkpoint_callback = callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='loss',
    mode='min',
    save_best_only=True)


reduce_lr = callbacks.ReduceLROnPlateau(monitor='loss', factor=0.2,
                              patience=20, min_lr=0.00001)

early_stopping_callback = callbacks.EarlyStopping(monitor='loss', patience=70)


log_dir = "../workfiles/logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

cb = [model_checkpoint_callback, reduce_lr, early_stopping_callback, tensorboard_callback]

In [None]:
autoencoder.load_weights(checkpoint_filepath)

In [None]:
hist = autoencoder.fit(data_original, epochs=200, callbacks=cb)  
autoencoder.load_weights(checkpoint_filepath)

In [None]:
plt.plot(hist.history["loss"])

In [None]:
compressed_dataframe = autoencoder.encoder.predict(data)

In [None]:
sns.clustermap(compressed_dataframe, col_cluster= False)

In [None]:
recon_data = autoencoder.decoder.predict(compressed_dataframe)

In [None]:
sns.clustermap(recon_data, col_cluster= False)

# who gets the best loss?

In [None]:
# torch dataset
squared_error = np.square(filtered_data - reconstruction_out)
mse = np.mean(squared_error)

print("Mean Squared Error:", mse)

In [None]:
# tensorflow
squared_error = np.square(data_original - recon_data)
mse = np.mean(squared_error)

print("Mean Squared Error:", mse)

In [None]:
visualisation.plot_clusters(compressed_dataframe, metadata['PAM50_labels'])

# clustering perf analysis

In [None]:
from sklearn.metrics import confusion_matrix


In [None]:
linked = linkage(encode_out, 'ward', 'euclidean')  # You can use other linkage methods as well
plt.figure(figsize=(10, 5))
dendrogram(linked, orientation='top', distance_sort='descending')
plt.title('Hierarchical Clustering Dendrogram')
plt.xlabel('Sample Index')
plt.ylabel('Distance')
plt.show()

# Determine the number of clusters (adjust the threshold as needed)
threshold = 50  # Adjust this threshold to identify clusters
cluster_labels = fcluster(linked, threshold, criterion='distance')

cm = confusion_matrix(cluster_labels, filtered_labs)
cm_df = pd.DataFrame(cm, 
                     index=["Actual 0", "Actual 1", "Actual 2", "Actual 3", "Actual 4"], 
                     columns=["Predicted 0", "Predicted 1", "Predicted 2", "Predicted 3", "Predicted 4"])

print(cm_df)
sns.heatmap(cm_df)

In [None]:
linked = linkage(compressed_dataframe, 'ward', 'euclidean')  # You can use other linkage methods as well
plt.figure(figsize=(10, 5))
dendrogram(linked, orientation='top', distance_sort='descending')
plt.title('Hierarchical Clustering Dendrogram')
plt.xlabel('Sample Index')
plt.ylabel('Distance')
plt.show()

# Determine the number of clusters (adjust the threshold as needed)
threshold = 290  # Adjust this threshold to identify clusters
cluster_labels = fcluster(linked, threshold, criterion='distance')


# Identify potential outliers (clusters with a small number of points)
unique_labels, counts = np.unique(cluster_labels, return_counts=True)
outlier_clusters = unique_labels[counts < threshold]
print(unique_labels)

cm = confusion_matrix(cluster_labels, metadata['PAM50_labels'])
cm_df = pd.DataFrame(cm, 
                     index=["Actual 0", "Actual 1", "Actual 2", "Actual 3", "Actual 4"], 
                     columns=["Predicted 0", "Predicted 1", "Predicted 2", "Predicted 3", "Predicted 4"])

print(cm_df)
sns.heatmap(cm_df)