In [None]:
import pytorch_lightning as pl
import h5py
import os
import torch
import csv
import re
import numpy as np
from torch import nn
from torch.nn import functional as F
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import library as lib
import pandas as pd
from matplotlib.pylab import plt
from sklearn.manifold import trustworthiness

import seaborn as sns
#from sklearn.metrics import silhouette_score
from itertools import groupby
import string

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

#import multiprocess as mp
from ete3 import Tree

In [None]:
pl.seed_everything(42, workers=True) 

# Parameters

In [None]:
print(os.getcwd())
work_dir = "/c/Users/alex/FortgeschrittenesPraktikum/PyTorch_Lightning_VAE"
inputs = "./inputs/"
out_path = "./out_activ_fct/sig/"

writer = False
gpu = False

activation_function = "sigmoid"
#exec(f"torch.nn.functional.{get_activation(self)}({self.latent(self.encoder(x))})")

metric_list = ["cosine","euclidean", "manhattan", "ts_ss"]
metric_list = metric_list
print("metric: ",metric_list)

epoch = 10

dim = 2560
dim3 = int(dim * 1/16)
dim2 = int(dim * 1/8)
dim1 = int(dim * 1/2)

In [None]:
#with open(exclude, "r") as ex:
#    reader = csv.reader(ex, delimiter='\t')
#    column1 = [row[0] for row in reader]
#print(column1)

# Data Pre-Processing

In [None]:
#make numpy file with (seq header, embedding)
#esm2 = h5py.File(inputs+"KLK_esm2_reduced.h5", 'r')
esm2 = h5py.File("chumby_data/phosphatase_esm2_3b.h5", 'r')
list(esm2.keys())
# List all groups
keys = list(esm2.keys())
#print(keys) 
# Get the data
embedding = []
headers = []
for key in keys:
    #data
    emb = esm2[key][:]
    #headers
    #replacing mistakes
    key = key.replace("isoform=", "isoform_").replace("=", "").replace(":", "_")
    if key == "GZMA_Canis_lupus":
        print("header is now changed")
        key = "GZMA_Canis_lupus_1"
    #exclude
    #if key not in column1:
    headers.append(key)
    embedding.append(emb)

headers = np.array(headers)
embedding = np.array(embedding)

#split train test
headers_first, headers_test, embedding_first, embedding_test = train_test_split(headers, embedding, test_size=0.1, random_state=42)
#split train validation
headers_train, headers_validate, embedding_train, embedding_validate = train_test_split(headers_first, embedding_first, test_size =0.2, random_state=42)
print("total:", headers.shape)
print("training+validation:", headers_first.shape)
print("test:", headers_test.shape)
print("training:", headers_train.shape)
print("validation:", headers_validate.shape)
print(len(headers_train), len(embedding_train))
print(len(headers_validate), len(embedding_validate))

#training_df = pd.DataFrame({'header': headers_train, 'embedding' : embedding_train})
#validate_df = pd.DataFrame({'header': headers_validate, 'embedding' : embedding_validate})

# Dataloaders

In [None]:
train_dataset = lib.makedataset(headers_train, embedding_train)
val_dataset = lib.makedataset(headers_validate, embedding_validate)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=0, drop_last=True)

# Training VAE

In [None]:
#if prott5 used: dtype=float16 -> error with .linear()

#encoder_layers, latent_dim, decoder_layers

encoder_layers = [dim, dim1, dim2, dim3]
latent_dim = dim3
decoder_layers = [dim3, dim2, dim1, dim]

# Initialize the VAE model
vae = lib.VariationalAutoencoder(encoder_layers, latent_dim, decoder_layers)

if gpu:
    trainer = pl.Trainer(accelerator="gpu", devices = 1, max_epochs=epoch, log_every_n_steps=1, check_val_every_n_epoch=1)
else:
    trainer = pl.Trainer(accelerator="auto", max_epochs=epoch, log_every_n_steps=1, check_val_every_n_epoch=1)

# Train the VAE model
trainer.fit(vae, train_dataloader, val_dataloader)

# Fix Formatting

In [None]:
def find_largest_number(root_dir):
    largest_number = 0
    print(root_dir)
    for subdir, dirs, files in os.walk(root_dir):
        for dir_name in dirs:
            print(dir_name)
            match = re.search(r'\d+', dir_name)
            if match:
                current_number = int(match.group())
                if current_number > largest_number:
                    largest_number = current_number
    return largest_number

