In [1]:
import torch
import torch.utils.data as data
from dataset import SNPmarkersDataset
import json
from utils import train_DL_model
import numpy as np
import random
from torch.utils.data import Dataset
from sklearn.feature_selection import mutual_info_regression
from Models.GPTransformer import GPTransformer, EmbeddingType

In [2]:
BATCH_SIZE = 8
LEARNING_RATE = 1e-5
DROPOUT = 0
N_EMBEDDING = 100
N_HEADS = 2
N_LAYERS = 2
HIDDEN_NODES = 256
N_EPOCHS = 5
MASK_PROBABILITY = 0.2
HIDDEN_MLP_SIZE = [128, 32]
LINEAR_PROJECTOR_OUTPUT_SIZE = 4

In [3]:
selected_phenotypes = "ep_res"

In [4]:

mi = np.zeros(36304)
modes = ["local_train", "validation", "test"]
X_train = []
y_train = []
X_val = []
y_val = []
for mode in modes:
    dataset = SNPmarkersDataset(mode = mode, skip_check=True)
    dataset.set_phenotypes = selected_phenotypes
    
    X = dataset.get_all_SNP()
    y = dataset.phenotypes[selected_phenotypes]
    
    # Save the results to avoid fetching two times the sames values later on
    if mode == "local_train":
        X_train = X
        y_train = y 
    if mode == "validation":
        X_val = X
        y_val = y 
"""    
    mi += mutual_info_regression(X,y, n_jobs=-1, discrete_features=True, random_state=2307)

# Divide the number of modes to obtain the average mutual information
mi /= len(modes)
indexes = np.where(mi < 0.02)[0]
print(f"Nb of selected features: {len(indexes)}")
"""
mi = np.random.choice([0, 1], size=36304, p=[.9, .1])
indexes = np.where(mi == 1)[0]
print(f"Nb of selected features: {len(indexes)}")

Nb of selected features: 3562


In [5]:
class SNPResidualDataset(Dataset):

    def __init__(self, X, y):
        self.X = X
        self.y = y
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, index):
        return self.X[index], self.y[index]

def convert_categorical_to_frequency(data, path = "gptranformer_embedding_data.json"):
    with open(path,"r") as f:
        freq_data = json.load(f)
    
    results = []
    for sample in data:
        func = lambda t: [freq_data[str(t[0])]["p"]**2, 2*freq_data[str(t[0])]["p"]*freq_data[str(t[0])]["q"],freq_data[str(t[0])]["q"]**2].__getitem__(t[1])
        results.append(list(map(func, enumerate(sample))))
    return np.array(results, dtype=np.float32)


train_dataset = SNPResidualDataset(X_train[indexes].to_numpy(dtype=np.int32), y_train.to_numpy(dtype=np.float32))
validation_dataset = SNPResidualDataset(X_val[indexes].to_numpy(dtype=np.int32), y_val.to_numpy(dtype=np.float32))
            

In [6]:
model = GPTransformer(
    n_features= 3,
    sequence_length=  len(indexes),
    embedding_size=N_EMBEDDING, 
    n_hidden=HIDDEN_NODES,
    n_heads=N_HEADS,
    n_blocks=N_LAYERS,
    mask_probability= MASK_PROBABILITY,
    output_hidden_size= HIDDEN_MLP_SIZE,
    embedding_type = EmbeddingType.EmbeddingTable,
    linear_projector_output_size = LINEAR_PROJECTOR_OUTPUT_SIZE,
)

# Define function and seed to fix the loading via the dataloader (from https://pytorch.org/docs/stable/notes/randomness.html#pytorch)
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

train_dataloader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = 4, worker_init_fn=seed_worker)
validation_dataloader = data.DataLoader(validation_dataset, batch_size=BATCH_SIZE, num_workers = 4, worker_init_fn=seed_worker)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = torch.nn.L1Loss()

In [7]:
train_DL_model(
    model,
    optimizer,
    train_dataloader,
    validation_dataloader,
    N_EPOCHS,
    criterion,
    phenotype=selected_phenotypes,
    log_wandb=False,
    early_stop_n_epoch=5,
)

Devices detected: cpu
Model architecture : 
 GPTransformer(
  (mask): Dropout(p=0.2, inplace=False)
  (embedding): Embedding(3, 100)
  (preprocessing): PositionalEncoding(
    (dropout): Dropout(p=0, inplace=False)
  )
  (transformer): Sequential(
    (0): TransformerBlock(
      (multihead): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=100, out_features=100, bias=True)
      )
      (norm1): LayerNorm((100,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0, inplace=False)
      (fc1): Linear(in_features=100, out_features=256, bias=True)
      (relu): ReLU()
      (dropout2): Dropout(p=0, inplace=False)
      (fc2): Linear(in_features=256, out_features=100, bias=True)
      (norm2): LayerNorm((100,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformerBlock(
      (multihead): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=100, out_features=100, bias=True)
      )
      (norm1): Laye

KeyboardInterrupt: 