In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from tqdm import tqdm
from mpl_toolkits import mplot3d
import matplotlib.patches as mpatches
import umap

%matplotlib qt

### Some functions

##### Function to plot UMAP with differente neighbors

In [54]:
def umap_plot(data,color,components=2):

    n_neighbors = [5,15,22,30]

    for i in n_neighbors:
        reducer = umap.UMAP(n_components=components,n_neighbors=i)
        embedding = reducer.fit_transform(data)
        embedding_df = pd.DataFrame(embedding)
        embedding_df['Subgroups']= color

        X_data= embedding_df[0]
        Y_data = embedding_df[1]
        if components == 3:
            Z_data = embedding_df[2]
        Sbgrp = embedding_df['Subgroups']

        cdict = {'Group4': 'pink', 'SHH': 'blue', 'WNT': 'green', 'Group3': 'yellow', 'Generated': 'black', 'Reference': 'red'}
        c = [cdict[val] for val in Sbgrp]

        plt.figure(figsize=(16,10))
        if components == 3:
            ax = plt.axes(projection='3d')
            ax.scatter3D(X_data, Y_data, Z_data, c=c)
        if components == 2:
            plt.scatter(X_data,Y_data,c=c)
        pink_c = mpatches.Patch(color='pink', label='Group4')
        blue_c = mpatches.Patch(color='blue', label='SHH')
        green_c = mpatches.Patch(color='green', label='WNT')
        yellow_c = mpatches.Patch(color='yellow', label='Group3')
        black_c = mpatches.Patch(color='black', label='Generated')
        red_c = mpatches.Patch(color='red', label='Reference')
        plt.legend(handles=[pink_c,blue_c,green_c,yellow_c,black_c,red_c])
        plt.title('UMAP with n_neighbors %i'%(i))
        plt.show()

##### Function to get embeddings, recontructions, mean and logvar

In [None]:
def get_embeddings(model,dataloader):
    model.eval()
    rec_model = np.zeros(shape=(0,12087))
    embedding_model = np.zeros(shape=(0,512))
    mean_model = np.zeros(shape=(0,512))
    logvar_model = np.zeros(shape=(0,512))
    with torch.no_grad(): # in validation we don't want to update weights
        for i, data in tqdm(enumerate(dataloader), total=int(len(test_dataset)/dataloader.batch_size)):
            data = data.view(data.size(0), -1)
            reconstruction,mean,logvar, coded = model(data)
            rec_model = np.concatenate((rec_model, reconstruction), axis=0)
            mean_model = np.concatenate((mean_model, mean), axis=0)
            logvar_model = np.concatenate((logvar_model, logvar), axis=0)
            embedding_model = np.concatenate((embedding_model,coded),axis=0)
    return rec_model, embedding_model, mean_model, logvar_model

##### Function to generate data from a subgroup

In [None]:
def data_generation(N,subgroup,test_dataset):

    if subgroup == 'G4':
        data_mean = standard_g4
        data_logvar = logvar_g4
    elif subgroup == 'SHH':
        data_mean = standard_shh
        data_logvar = logvar_shh
    elif subgroup == 'G3':
        data_mean = standard_g3
        data_logvar = logvar_g3  
    else:
        print("Incorrect subgroup")
        return    
    
    sample = np.zeros(shape=(512,))

    for i in range(0,N):
        for mean, logvar in zip(data_mean, data_logvar):
            resultado = np.random.normal(mean,np.exp(0.5*logvar))
            sample = np.concatenate((sample, resultado), axis=0)

    sample = sample.reshape(N+1,512)
    z = sample[1:]

    z = torch.from_numpy(z)
    z = z.float()

    with torch.no_grad():                 
        samples = model.decoder(z)   #decode the data
    generated = torch.cat([test_dataset, samples], dim=0) #concat the test data and the generate data to visualize it
    new_colors = np.array(['Generated']*len(samples)) #create the reference to paint black the generate examples
    colors_generated = np.concatenate((colors,new_colors),axis=0) #concat the colors of the test data and the generate data

    return generated, colors_generated

##### Function to get interpolation data

In [50]:
def data_interpolation(N,centroid1_mean, centroid1_logvar,centroid2_mean, centroid2_logvar, colors, test_dataset):

    z1 = torch.from_numpy(centroid1_mean).float() #Means
    z2 = torch.from_numpy(centroid2_logvar).float()

    z3 = torch.from_numpy(centroid1_logvar).float() #Logvar
    z4 = torch.from_numpy(centroid2_logvar).float()

    with torch.no_grad():                 
        A = model.decoder(z1)   #decode the data to plot the references
        B = model.decoder(z2)   #here only mean because I want the particular data

    sample = np.zeros(shape=(0,512))
    for i in range(0,N):
        mean = i / (N - 1) * z2 + (1 - i / (N - 1) ) * z1 #interpolation mean
        logvar = i / (N - 1) * z4 + (1 - i / (N - 1) ) * z3 #interpolation logvar
        resultado = np.random.normal(mean,np.exp(0.5*logvar)) #sample from a normal 
        sample = np.concatenate((sample, resultado), axis=0)
    sample = sample.reshape(N,512)
    z = sample[1:]

    z = torch.from_numpy(z) #preprocessing to introduce samples in NN
    z = z.float()

    
    #GENERATE INTERPOLATION DATA
    with torch.no_grad():                 
        samples = model.decoder(z)   #decode the data
    generated = torch.cat([test_dataset, samples], dim=0) #concat the test data and the generate data to visualize it
    new_colors = np.array(['Generated']*len(samples)) #create the reference to paint black the generate examples
    colors_generated = np.concatenate((colors,new_colors),axis=0) #concat the colors of the test data and the generate data

    # ADD REFERENCES
    generated = torch.cat((generated,A,B),axis=0) # add the centroids to the data to plot them

    colors_reference = np.array(['Reference']*2) #add label generated
    colors_generated = np.concatenate((colors_generated,colors_reference),axis=0) #generate new colors

    return generated, colors_generated

### Load training data (Cavalli)

In [None]:
data = pd.read_csv('Medulloblastoma Files\Medulloblastoma_Cavalli_VAE_data.csv', sep=',', na_values=".")
print("The shape of the data is: ", data.shape)
data = data.rename(columns={'Unnamed: 0': 'Patient'})

subgroups = pd.read_csv('Medulloblastoma Files\GSE85218_subgroups.csv', sep=' ',header=None)
print("The shape of the subgroups is: ", subgroups.shape)

### Load test data (Northcott)

In [None]:
data_test = pd.read_csv('Medulloblastoma Files\Medulloblastoma_Northcott_VAE_data.csv', sep=',', na_values=".")
print("The shape of the data is: ", data_test.shape)
data_test = data_test.rename(columns={'Unnamed: 0': 'Patient'})

subgroups_test = pd.read_csv('Medulloblastoma Files\GSE37382_subgroups.csv', sep=' ',header=None)
print("The shape of the subgroups is: ", subgroups_test.shape)
colors = subgroups_test[1].values #column with the subgroups of tumor to label each observation in plots

### Normalize and prepare data

In [None]:
data = data.drop(['Patient'],axis=1)
data_test = data_test.drop(['Patient'],axis=1)

from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()

scaler.fit(data)
data = scaler.transform(data) 

scaler.fit(data_test)
data_test = scaler.transform(data_test)

In [None]:
data = pd.DataFrame(data)
train_dataset = torch.tensor(data.values).float()

data_test = pd.DataFrame(data_test)
test_dataset = torch.tensor(data_test.values).float()


train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False,
)

### Load model
#### Two layers decoder - two layers decoder
#### 512 dimension latent space

In [None]:
features = 512

class VAE(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()

        #encoder layers
        self.encoder = nn.Sequential(
            nn.Linear(in_features=kwargs["input_shape"], out_features=kwargs["mid_dim"]),
            nn.ReLU(),
            nn.Linear(in_features=kwargs["mid_dim"], out_features=features*2)
        )
            
        #decoder layers
        self.decoder = nn.Sequential(
            nn.Linear(in_features=features, out_features=kwargs["mid_dim"]),
            nn.ReLU(),
            nn.Linear(in_features=kwargs["mid_dim"], out_features=kwargs["input_shape"]),
            nn.Tanh()
        )

    def reparametrize(self, mu, log_var):

        # mu: mean of the encoder's latent space distribution
        # log_var: variance from the encoder's latient space distribution
        if self.training:
            std = torch.exp(0.5*log_var) 
            eps = torch.randn_like(std) 
            sample = mu + (eps*std) 
        else:
            sample = mu
        return sample

    def forward(self, x):
        
        mu_logvar = self.encoder(x).view(-1,2,features)
        mu = mu_logvar[:, 0, :] 
        log_var = mu_logvar[:, 1, :] 

        z = self.reparametrize(mu,log_var) 
        reconstruction = self.decoder(z)
        
        return reconstruction, mu, log_var, z
    
model = VAE(input_shape=12087, mid_dim=4096)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

### Load trained model

In [None]:
PATH = './vaeDecoderPrunned512000001.pth'
model = VAE(input_shape=12087, mid_dim=4096)
model.load_state_dict(torch.load(PATH))

### Get embeddings, mean and logvar from Northcott data

In [None]:
reconstructed, coded, mean, logvar = get_embeddings(model, test_loader)

### Calculate centroids of subgroups

In [None]:
mean = pd.DataFrame(mean)
mean['Subgroup'] = subgroups_test[1] #add subgroups column to the embedding data

mean_shh = mean[mean['Subgroup'] == 'SHH']
mean_g3 = mean[mean['Subgroup'] == 'Group3']
mean_g4 = mean[mean['Subgroup'] == 'Group4']

# Centroids of Means
standard_shh = mean_shh.mean().values.reshape(1,512) 
standard_g3 = mean_g3.mean().values.reshape(1,512)
standard_g4 = mean_g4.mean().values.reshape(1,512)

In [None]:
logvar = pd.DataFrame(mean)
logvar['Subgroup'] = subgroups_test[1] #add subgroups column to the embedding data

logvar_shh = logvar[logvar['Subgroup'] == 'SHH']
logvar_g3 = logvar[logvar['Subgroup'] == 'Group3']
logvar_g4 = logvar[logvar['Subgroup'] == 'Group4']

logvar_shh = logvar_shh.mean().values.reshape(1,512)
logvar_g3 = logvar_g3.mean().values.reshape(1,512)
logvar_g4 = logvar_g4.mean().values.reshape(1,512)

### Calculate euclidean distance between centroids

In [None]:
distance_g3_shh = np.cumsum(abs(standard_shh - standard_g3))
distance_G3G4 = np.cumsum(abs(standard_g4 - standard_g3))
distance_shh_G4 = np.cumsum(abs(standard_g4 - standard_shh))

print("Distance G3-SHH: ",distance_g3_shh[-1])
print("Distance G3-G4: ",distance_G3G4[-1])
print("Distance SHH-G4: ",distance_shh_G4[-1])

### Sample from the distribution of some subgroup

In [31]:
generated, new_colors = data_generation(48, 'G3', test_dataset) #generate 48 data of G3

In [32]:
umap_plot(generated,new_colors) #plot it in UMAP with 2 dimensions
#umap_plot(generated,new_colors,3) plot it in UMAP with 3 dimensions

### Interpolate between centroids of two subgroups

In [52]:
interpolation_data, interpolation_colors = data_interpolation(32,standard_g3, logvar_g3,standard_g4, logvar_g4, colors, test_dataset)

In [53]:
umap_plot(interpolation_data,interpolation_colors) #plot it in UMAP with 2 dimensions

(318,)
torch.Size([318, 12087])