largest_num = find_largest_number(work_dir+"/lightning_logs")
#hard coded cause its broken
largest_num = 1

#Fix formatting
#with open(work_dir+'/lightning_logs/version_' + str(largest_num) + '/metrics.csv', 'r') as file:
with open("/c/Users/alex/FortgeschrittenesPraktikum/PyTorch_Lightning_VAE/lightning_logs/version_1/metrics.csv", 'r') as file:
    reader = csv.reader(file)
    rows = list(reader)

combined_rows = []
header_row = ['train_loss','reconstruction_loss','kl_loss','epoch','step','validation_loss']

combined_rows.append(header_row)
for i in range(1, len(rows) -1, 2):
    train_loss = rows[i][0]
    reconstruction_loss = rows[i][1]
    kl_loss = rows[i][2]
    epoch = rows[i+1][3]
    step = rows[i+1][4]
    validation_loss = rows[i+1][5]
    combined_rows.append([train_loss,reconstruction_loss,kl_loss,epoch,step,validation_loss])

with open(work_dir+'lightning_logs/version_' + str(largest_num) + '/new_metrics.csv', 'w', newline = "") as file:
    writer = csv.writer(file)
    writer.writerows(combined_rows)

# Plot Loss

In [None]:
#Plot
# read csv file
df = pd.read_csv('lightning_logs/version_' + str(largest_num) + '/new_metrics.csv')

df1 = df[df['epoch'] >= int(epoch)/2]

