# Notebook ML4RG-Project

First download the data and install the needed packages

In [1]:
![[ ! -d ML4RG-2023-project ]] && git clone https://github.com/Hugenotte585/ML4RG-2023-project.git
!gdown https://drive.google.com/uc?id=16BUHUYXNYvndfsiECB8-C7cwWq82oTg-

Cloning into 'ML4RG-2023-project'...
remote: Enumerating objects: 402, done.[K
remote: Counting objects: 100% (21/21), done.[K
remote: Compressing objects: 100% (17/17), done.[K
remote: Total 402 (delta 13), reused 6 (delta 4), pack-reused 381[K
Receiving objects: 100% (402/402), 9.86 MiB | 15.36 MiB/s, done.
Resolving deltas: 100% (220/220), done.
Downloading...
From: https://drive.google.com/uc?id=16BUHUYXNYvndfsiECB8-C7cwWq82oTg-
To: /content/ML4RG-2023-project.tar
100% 39.8M/39.8M [00:00<00:00, 116MB/s] 


In [2]:
!tar -xvf ML4RG-2023-project.tar
!rm ML4RG-2023-project.tar

Homo_sapiens_3prime_UTR.fa
Homo_sapiens_3prime_UTR.fa.fai
MLM_mammals_species_aware_5000_weights


In [3]:
!pip -q install pysam
!pip -q install torchmetrics
!pip -q install einops
!pip -q install omegaconf
!pip -q install biopython
!pip -q install logomaker

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m20.0/20.0 MB[0m [31m52.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.5/79.5 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.0/117.0 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for antlr4-python3-runtime (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m36.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.8/11.8 MB[0m [31m112.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [5]:
%load_ext autoreload
%autoreload 2

colab = True
import sys, os
if colab:
    sys.path.insert(0, './ML4RG-2023-project')
else:
    sys.path.insert(0, '..')


import gc
import pysam
import pandas as pd
import re
import torch
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np


import helpers.train_eval as train_eval    #train and evaluation
import helpers.misc as misc                #miscellaneous functions

import encoding_utils.sequence_encoders as sequence_encoders
import encoding_utils.sequence_utils as sequence_utils
from models.spec_dss import DSSResNet, DSSResNetEmb, SpecAdd
from models.baseline.markov_model import *
from models.baseline.markov_for_dinuc import *
from Bio import SeqIO
import pickle
import glob

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Example script usage ^^

In [None]:
#!cd ML4RG-2023-project && python main.py --test --fasta ../Homo_sapiens_3prime_UTR.fa --species_list 240_species.txt --output_dir ./test --model_weight ../MLM_mammals_species_aware_5000_weights

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [10]:
!cp -r "/content/prbs.pt" "/content/drive/MyDrive/MLRG2023"

cp: cannot stat '/content/prbs.pt': No such file or directory


In [2]:
# Parameters
species_agnostic = False

In [3]:
class SeqDataset(Dataset):

    def __init__(self, fasta_fa, seq_df, transform, motifs):

        self.fasta = pysam.FastaFile(fasta_fa)

        self.val_fraction = 0.1
        N_train = int(len(seq_df) * (1-self.val_fraction))
        self.start_index = N_train
        self.seq_df = seq_df
        self.transform = transform

        self.motifs = motifs

    def __len__(self):

        return len(self.seq_df[self.start_index:])

    def __getitem__(self, idx):
      seq = self.fasta.fetch(self.seq_df.iloc[self.start_index + idx].seq_name).upper()
        #print(seq)
      species_label = self.seq_df.iloc[self.start_index + idx].species_label
        #print(species_label)
        # x_batch, y_masked_batch, y_batch, mask_batch, motif_mask_batch
      masked_sequence, target_labels_masked, target_labels, mask, motif_mask_batch = self.transform(seq, motifs = self.motifs)

      masked_sequence = (masked_sequence, species_label)
      return masked_sequence, target_labels_masked, target_labels, motif_mask_batch

    def close(self):
      self.fasta.close()

# Read the data

In [7]:
fasta_fa = "./Homo_sapiens_3prime_UTR.fa"
if not colab:
    fasta_fa = glob.glob("../../test/*.fa")[0]
species_list = "ML4RG-2023-project/240_species.txt"
if not colab:
    species_list = "../240_species.txt"
seq_df = pd.read_csv(fasta_fa + '.fai', header=None, sep='\t', usecols=[0], names=['seq_name'])
seq_df['species_name'] = seq_df.seq_name.apply(lambda x:x.split(':')[1])
species_encoding = pd.read_csv(species_list, header=None).squeeze().to_dict()

if not species_agnostic:
    species_encoding = {species:idx for idx,species in species_encoding.items()}
else:
    species_encoding = {species:0 for _,species in species_encoding.items()}

species_encoding['Homo_sapiens'] = species_encoding['Pan_troglodytes']
seq_df['species_label'] = seq_df.species_name.map(species_encoding)

seq_df

Unnamed: 0,seq_name,species_name,species_label
0,ENST00000641515.2_utr3_2_0_chr1_70009_f:Homo_s...,Homo_sapiens,181
1,ENST00000616016.5_utr3_13_0_chr1_944154_f:Homo...,Homo_sapiens,181
2,ENST00000327044.7_utr3_18_0_chr1_944203_r:Homo...,Homo_sapiens,181
3,ENST00000338591.8_utr3_11_0_chr1_965192_f:Homo...,Homo_sapiens,181
4,ENST00000379410.8_utr3_15_0_chr1_974576_f:Homo...,Homo_sapiens,181
...,...,...,...
18129,ENST00000303766.12_utr3_11_0_chrY_22168542_r:H...,Homo_sapiens,181
18130,ENST00000250831.6_utr3_11_0_chrY_22417604_f:Ho...,Homo_sapiens,181
18131,ENST00000303728.5_utr3_4_0_chrY_22514071_f:Hom...,Homo_sapiens,181
18132,ENST00000382407.1_utr3_0_0_chrY_24045793_r:Hom...,Homo_sapiens,181


In [9]:
# Motif:id
motifs = {"GTATG":1}

In [10]:
kseq_len = 5000
total_len = 5000

seq_transform = sequence_encoders.RollingMasker()

test_dataset = SeqDataset(fasta_fa, seq_df, transform = seq_transform, motifs=motifs)
test_dataloader = DataLoader(dataset = test_dataset, batch_size = 1, num_workers = 1, collate_fn = None, shuffle = False)
len(test_dataset)

1814

# Load the model
## Model params

In [11]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [12]:
d_model = 128
n_layers = 4
dropout = 0.
learn_rate = 1e-4
weight_decay = 0.
output_dir = "./test/"
get_embeddings = True
save_at = None

species_encoder = SpecAdd(embed = True, encoder = 'label', d_model = 128)

model = DSSResNetEmb(d_input = 5, d_output = 5, d_model = d_model, n_layers = n_layers,
                     dropout = dropout, embed_before = True, species_encoder = species_encoder)

model = model.to(device)

model_params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.Adam(model_params, lr = learn_rate, weight_decay = weight_decay)

last_epoch = 0

In [13]:
species_agnostic

False

In [14]:
if not species_agnostic:
    model_weight = "MLM_mammals_species_aware_5000_weights"
else :
    model_weight = "MLM_mammals_species_agnostic_5000_weights"
# load model but avoid torch._C._cuda_getDeviceCount() > 0 failed error
model.load_state_dict(torch.load(model_weight, map_location=device))

<All keys matched successfully>

In [15]:
predictions_dir = os.path.join(output_dir, 'predictions') #dir to save predictions
weights_dir = os.path.join(output_dir, 'weights') #dir to save model weights at save_at epochs
if save_at:
    os.makedirs(weights_dir, exist_ok = True)

def metrics_to_str(metrics):
    loss, total_acc, masked_acc = metrics
    return f'loss: {loss:.4}, total acc: {total_acc:.3f}, masked acc: {masked_acc:.3f}'

## Model
If the following line fails:

```
model = model.to(device)
```
Either use:


```
device = torch.device("cpu")
```
Or go to Runtime -> change runtime type -> Hardware Accellerator 'GPU'



In [16]:
from helpers.metrics import MaskedAccuracy
def model_eval_check(model, optimizer, dataloader, device, get_embeddings = False, silent=False):
    criterion = torch.nn.CrossEntropyLoss(reduction = "mean")

    metric = MaskedAccuracy().to(device)
    motif_metric = MaskedAccuracy().to(device)

    model.eval() #model to train mode

    if not silent:
        tot_itr = len(dataloader.dataset)//dataloader.batch_size #total train iterations
        pbar = tqdm(total = tot_itr, ncols=700) #progress bar

    avg_loss, masked_acc, total_acc = 0., 0., 0.

    all_embeddings = []
    outputs = []
    with torch.no_grad():
      for itr_idx, (((masked_sequence, species_label), targets_masked, targets, motif_mask)) in enumerate(dataloader):

            if get_embeddings:
                #batches are generated by transformation in the dataset,
                #so remove extra batch dimension added by dataloader
                masked_sequence, targets_masked, targets = masked_sequence[0], targets_masked[0], targets[0]
                species_label = species_label.tile((len(masked_sequence),))

            masked_sequence = masked_sequence.to(device)
            targets_masked = targets_masked.to(device)

            motif_targets=targets.detach().clone()
            motif_targets[motif_mask.squeeze()== 0] = -100.0
            print(f"{itr_idx}: {motif_targets.shape}")
            motif_targets[targets_masked == -100] = -100.0
            targets = targets.to(device)
            species_label = torch.tensor(species_label).long().to(device)

            logits, embeddings = model(masked_sequence, species_label)

            loss = criterion(logits, targets_masked)
            avg_loss += loss.item()

            preds = torch.argmax(logits, dim=1)
            preds = preds.to(device)
            motif_targets = motif_targets.to(device)


            test_acc_motif = motif_metric(preds, motif_targets)
            masked_acc += metric(preds, targets_masked).detach() # compute only on masked nucleotides
            total_acc += metric(preds, targets).detach()
            #print(masked_acc/(itr_idx+1))

            if get_embeddings:
                # only get embeddings of the masked nucleotide
                sequence_embedding = embeddings["seq_embedding"]
                sequence_embedding = sequence_embedding.transpose(-1,-2)[targets_masked!=-100]
                # shape # B, L, dim  to L,dim, left with only masked nucleotide embeddings
                # average over sequence
                #print(sequence_embedding.shape)
                sequence_embedding = sequence_embedding.mean(dim=0) # if we mask
                #sequence_embedding = sequence_embedding[0].mean(dim=-1) # no mask

                sequence_embedding = sequence_embedding.detach().cpu().numpy()
                all_embeddings.append(sequence_embedding)
            if not silent:
                pbar.update(1)
                pbar.set_description(f"acc: {total_acc/(itr_idx+1):.2}, masked acc: {masked_acc/(itr_idx+1):.2}, motif acc {test_acc_motif/(itr_idx+1):.2} loss: {avg_loss/(itr_idx+1):.4}")
            outputs.append({"loss": loss, "preds": preds, "logits": logits, "targets": targets_masked, "motifs": motif_mask})
    if not silent:
        del pbar
    return outputs

In [17]:
import time
start = time.time()
outputs = model_eval_check(model, optimizer, test_dataloader, device,
                                                        get_embeddings = get_embeddings, silent = True)
end = time.time()
print("Time taken in mins: ", (end-start)/60)

0: torch.Size([30, 4956])


  species_label = torch.tensor(species_label).long().to(device)
  return einsum('chn,hnl->chl', W, S).float(), state                   # [C H L]


1: torch.Size([30, 1214])
2: torch.Size([30, 892])
3: torch.Size([30, 511])
4: torch.Size([30, 504])
5: torch.Size([30, 555])
6: torch.Size([30, 2448])
7: torch.Size([30, 2522])
8: torch.Size([30, 6014])
9: torch.Size([30, 2441])
10: torch.Size([30, 891])
11: torch.Size([30, 2879])
12: torch.Size([30, 1550])
13: torch.Size([30, 218])
14: torch.Size([30, 154])
15: torch.Size([30, 1021])
16: torch.Size([30, 1276])
17: torch.Size([30, 466])
18: torch.Size([30, 5654])
19: torch.Size([30, 777])
20: torch.Size([30, 2377])
21: torch.Size([30, 5823])
22: torch.Size([30, 426])
23: torch.Size([30, 3850])
24: torch.Size([30, 1714])
25: torch.Size([30, 594])
26: torch.Size([30, 1518])
27: torch.Size([30, 1907])
28: torch.Size([30, 1101])
29: torch.Size([30, 1656])
30: torch.Size([30, 258])
31: torch.Size([30, 1986])
32: torch.Size([30, 2798])
33: torch.Size([30, 126])
34: torch.Size([30, 2245])
35: torch.Size([30, 1003])
36: torch.Size([30, 7883])
37: torch.Size([30, 1511])
38: torch.Size([30, 116

KeyboardInterrupt: ignored

In [56]:
import pickle
outputs_file = "outputs.pickle"
with open(outputs_file, "wb") as f:
    pickle.dump(outputs, f)

In [18]:
species_encoder = SpecAdd(embed = True, encoder = 'label', d_model = d_model)

model = DSSResNetEmb(d_input = 5, d_output = 5, d_model = d_model, n_layers = n_layers, dropout = dropout, embed_before = True, species_encoder = species_encoder)

model = model.to(device)

model_params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.Adam(model_params, lr = learn_rate, weight_decay = weight_decay)

last_epoch = 0
model_weight = "MLM_mammals_species_aware_5000_weights"
model.load_state_dict(torch.load(model_weight))

<All keys matched successfully>

In [None]:
predictions_dir = os.path.join(output_dir, 'predictions') #dir to save predictions
weights_dir = os.path.join(output_dir, 'weights') #dir to save model weights at save_at epochs
if save_at:
    os.makedirs(weights_dir, exist_ok = True)

def metrics_to_str(metrics):
    loss, total_acc, masked_acc = metrics
    return f'loss: {loss:.4}, total acc: {total_acc:.3f}, masked acc: {masked_acc:.3f}'

from helpers.misc import print    #print function that displays time
print(f'Test/Inference...')

test_metrics, test_embeddings =  train_eval.model_eval(model, optimizer, test_dataloader, device,
                                                          get_embeddings = get_embeddings, silent = True)




[2023/07/02-08:23:40]- Test/Inference...


  species_label = torch.tensor(species_label).long().to(device)
  return einsum('chn,hnl->chl', W, S).float(), state                   # [C H L]
