In [None]:
import sys
import importlib
import pickle

# data manipulation
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# data analysis
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA

# pytorch specific
import torch
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm

# project specific
sys.path.append('../src')
from utils import data_handler
from utils import visualisation
from models import torch_vanilla_AE

pd.options.display.width = 1000

### Load Dataset

In [None]:
with open('../workfiles/BRCA_ds.pkl', 'rb') as f:
    data, metadata = pickle.load(f)

In [None]:
dat = data
print(dat.shape)
feature_num = dat.shape[1]
dat = dat.reshape(-1,1,feature_num)
print(dat.shape)
label = metadata["PAM50_labels"]
feature_num = metadata["n_features"]

In [None]:
class Mydatasets(torch.utils.data.Dataset):
    def __init__(self, data1 ,transform = None):
        self.transform = transform
        self.data1 = data1
        self.datanum = len(data1)

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        
        out_data1 = torch.tensor(self.data1[idx]).float() 
        if self.transform:
            out_data1 = self.transform(out_data1)

        return out_data1

In [None]:
train_data, test_data = train_test_split(dat, test_size = 0.1, random_state = 66)
print('train data:',len(train_data))
print('test data:',len(test_data))
train_data_set = Mydatasets(data1 = train_data)
test_data_set = Mydatasets(data1 = test_data)
train_dataloader = torch.utils.data.DataLoader(train_data_set, batch_size = 32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_data_set, batch_size = 32, shuffle=True)

In [None]:
#DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# For m1 Mac
DEVICE = torch.device("mps")

In [None]:
importlib.reload(torch_vanilla_AE)

latent_dim = 32
model = torch_vanilla_AE.Autoencoder(
    shape = feature_num,
    dropout = 0.1,
    latent_dim= 64).to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr=1e-4, amsgrad=False)

print(DEVICE)

In [None]:
train_res_recon_error = []
perplexities = []
frames = []
n_frames = 0

In [None]:
EPOCH = 1000

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', min_lr= 0.000001)
print(optimizer.param_groups[0]['lr'])

model.train()


for epoch in tqdm(range(EPOCH)):
    model.train()

    running_loss = 0.0
    count = 0
    quantized_merge = torch.empty(0, 1, 64).to(DEVICE)
    
    # Training loop
    for _, inputs in enumerate(train_dataloader):
        optimizer.zero_grad()
        inputs = inputs.to(DEVICE)
        data_recon = model(inputs)
        loss = F.mse_loss(data_recon, inputs) 
        loss.backward()
        optimizer.step()
        count += 1
        running_loss += loss.item()
    
    # Calculate and store training loss for this epoch
    train_loss = running_loss / count
    train_res_recon_error.append(train_loss)


    if (epoch + 1) % 10 == 0:
        en_lat = []
        en_reconstruction = []

        model.eval()

        data_set = Mydatasets(data1 = dat)
        data_set = torch.utils.data.DataLoader(data_set, batch_size = 256, shuffle=False) 


        for i in range(len(dat)):
            en_data = data_set.dataset[i][0]
            latent_1 = model._encoder(en_data.view(1, 1, feature_num).float().to(DEVICE))
            data_recon = model(en_data.view(1, 1, feature_num).float().to(DEVICE))
            en_lat.append(latent_1.cpu().detach().numpy())
            en_reconstruction.append(data_recon.cpu().detach().numpy())

        encode_out = np.array(en_lat).reshape(len(dat), -1)
        reconstruction_out = np.array(en_reconstruction).reshape(len(dat), -1)


        
        # PCA of the latent space
        pca = PCA(n_components=2)
        pca.fit(encode_out)
        pca_result = pca.transform(encode_out)

        index_column = np.full((pca_result.shape[0], 1), n_frames, dtype=int)

        pca_result_with_index = np.hstack((index_column, pca_result))

        frames.append(pca_result_with_index)
        n_frames += 1

        if (epoch + 1) % 100 == 0:

            # stacking a single observation as well as its reconstruction in order to evaluate the results
            stack = np.vstack([dat[0].reshape(1, -1), reconstruction_out[0].reshape(1, -1)])

            # prepping a 1x4 plot to monitor the model through training
            fig, axs = plt.subplots(1, 4, figsize=(12, 3))


            # Plot the line plot in the second subplot
            axs[0].plot(train_res_recon_error, label='Training Loss')
            axs[0].set_title('Training Loss Plot')
            #axs[0].set_xticks([])

            sns.heatmap(stack, ax=axs[1], cbar=False)
            axs[1].set_title('Stacked heatmap of two samples')
            axs[1].set_xticks([])
            axs[1].set_yticks([])


            sns.heatmap(encode_out, ax = axs[2], cbar=False)
            axs[2].set_title('Heatmap of hole quantized dataset')
            axs[2].set_xticks([])
            axs[2].set_yticks([])


            sns.scatterplot(x = pca_result[:, 0], y = pca_result[:, 1], c=label, ax=axs[3])
            axs[3].set_title('PCA')
            axs[3].set_xticks([])
            axs[3].set_yticks([])

            plt.subplots_adjust(wspace=0)  
            plt.tight_layout()
            plt.show()




    
