In [1]:
import os
import pandas as pd
import datetime

import pathlib
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import importlib


from tensorflow.keras import layers, losses
from tensorflow.keras.models import Model
from tensorflow import keras
from tensorflow.keras import callbacks  

# project specific
from utils import data_handler
from utils.models import LSTM



import matplotlib.pyplot as plt
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()

import plotly.express as px
import plotly.subplots as sp
import plotly.graph_objs as go



%load_ext tensorboard
!rm -rf ../workfiles/logs/

In [2]:
importlib.reload(data_handler) # to allow modification of the script without restarting the whole session

sgdc_params = {
            #'penalty':["elasticnet", "l1", "l2"],
            'penalty':["l1"],
            #'l1_ratio':np.linspace(0.1, 1, 5),
            'alpha':np.linspace(0.1, 0.5, 5),
        }

x_train, filenames, n_genes = data_handler.generate_dataset(feature_selection_threshold = 2, 
                                                   #feature_selection_proceedure = "LASSO", 
                                                   retain_phases="Both", 
                                                   #retain_phases=None, 
                                                   return_id = True,
                                                   sgdc_params = sgdc_params,
                                                   #subsample = 100,
                                                   #class_balancing = "match_smaller_sample")
                                                   class_balancing = "balanced",
                                                   as_time_series = True)


Retaining patients that are included in phases 1 & 2
retaining all patient who have passed all visits...
loading samples...
loaded 1455 samples
selecting genes based on median absolute deviation threshold:  2 ...
number of genes selected :  14876
normalizing data...
normalization done
number of seq in the dataset : 1455
converting samples to time series
number of actual individual to be studied : 291


In [3]:
importlib.reload(LSTM) # to allow modification of the script without restarting the whole session

latent_dim = 32
sequence_length = 5
t_shape = (sequence_length, n_genes)


autoencoder = LSTM.generate_model(t_shape, latent_dim)
autoencoder.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())

In [4]:
checkpoint_filepath = '../workfiles/simple_ae/checkpoint'
model_checkpoint_callback = callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='loss',
    mode='min',
    save_best_only=True)


reduce_lr = callbacks.ReduceLROnPlateau(monitor='loss', factor=0.5,
                              patience=25, min_lr=0.00001)

early_stopping_callback = callbacks.EarlyStopping(monitor='loss', patience=50)


log_dir = "../workfiles/logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

cb = [model_checkpoint_callback, 
      reduce_lr, 
      early_stopping_callback, 
      tensorboard_callback]

In [None]:
#%tensorboard --logdir ../workfiles/logs/fit

6K so far

In [None]:
hist = autoencoder.fit(x_train, epochs=2000, callbacks=cb)

In [7]:
autoencoder.load_weights(checkpoint_filepath)


<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x289240e50>

In [None]:
plt.plot(hist.history['loss'])
print(hist.history['loss'])




### some diagnosis

In [8]:
e = iter(x_train).next()
z = autoencoder.encoder(e)
decoded = autoencoder.decoder(z)

In [9]:
e_ = e[0]  
z_ = z[0].reshape(1, -1) 
decoded_ = decoded[0]  



# Create subplot grid with vertical stacking
fig = sp.make_subplots(rows=3, cols=1, shared_xaxes=False, vertical_spacing=0.1)

# Add the original image as a heatmap-like plot
heatmap_trace1 = go.Heatmap(z=e_, colorscale='viridis')
fig.add_trace(heatmap_trace1, row=1, col=1)

# Add the latent representation as a heatmap-like plot
tensor_reshaped = z.reshape(1, -1)
heatmap_trace2 = go.Heatmap(z=z_, colorscale='viridis')
fig.add_trace(heatmap_trace2, row=2, col=1)

# Add the decoded image as a heatmap-like plot
heatmap_trace3 = go.Heatmap(z=decoded_, colorscale='viridis')
fig.add_trace(heatmap_trace3, row=3, col=1)

# Update layout
fig.update_layout(title='Stacked Graph of Image and Latent Space', showlegend=False)

# Update x-axis labels
#fig.update_xaxes(title_text='genes (normalized)', row=1, col=1)
#fig.update_xaxes(title_text='latent representation', row=2, col=1)
#fig.update_xaxes(title_text='genes (normalized)', row=3, col=1)

# Update y-axis labels
#fig.update_yaxes(title_text='timestamps', row=1, col=1)
#fig.update_yaxes(title_text='latent representation', row=2, col=1)
#fig.update_yaxes(title_text='timestamps', row=3, col=1)

