In [1]:
import os
import random
import re
import sys
from datetime import datetime
from random import sample

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from sklearn import preprocessing
from sklearn.metrics import log_loss, mean_squared_error
from sklearn.preprocessing import StandardScaler
from torch.autograd import Variable
from torch.distributions.kl import kl_divergence
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.utils.data import DataLoader, Dataset

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
from torchvision.utils import save_image

In [2]:
path = "/Users/M283455/VAE_prject/scripts/"
sys.path.insert(0, path)

In [3]:
import VAE_tybalt
from VAE_tybalt import VAE

In [4]:
model = VAE(input_dim=5000, hidden_dim=[100], z_dim=100)

In [5]:
model.load_state_dict(torch.load("../output/models/vae_weights.pth"))

<All keys matched successfully>

In [6]:
tcga_tybalt_file_location = (
    "../../VAE_prject_data/raw/pancan_scaled_zeroone_rnaseq.tsv.gz"
)
rnaseq_df = pd.read_table(tcga_tybalt_file_location)
rnaseq_df.drop(columns=rnaseq_df.columns[0], axis=1, inplace=True)
rnaseq_df = rnaseq_df.dropna()

In [7]:
rnaseq_df.head(5)

Unnamed: 0,RPS4Y1,XIST,KRT5,AGR2,CEACAM5,KRT6A,KRT14,CEACAM6,DDX3Y,KDM5D,...,FAM129A,C8orf48,CDK5R1,FAM81A,C13orf18,GDPD3,SMAGP,C2orf85,POU5F1B,CHST2
0,0.678296,0.28991,0.03423,0.0,0.0,0.084731,0.031863,0.037709,0.746797,0.687833,...,0.44061,0.428782,0.732819,0.63434,0.580662,0.294313,0.458134,0.478219,0.168263,0.638497
1,0.200633,0.654917,0.181993,0.0,0.0,0.100606,0.050011,0.092586,0.103725,0.140642,...,0.620658,0.363207,0.592269,0.602755,0.610192,0.374569,0.72242,0.271356,0.160465,0.60256
2,0.78598,0.140842,0.081082,0.0,0.0,0.0,0.0,0.0,0.730648,0.657189,...,0.437658,0.471489,0.868774,0.471141,0.487212,0.385521,0.466642,0.784059,0.160797,0.557074
3,0.720258,0.122554,0.180042,0.0,0.0,0.0,0.0,0.0,0.720306,0.719855,...,0.553306,0.373344,0.818608,0.691962,0.635023,0.430647,0.45369,0.364494,0.161363,0.607895
4,0.767127,0.210393,0.034017,0.0,0.061161,0.0,0.053021,0.0,0.739546,0.665684,...,0.601268,0.379943,0.506839,0.68432,0.607821,0.320113,0.47619,0.122722,0.389544,0.698548


In [8]:
# load model
def load_trained_model(model, model_name):
    model.load_state_dict(torch.load("../output/models/" + model_name))
    return model


# reconstruct the input dataframe with VAE
def VAE_reconstruct_df(df, model):
    reconstruct_tensor = model.forward(torch.tensor(df.values, dtype=torch.float32))[0]
    dataframe = pd.DataFrame(reconstruct_tensor.detach().numpy(), columns=df.columns)

    return dataframe


# How well does the model reconstruct the input RNAseq data
def VAE_latent_out(df):
    mu = model.forward(torch.tensor(df.values, dtype=torch.float32))[2]
    sigma = model.forward(torch.tensor(df.values, dtype=torch.float32))[3]

    column_names = [str(i) for i in range(100)]

    mu = pd.DataFrame(mu.detach().numpy(), columns=column_names)
    sigma = pd.DataFrame(sigma.detach().numpy(), columns=column_names)

    return mu, sigma


def gene_summary(df, df_reconstruct):
    reconstruction_fidelity = df - df_reconstruct

    gene_mean = reconstruction_fidelity.mean(axis=0)
    gene_abssum = reconstruction_fidelity.abs().sum(axis=0).divide(rnaseq_df.shape[0])
    gene_summary = pd.DataFrame(
        [gene_mean, gene_abssum], index=["gene mean", "gene abs(sum)"]
    ).T
    gene_summary = gene_summary.sort_values(by="gene abs(sum)", ascending=True)

    return gene_summary


In [9]:
df_reconstruct = VAE_reconstruct_df(rnaseq_df, model)
gene_summary_sort = gene_summary(rnaseq_df, df_reconstruct)
best_500_genes = list(gene_summary_sort.index[:500])

In [10]:
best_100_genes = random.sample(best_500_genes,500)

In [11]:
import pandas as pd
from sklearn.feature_selection import mutual_info_regression

# Assume you have a pandas DataFrame df
df = rnaseq_df.copy()  # Add your data here

# Specify the target column
target_column = "RPS4Y1"  # Insert target column name here

def calculate_mutual_info(df, target):
    mutual_info = mutual_info_regression(df, df[target])
    return pd.Series(mutual_info, index=df.columns)

mi_series = calculate_mutual_info(df, target_column)

# Sort by mutual information in descending order
mi_series = mi_series.sort_values(ascending=False)

print(mi_series) 

RPS4Y1    6.000008
DDX3Y     1.165206
KDM5D     1.070255
EIF1AY    1.045620
ZFY       1.038007
            ...   
AIFM3     0.000000
COL9A3    0.000000
BEAN      0.000000
ERAP2     0.000000
SHANK2    0.000000
Length: 5000, dtype: float64