# Plot training and validation loss curves
epochs = np.arange(1, EPOCH + 1)
plt.plot(epochs, train_res_recon_error, label='Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

print(optimizer.param_groups[0]['lr'])


In [None]:
print(optimizer.param_groups[0]['lr'])

en_lat = []
en_reconstruction = []

model.eval()

data_set = Mydatasets(data1 = dat)
data_set = torch.utils.data.DataLoader(data_set, batch_size = 256, shuffle=False) 


for i in range(len(dat)):
    en_data = data_set.dataset[i][0]
    latent_1 = model._encoder(en_data.view(1, 1, feature_num).float().to(DEVICE))
    data_recon = model(en_data.view(1, 1, feature_num).float().to(DEVICE))
    en_lat.append(latent_1.cpu().detach().numpy())
    en_reconstruction.append(data_recon.cpu().detach().numpy())

encode_out = np.array(en_lat).reshape(len(dat), -1)
reconstruction_out = np.array(en_reconstruction).reshape(len(dat), -1)




# PCA of the latent space
pca = PCA(n_components=2)
pca.fit(encode_out)
pca_result = pca.transform(encode_out)








# stacking a single observation as well as its reconstruction in order to evaluate the results
stack = np.vstack([dat[0].reshape(1, -1), reconstruction_out[0].reshape(1, -1)])

# prepping a 1x4 plot to monitor the model through training
fig, axs = plt.subplots(2, 3, figsize=(12, 6))


# Plot the line plot in the second subplot
axs[0,0].plot(train_res_recon_error, label='Training Loss')
axs[0,0].set_title('Training Loss Plot')


sns.heatmap(stack, ax=axs[0,1], cbar=False)
axs[0,1].set_title('Stacked heatmap of two samples')
axs[0,1].set_xticks([])
axs[0,1].set_yticks([])





sns.scatterplot(x = pca_result[:, 0], y = pca_result[:, 1], c=label, ax=axs[0,2])
axs[0,2].set_title('PCA')
axs[0,2].set_xticks([])
axs[0,2].set_yticks([])


sns.heatmap(data, ax = axs[1,0], cbar=False)
axs[1,0].set_title('Heatmap of the hole dataset')
axs[1,0].set_xticks([])
axs[1,0].set_yticks([])

sns.heatmap(encode_out, ax = axs[1,1], cbar=False)
axs[1,1].set_title('Heatmap of the hole latent space')
axs[1,1].set_xticks([])
axs[1,1].set_yticks([])

sns.heatmap(reconstruction_out, ax = axs[1,2], cbar=False)
axs[1,2].set_title('Heatmap of the hole recontruction')
axs[1,2].set_xticks([])
axs[1,2].set_yticks([])

plt.subplots_adjust(wspace=0)  
plt.tight_layout()
plt.show()

In [None]:
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

#%matplotlib notebook

# Create a figure and axis for the animation
fig, ax = plt.subplots()

# Define an update function for the animation
def update(frame):
    ax.clear()
    ax.set_title(f'Frame {frame}')
    
    # Get the PCA result for the current frame
    pca_result = frames[frame]
    
    # Scatter plot of PCA results with color based on index
    scatter = ax.scatter(pca_result[:, 1], pca_result[:, 2], c=label)
    




# Create the animation
ani = FuncAnimation(fig, update, frames=n_frames, repeat=True)

# Display the animation as HTML
HTML(ani.to_jshtml())

In [None]:
ani.save('../img/GDS_pca_vanilla_AE_0.mp4', writer='ffmpeg')


In [None]:
torch.save(model, "../workfiles/torch_AE")