# plot 0
plt.plot('epoch', 'train_loss', data=df)
plt.plot('epoch', 'reconstruction_loss', data=df)
plt.plot('epoch', 'validation_loss', data=df)
plt.title('Epoch vs Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(['loss', 'reconstruction loss', 'validation_loss'], loc='upper right')
plt.show()


# plot 1
plt.plot('epoch', 'train_loss', data=df1)
plt.plot('epoch', 'reconstruction_loss', data=df1)
plt.plot('epoch', 'validation_loss', data=df1)
plt.title('Epoch vs Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(['loss', 'reconstruction loss', 'validation_loss'], loc='upper right')
plt.show()

# Encode

In [None]:
encode = vae.encoder(torch.Tensor(embedding))
print(encode)

embedding_out = encode.detach() 


# Epoch

In [None]:
epoch = int(epoch) +1
print("epoch: ",epoch)


# Neighbor joining

In [None]:
# Neighbor_joining and create tree
def nj(metric_list, emb, headers, writer, out_path, epoch, trained, activation_function,title):
    
    i = 0
    j = 0
    PLOTS_PER_ROW = 2
    
    #fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(15, 12))
    #plt.subplots_adjust(hspace=0.5)
    #fig.suptitle('Distance Matrix from '+str(epoch)+' Training '+title, fontsize=18, y=0.95)
    
    for metric in metric_list:
        print(metric)

        out_newick = out_path+"KLK_"+activation_function+"_esm2_nj_"+metric+"_"+trained+".newick"
        print(out_newick)


        #distance metric
        dist = lib.distance_metric(emb)
        distmat = dist.get_metric(emb, metric)
        
        #print(np.where(distmat > 150))

        
        #neighbor joining to tree
        nj = lib.neighbor_joining(distmat,headers)
        newick = nj.get_newick(distmat,headers)

        #trustworthiness
        _distmat  = lib.cophenetic_distmat(newick, names=headers)
        _trustworthiness = trustworthiness(distmat, _distmat, n_neighbors=10, metric='precomputed')
        print("trustworthiness: ", _trustworthiness)


        #silhouette = silhouette_score(_distmat, headers, metric='precomputed')
        #silhouette = lib.silhouette.get_silhouette(distmat, headers)
        #print("silhouette: ", silhouette)

        sns.heatmap(data= distmat,cmap = sns.cm.rocket_r).set(title=metric.capitalize()+' Distance Matrix from '+str(epoch)+' Training '+title)
        #sns.heatmap(ax=axes[i, j], data= distmat,cmap = sns.cm.rocket_r).set(title=metric.capitalize()+' Distance Matrix')
        plt.savefig(out_path+'heatmap_'+activation_function+'_'+metric+'_'+trained+'.png', dpi=400)
        
        j+=1
        if j%PLOTS_PER_ROW==0:
            i+=1
            j=0
            
        if writer:
            with open(out_newick, 'w') as w:
                w.write(newick)
                w.close()
                                
    #plt.show()
    #plt.savefig(out_path+'heatmap_'+activation_function+'_summary_'+trained+'.png', dpi=400)
    
print("epoch: ",epoch)
nj(metric_list, embedding_out,headers, writer, out_path, str(epoch)+" Epoch", str(epoch),activation_function, '('+activation_function.capitalize()+')')            

# Reference NJ: no training 

In [None]:
nj(metric_list, embedding, headers, False, out_path, "No", "notraining",activation_function, "")   

# UPGMA and create tree

In [None]:
def upgma(metric_list, emb, headers, writer, out_path, epoch, trained ,activation_function):
    
    for metric in metric_list:
        print("metric:", metric)
        
        out_newick = out_path+"KLK_"+activation_function+"_esm2_upgma_"+metric+"_"+trained+".newick"

        #distance metric
        dist = lib.distance_metric(emb)
        distmat = dist.get_metric(emb, metric)
        #print(distmat)

        #upgma to tree
        upgma = lib.upgma(distmat,headers)

        #trustworthiness
        _distmat  = lib.cophenetic_distmat(upgma, names=headers)
        _trustworthiness = trustworthiness(distmat, _distmat, n_neighbors=10, metric='precomputed')
        print("trustworthiness: ", _trustworthiness)

        
        #sns.heatmap(distmat,cmap = sns.cm.rocket_r).set(title=metric.capitalize()+' Distance Matrix from '+epoch+' Training')
        #plt.savefig(out_path+'heatmap_'+activation_function+'_'+metric+'_'+trained+'.png', dpi=400)
        
        if writer:
            with open(out_newick, 'w') as w:
                w.write(upgma)
                w.close()


print("epoch: ",epoch)
upgma(metric_list, embedding_out, headers, writer, out_path, str(epoch)+" Epochs", str(epoch),activation_function)

# Reference UPGMA: no training

In [None]:
upgma(metric_list, embedding, headers, False, out_path, "No","notraining",activation_function)


# T-SNE

In [None]:
colour_df = pd.read_csv(inputs+"KLK_iTOL.txt", sep='\t',names=('protein', 'range', 'colour', 'family'))
colour_df = colour_df.drop([0,1,2])

colour_df['family'] = colour_df['family'].replace(['KLK4','KLK5','KLK7'], 'KLK4/5/7',regex = False)
colour_df['family'] = colour_df['family'].replace(['KLK8','KLK10','KLK12'], 'KLK8/10/12',regex = False)
colour_df['family'] = colour_df['family'].replace(['KLK1','Mammalian KLK1','non-mammalian KLK1'], 'KLK1',regex = False)
colour_df['family'] = colour_df['family'].replace(['KLK1','KLK2','Reptilian KLKs'], 'KLK1/2/Reptilian KLK',regex = False)
colour_df['family'] = colour_df['family'].replace(['KLK4/5/7','KLK8/10/12'], 'KLK4/5/7/8/10/12',regex = False)

types = len(colour_df["family"].unique())
print(types)



In [None]:
def tsne(embeddings, df_protein_types, title_name, dataset,n_types):
    
    tsne = TSNE(perplexity = 22, learning_rate = 100, n_iter = 1000, random_state= 42)#(n_components=2, verbose=1, random_state=123)
    tsne_results = tsne.fit_transform(embeddings)

    df_subset = pd.DataFrame()
    df_subset["Type"] = df_protein_types["family"]
    df_subset['tsne-2d-one'] = tsne_results[:,0]
    df_subset['tsne-2d-two'] = tsne_results[:,1]

    #print(df_subset)
    sns.scatterplot(
        x="tsne-2d-one", y="tsne-2d-two",
        hue="Type",
        palette=sns.color_palette("tab20",n_types),
        edgecolor = 'white',
        data=df_subset,
        legend="full"
    ).set(title= title_name)

    plt.legend(bbox_to_anchor=(1.02, 1),loc='upper left', borderaxespad=0,fontsize="6")
    plt.tight_layout()
    plt.savefig(out_path+'tsne_'+dataset+'.png', dpi=600)
    
    return tsne_results


tsne_results_trained = tsne(embedding_out,colour_df, f"T-SNE of VAE Embeddings ({activation_function.capitalize()})", f"{activation_function}_trained",types)


In [None]:
tsne_results_ref = tsne(embedding,colour_df, "T-SNE of Embeddings Without Training","reference",types)

In [None]:
print(tsne_results_trained.shape)
nj(metric_list, tsne_results_ref, headers, writer, out_path, "T-SNE", "tsne","","")
#upgma(metric_list,  tsne_results_ref, headers, writer, out_path, str(epoch)+" Epoch T-SNE", "tsne", activation_function)


# PCA

In [None]:
def pca(embeddings, df_protein_types, title_name, dataset,n_types):
    
    pca = PCA(n_components=2,random_state=42)
    pca_result = pca.fit_transform(embeddings)

    df = pd.DataFrame()
    df["Type"] = colour_df["family"]
    df['pca-one'] = pca_result[:,0]
    df['pca-two'] = pca_result[:,1] 
    #df['pca-three'] = pca_result[:,2]
    print()

    print('Explained variation per principal component: {}'.format(pca.explained_variance_ratio_))


    plt.figure(figsize=(14,10))
    sns.scatterplot(
        x="pca-one", y="pca-two",
        hue="Type",
        palette=sns.color_palette("tab20", n_types),
        data=df,
        legend="full").set(title= title_name)
    
    plt.tight_layout()
    plt.savefig(out_path+'pca_'+dataset+'.png', dpi=600)
    
    return pca_result



pca_results_trained = pca(embedding_out, colour_df, f"PCA of VAE Embeddings ({activation_function.capitalize()})", f"{activation_function}_trained",types)


In [None]:
pca_results_reference = pca(embedding, colour_df,  f"PCA of Embeddings Without Training", "reference",types)

In [None]:
print(pca_results_trained.shape)
#nj(metric_list, pca_results_reference, headers, writer, out_path, "PCA", "pca", activation_function)
upgma(metric_list,  pca_results_reference, headers, writer, out_path, "PCA", "pca", activation_function)

# Branch Support

In [None]:
def get_nj_fast(metric, emb, headers):

    #distance metric
    dist = lib.distance_metric(emb)
    distmat = dist.get_metric(emb, metric)

    #neighbor joining to tree
    nj = lib.neighbor_joining(distmat,headers)
    newick = nj.get_newick(distmat,headers)

    return newick

In [None]:
def get_samples(n_samples, emb):
    sample_list = []
    
    while n_samples != 0:
        encode = vae.encoder(torch.Tensor(emb))
        embedding_out = encode.detach() 
        print(embedding_out)
        
        sample_list.append(embedding_out)
        
        n_samples = n_samples -1
        
    return sample_list

In [None]:
n_samples = 200
THREADS = 6

samples = get_samples(n_samples, embedding)

with mp.Pool(THREADS) as pool:
    args = [("manhattan", i, headers) for i in samples]
    #print(len(args))

    output = pool.starmap(get_nj_fast, args)
    #print(output[0])

    
#with open(newick_replicates, 'w') as w:
#    w.write('\n'.join(output))

     

In [None]:
newick_reference = "./out_activ_fct/sigmoid/KLK_sigmoid_esm2_nj_manhattan_1.newick"

reference_tree = Tree(newick_reference)
reference_bits, reference_nodes = lib.get_bipartitions(reference_tree, headers)
print("total: ",len(headers))
print("after bi-partition: ",len(reference_bits))
#print(reference_bits)

#tree_list = [Tree(i) for i in output]
#print(len(tree_list))
#replicate_trees  = ((reference_bits, i, headers) for i in map(Tree, output))
#for tree in replicate_trees:
#    print(tree)
"""
with mp.Pool(THREADS) as pool:
    args = [(reference_bits, i, headers) for i in output]
    #print(args[0])
    support_bits = pool.starmap(lib.get_support, args)
    support = (100 * np.array(support_bits).mean(0)).astype(int)
    #print(support)


for percentage, node in zip(support, reference_nodes):
    node.support = percentage
    #print(percentage)
    #print(node.support)

count = 0
total = 0

for node in reference_tree.traverse():
    total = total+1
    if node not in reference_nodes:
        node.support = 0
        count = count+1
        
print(total)
print(count)

with open(out_path+"branch_support.newick", 'w') as w:
    w.write(reference_tree.write())
"""