In [12]:
idxs = mi_series.index
mi_dataframe = pd.DataFrame()
for target in best_100_genes:
    mi_series = calculate_mutual_info(rnaseq_df, target)
    mi_dataframe[target] = mi_series[idxs]
    

  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[target] = mi_series[idxs]
  mi_dataframe[t

In [13]:
mi_dataframe

Unnamed: 0,AMICA1,AQP1,SLC12A8,TC2N,SLC22A17,RASGRP2,FAM171A1,SLAMF1,GPNMB,GRB7,...,PLAUR,AFP,RBM47,NCF1,CD3E,APOC1,SLC26A3,SLC22A12,IL2RG,BOC
RPS4Y1,0.039157,0.043493,0.024495,0.033889,0.046532,0.070019,0.019264,0.052931,0.031684,0.052880,...,0.035403,0.047225,0.032292,0.028127,0.035431,0.033820,0.047613,0.038488,0.029133,0.042729
DDX3Y,0.029967,0.044617,0.060035,0.029154,0.045610,0.083609,0.021912,0.031515,0.041699,0.060148,...,0.032918,0.043083,0.039003,0.027186,0.035111,0.034890,0.063969,0.055669,0.018705,0.052993
KDM5D,0.017049,0.044110,0.044610,0.050344,0.047725,0.113694,0.050004,0.049288,0.038469,0.062298,...,0.038092,0.041435,0.052859,0.029422,0.048110,0.030035,0.046950,0.046473,0.031569,0.055177
EIF1AY,0.027419,0.030647,0.016578,0.026552,0.034998,0.061384,0.032875,0.057451,0.041416,0.041258,...,0.037681,0.026626,0.036443,0.042483,0.043696,0.024216,0.026470,0.029604,0.025828,0.035722
ZFY,0.012230,0.031849,0.023801,0.032778,0.032894,0.053877,0.028903,0.036142,0.035582,0.045879,...,0.024107,0.102474,0.028200,0.014590,0.006995,0.041911,0.106248,0.075482,0.026945,0.039342
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
AIFM3,0.046306,0.005115,0.021127,0.021458,0.067119,0.020470,0.034336,0.020968,0.057033,0.038116,...,0.009161,0.005200,0.032951,0.029590,0.019403,0.037518,0.038223,0.023649,0.036466,0.013524
COL9A3,0.036382,0.021848,0.044244,0.048345,0.035896,0.036333,0.030323,0.031782,0.030085,0.057183,...,0.024785,0.013544,0.051143,0.014782,0.031214,0.000000,0.021121,0.048774,0.033575,0.033324
BEAN,0.017538,0.039949,0.043057,0.006092,0.057717,0.007866,0.024796,0.006863,0.032175,0.033919,...,0.031981,0.021287,0.009816,0.021220,0.012325,0.047954,0.030778,0.015138,0.007549,0.034316
ERAP2,0.096115,0.041746,0.018792,0.014424,0.038835,0.023134,0.024640,0.106990,0.023940,0.015589,...,0.035712,0.012311,0.022511,0.072080,0.118813,0.025126,0.009214,0.006264,0.116212,0.011108


In [14]:
max_mut_info = {}
for column in mi_dataframe.columns:
    df = mi_dataframe[str(column)].copy()
    max_mut_info[column]=df.sort_values()[-3:].index[-2]

In [15]:
max_mut_info.keys()

dict_keys(['AMICA1', 'AQP1', 'SLC12A8', 'TC2N', 'SLC22A17', 'RASGRP2', 'FAM171A1', 'SLAMF1', 'GPNMB', 'GRB7', 'SLC43A1', 'WFDC12', 'DCT', 'PLG', 'LOC388387', 'ST8SIA3', 'APOF', 'ACMSD', 'VDR', 'SDC2', 'ASPDH', 'ACTL6B', 'PLUNC', 'ENPEP', 'PERP', 'FBLN2', 'CAPN5', 'CTGF', 'AKAP12', 'KITLG', 'SLC2A2', 'ATCAY', 'FN1', 'RAD54L', 'GUCY1A3', 'FBLN1', 'COL12A1', 'MYO1G', 'MYLK', 'SYNPO2', 'TAT', 'SHCBP1', 'TOP2A', 'SERPINC1', 'SLC17A3', 'DBH', 'CXCL12', 'CGNL1', 'COL6A2', 'NNMT', 'CRB3', 'SLAMF8', 'SYNM', 'FEV', 'CFHR3', 'EDNRB', 'ASF1B', 'UGT2B10', 'TAC3', 'SPARCL1', 'KIAA1543', 'SLC17A4', 'KIF23', 'KIFC1', 'CDCA8', 'CRMP1', 'CD247', 'CYP2A6', 'DHCR24', 'KRT78', 'C17orf28', 'RGS5', 'SLC22A7', 'LOC285733', 'SPN', 'EPS8L2', 'SORBS1', 'RHOD', 'SEMA4G', 'S100A4', 'CAV2', 'GPM6B', 'TRIP13', 'RARRES2', 'AGER', 'CCNB2', 'MECOM', 'RHPN2', 'RNASE4', 'SI', 'RLBP1', 'SERPINB11', 'TNFSF10', 'F2', 'CHRNA2', 'AGXT', 'OLIG2', 'PRODH2', 'GPR116', 'PHLDB2', 'CCL5', 'C11orf82', 'CD3D', 'FBN1', 'RRM2', 'STAP2'

In [16]:
import json
    
with open("Mutinfo_300.json", "w") as outfile:
    json.dump(max_mut_info, outfile)