In [6]:
# basics + plotting
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.dpi"] = 250
plt.rcParams["font.family"] = "sans serif"

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler

# custom
import os, sys
UTILS_PATH = '/'.join(os.getcwd().split('/')[:-1] + ['utils'])
MODELS_PATH = '/'.join(os.getcwd().split('/')[:-1] + ['models'])
sys.path.insert(1, UTILS_PATH)
sys.path.insert(1, MODELS_PATH)
import plotting_utils
import data_utils
import mlp

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cpu device


In [16]:
import json

In [9]:
with open("../gene_list.txt", 'r') as file:
    gene_list = [x.strip() for x in file.readlines()]

table_paths = {g:{'variants': f'../data/data/{g}/variants_table.parquet', 
                  'sequences': f'../data/data/{g}/seq_table.parquet', 
                  'patients': f'../data/data/{g}/patients_table.parquet',
                  'haplotypes': f'../data/data/{g}/hap_table.parquet'} for g in gene_list[:]}
data_paths = {g:{
                 'seq-var-matrix': f'../data/data/{g}/seq_var_matrix.npy'
                } 
              for g in gene_list[:]}

In [12]:
phendata = pd.read_csv("../../../../all_phenotyes_aug2022/split_phenotypes/131338.csv", 
                       low_memory=False)\
            .set_index('eid')\
            .fillna(0)\
            .apply({'131338-0.0': lambda x: 1 if x!=0 else 0})
phendata.index = phendata.index.astype(str)

In [14]:
data = data_utils.VariationDataset(table_paths=table_paths, 
                                   data_paths=data_paths, 
                                   phenotypes=phendata, 
                                   keep_genes_separate=False)

Fetching FHL1 data ... Done.
Fetching ACTC1 data ... Done.
Fetching ACTN2 data ... Done.
Fetching CSRP3 data ... Done.
Fetching MYBPC3 data ... Done.
Fetching MYH6 data ... Done.
Fetching MYH7 data ... Done.
Fetching MYL2 data ... Done.
Fetching MYL3 data ... Done.
Fetching MYOZ2 data ... Done.
Fetching LDB3 data ... Done.
Fetching TCAP data ... Done.
Fetching TNNC1 data ... Done.
Fetching TNNI3 data ... Done.
Fetching TNNT2 data ... Done.
Fetching TPM1 data ... Done.
Fetching TRIM63 data ... Done.
Fetching PLN data ... Done.
Fetching JPH2 data ... Done.
Fetching FLNC data ... Done.
Fetching ALPK3 data ... Done.
Fetching LMNA data ... Done.
Fetching NEXN data ... Done.
Fetching VCL data ... Done.
Fetching MYOM2 data ... Done.
Fetching CASQ2 data ... Done.
Fetching CAV3 data ... Done.
Fetching MYLK2 data ... Done.
Fetching CRYAB data ... Done.
Combining tables ... Done.
Integrating with phenotypes data ...Done.


In [30]:
params = json.load(open("../params/mlp.json"))

In [31]:
mlp.MLP(**params['model_params'])

MLP(
  (fc): Sequential(
    (0): Linear(in_features=10, out_features=5, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=5, out_features=5, bias=True)
    (4): ReLU()
    (5): Linear(in_features=5, out_features=2, bias=True)
  )
)