fig.show()

# Trouble shooting 

In [10]:
# Counting how many entries has been 0'd out

indexes_original = (sum(e_) == 0).numpy()
indexes_reconstructed = (sum(decoded_) == 0).numpy()

print("sum of series being null in the original patient data:",sum(indexes_original))
print("sum of series being null in the reconstructed patient data:",sum(indexes_reconstructed))
print("number of total genes",len(e_[0]))

print("this represent", (sum(indexes_reconstructed)/len(e_[0]))*100,"% of 0'd out indexes" )

sum of series being null in the original patient data: 28
sum of series being null in the reconstructed patient data: 12516
number of total genes 14876
this represent 84.13552030115623 % of 0'd out indexes


In [11]:
print("sum of total experession on orginal data:", sum(sum(e_)))
print("sum of total experession on reconstructed data:", sum(sum(decoded_)))

sum of total experession on orginal data: tf.Tensor(101.76974, shape=(), dtype=float32)
sum of total experession on reconstructed data: tf.Tensor(68.98187, shape=(), dtype=float32)


On a given run, this represents a 30% delta

In [12]:
indexes = (sum(decoded_) != 0).numpy()

In [13]:
# let's see if the 0'd out entries are creating a bias
print(sum(sum(e_[0:5, indexes])))
print(sum(sum(decoded_[0:5, indexes])))

tf.Tensor(70.802444, shape=(), dtype=float32)
tf.Tensor(68.98187, shape=(), dtype=float32)


In [None]:
autoencoder.encoder.save('../workfiles/LSTM')


In [14]:
compressed_dataframe = autoencoder.encoder.predict(x_train)

1/5 [=====>........................] - ETA: 1s

2023-08-25 11:35:48.355153: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz




# Can we have some visualisation over the entire dataset ?

In [15]:
decoded_dataframe = autoencoder.decoder.predict(compressed_dataframe)



In [17]:
original_dataset_numpy = np.concatenate(list(x_train.as_numpy_iterator()), axis=0)


# Reshape the dataset to a 2D matrix
reshaped_original = original_dataset_numpy.reshape(original_dataset_numpy.shape[0]*original_dataset_numpy.shape[1], -1)
reshaped_reconstruction = decoded_dataframe.reshape(decoded_dataframe.shape[0]*decoded_dataframe.shape[1], -1)

In [None]:
print(original_dataset_numpy.shape)
print(reshaped_original.shape)
print(reshaped_reconstruction.shape)

In [18]:
# Assuming reshaped_dataset is your 2D matrix of shape (total_num_samples, sequence_length * num_features)
# You might need to preprocess your data before visualizing

# Create the heatmap for the original dataset
fig1 = go.Figure(go.Heatmap(z=reshaped_original, colorscale='viridis'))
fig1.update_layout(title='Original Dataset Heatmap')

# Create the heatmap for the reconstructed dataset
# Assuming reconstructed_dataset is another 2D matrix similar to reshaped_dataset
fig2 = go.Figure(go.Heatmap(z=reshaped_reconstruction, colorscale='viridis'))
fig2.update_layout(title='Reconstructed Dataset Heatmap')

# Create a subplot with two rows
fig = sp.make_subplots(rows=2, cols=1, subplot_titles=('Original Dataset', 'Reconstructed Dataset'))
fig.add_trace(fig1.data[0], row=1, col=1)
fig.add_trace(fig2.data[0], row=2, col=1)

# Update subplot layout
fig.update_layout(height=800, width=600, showlegend=False)
fig.update_xaxes(title_text='Flattened Time Series', row=2, col=1)
fig.update_yaxes(title_text='Samples', row=1, col=1)
fig.update_yaxes(title_text='Samples', row=2, col=1)

#fig.show() # showing the figure is way to ressource intensive.
# i should try to save it instead.

ValueError: 
    Invalid value of type 'builtins.dict' received for the 'coloraxis' property of heatmap
        Received value: {'colorscale': 'viridis', 'colorbar': {'len': 0.5, 'y': 0.75}}

    The 'coloraxis' property is an identifier of a particular
    subplot, of type 'coloraxis', that may be specified as the string 'coloraxis'
    optionally followed by an integer >= 1
    (e.g. 'coloraxis', 'coloraxis1', 'coloraxis2', 'coloraxis3', etc.)
        

In [None]:
df = pd.DataFrame(compressed_dataframe)
df["name"] = filenames

In [None]:
df.to_csv("../workfiles/processed_data_lstm.csv", index=False)
