In [1]:
# basics + plotting
import os, sys
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.dpi"] = 250
plt.rcParams["font.family"] = "sans serif"

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler

# custom
PROJECT_PATH = '/'.join(os.getcwd().split('/')[:-1])
sys.path.insert(1, PROJECT_PATH)

from utils import (
    data_utils, 
    eval_utils, 
    plotting_utils, 
    train_test_utils
)

from models import (
    esm_transfer_learning,
    gnn
)

import importlib
data_utils = importlib.reload(data_utils)
eval_utils = importlib.reload(eval_utils)
train_test_utils = importlib.reload(train_test_utils)
esm_transfer_learning = importlib.reload(esm_transfer_learning)
gnn = importlib.reload(gnn)

In [2]:
data = data_utils.load_variation_dataset("../data/data/", 
                                         "../gene_list.txt", 
                                         ["esm-tokens", "indicators"], 
                                         "../data/phenotypes_hcm_only.parquet",
                                         predict=["hcm"], 
                                         low_memory=True,
                                         embeddings_file='esm2m_embeddings.npy',
                                         ppi_graph_path='../ppi_networks/string_interactions_short.tsv')

train_dataset, test_dataset = data.train_test_split(balance_on=['hcm','ethnicity'])

Fetching FHL1 data ... Done
Fetching ACTC1 data ... Done
Fetching ACTN2 data ... Done
Fetching CSRP3 data ... Done
Fetching MYBPC3 data ... Done
Fetching MYH6 data ... Done
Fetching MYH7 data ... Done
Fetching MYL2 data ... Done
Fetching MYL3 data ... Done
Fetching MYOZ2 data ... Done
Fetching LDB3 data ... Done
Fetching TCAP data ... Done
Fetching TNNC1 data ... Done
Fetching TNNI3 data ... Done
Fetching TNNT2 data ... Done
Fetching TPM1 data ... Done
Fetching TRIM63 data ... Done
Fetching PLN data ... Done
Fetching JPH2 data ... Done
Fetching FLNC data ... Done
Fetching ALPK3 data ... Done
Fetching LMNA data ... Done
Fetching NEXN data ... Done
Fetching VCL data ... Done
Fetching MYOM2 data ... Done
Fetching CASQ2 data ... Done
Fetching CAV3 data ... Done
Fetching MYLK2 data ... Done
Fetching CRYAB data ... Done
Combining tables ... Done
Integrating with phenotypes data ...Done
Processing PPI graph ...Done


In [4]:
batch_size = 8

train_loader = DataLoader(
    dataset = train_dataset, 
    batch_size = batch_size,
    sampler = WeightedRandomSampler(train_dataset.weights('hcm',flatten_factor=1), num_samples = len(train_dataset)),
    num_workers=14
)
    
test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    # sampler = WeightedRandomSampler(test_dataset.weights('hcm',flatten_factor=1), num_samples = 20),
    num_workers=14
)

gat_params = {
    'in_dim': 0,
    'embed_dim': 256,
    'n_heads': 4,
    'n_nodes': 0,
    'mlp_hidden_dims': [128],
    'mlp_actn': 'gelu'
}

gcn_params = {
    'in_dim': 0,
    'embed_dim': 256,
    'n_nodes': 0,
    'mlp_hidden_dims': [256],
    'mlp_actn': 'leakyrelu'
}

mlp_params = {
    'in_dim': 0,
    'hidden_dims': [512],
    'out_dim':0,
    'actn': 'gelu'
}

model = esm_transfer_learning.ESMTransferLearner(
    esm_model_name="esm2_t6_8M_UR50D",
    agg_emb_method="weighted_average", 
    predict_method="gat",
    n_genes=29,
    predictor_params=gat_params,
    add_residue_features=data.get_feature_dimensions(['indicators']),
    edge_index=data.edge_index.to('cuda'),
    first_finetune_layer=-1,
)
    

Preparing for fine-tuning:
- Freezing layer embed_tokens.weight
- Freezing layer layers.0.self_attn.k_proj.weight
- Freezing layer layers.0.self_attn.k_proj.bias
- Freezing layer layers.0.self_attn.v_proj.weight
- Freezing layer layers.0.self_attn.v_proj.bias
- Freezing layer layers.0.self_attn.q_proj.weight
- Freezing layer layers.0.self_attn.q_proj.bias
- Freezing layer layers.0.self_attn.out_proj.weight
- Freezing layer layers.0.self_attn.out_proj.bias
- Freezing layer layers.0.self_attn_layer_norm.weight
- Freezing layer layers.0.self_attn_layer_norm.bias
- Freezing layer layers.0.fc1.weight
- Freezing layer layers.0.fc1.bias
- Freezing layer layers.0.fc2.weight
- Freezing layer layers.0.fc2.bias
- Freezing layer layers.0.final_layer_norm.weight
- Freezing layer layers.0.final_layer_norm.bias
- Freezing layer layers.1.self_attn.k_proj.weight
- Freezing layer layers.1.self_attn.k_proj.bias
- Freezing layer layers.1.self_attn.v_proj.weight
- Freezing layer layers.1.self_attn.v_proj.b

In [5]:
train_test_utils.train(model, 
      train_loader,
      test_loader,
      save_model_to = '../data/results/esm2s_wavg_gat',
      save_metrics_to = 'esm2s_wavg_gat',
      log_every = 10,
      lr=1e-3, 
      n_epochs=6)

Epoch #0:


RuntimeError: CUDA out of memory. Tried to allocate 1.28 GiB (GPU 0; 22.17 GiB total capacity; 20.08 GiB already allocated; 424.81 MiB free; 20.77 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF