In [1]:
%load_ext autoreload
%autoreload 2

import sys, os
sys.path.insert(0, '..')

import gc
import pysam
import pandas as pd
import re
import torch
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
import tqdm


import helpers.train_eval as train_eval    #train and evaluation
import helpers.misc as misc                #miscellaneous functions
from helpers.plots import MetricsHandler, MotifMetrics

import encoding_utils.sequence_encoders as sequence_encoders
import encoding_utils.sequence_utils as sequence_utils
from models.spec_dss import DSSResNet, DSSResNetEmb, SpecAdd

In [2]:

class SeqDataset(Dataset):
    
    def __init__(self, fasta_fa, seq_df, transform, motifs):
        
        self.fasta = pysam.FastaFile(fasta_fa)
        
        self.seq_df = seq_df
        self.transform = transform

        self.motifs = motifs
        
    def __len__(self):
        
        return len(self.seq_df)
    
    def __getitem__(self, idx):
        
        seq = self.fasta.fetch(self.seq_df.iloc[idx].seq_name).upper()
        #print(seq)
                
        species_label = self.seq_df.iloc[idx].species_label
        #print(species_label)
        # x_batch, y_masked_batch, y_batch, mask_batch, motif_mask_batch 
        masked_sequence, target_labels_masked, target_labels, mask, motif_mask_batch = self.transform(seq, motifs = self.motifs)
        
        masked_sequence = (masked_sequence, species_label)
        return masked_sequence, target_labels_masked, target_labels, motif_mask_batch
    
    def close(self):
        self.fasta.close()

# Reading sequences and filtering motifs


In [3]:
fasta_fa = "../../test/Homo_sapiens_3prime_UTR.fa"
species_list = "../240_species.txt"

seq_df = pd.read_csv(fasta_fa + '.fai', header=None, sep='\t', usecols=[0], names=['seq_name'])
seq_df['species_name'] = seq_df.seq_name.apply(lambda x:x.split(':')[1])
species_encoding = pd.read_csv(species_list, header=None).squeeze().to_dict()
species_encoding = {species:idx for idx,species in species_encoding.items()}
species_encoding['Homo_sapiens'] = species_encoding['Pan_troglodytes']
seq_df['species_label'] = seq_df.species_name.map(species_encoding)
seq_df

Unnamed: 0,seq_name,species_name,species_label
0,ENST00000641515.2_utr3_2_0_chr1_70009_f:Homo_s...,Homo_sapiens,181
1,ENST00000616016.5_utr3_13_0_chr1_944154_f:Homo...,Homo_sapiens,181
2,ENST00000327044.7_utr3_18_0_chr1_944203_r:Homo...,Homo_sapiens,181
3,ENST00000338591.8_utr3_11_0_chr1_965192_f:Homo...,Homo_sapiens,181
4,ENST00000379410.8_utr3_15_0_chr1_974576_f:Homo...,Homo_sapiens,181
...,...,...,...
18129,ENST00000303766.12_utr3_11_0_chrY_22168542_r:H...,Homo_sapiens,181
18130,ENST00000250831.6_utr3_11_0_chrY_22417604_f:Ho...,Homo_sapiens,181
18131,ENST00000303728.5_utr3_4_0_chrY_22514071_f:Hom...,Homo_sapiens,181
18132,ENST00000382407.1_utr3_0_0_chrY_24045793_r:Hom...,Homo_sapiens,181


In [4]:
from helpers.motifs import MotifHandler

motif_overlap = [
    ("EWSR1","GGGGG"),
    ("FUS", "GGGGG"),
    ("TAF15", "GGGGG"),
    ("HNRNPL", "ACACA"),
    ("PABPN1L", "AAAAA"),
    ("TRA2A", "GAAGA"),
    ("PCBP2", "CCCCC"),
    ("RBFOX2", "GCATG"),
    ("TARDBP", "GTATG"),
    ("HNRNPC", "TTTTT"),
    ("TIA1","TTTTT"),
    ("PTBP3", "TTTCT"),
    ("CELF1", "TATGT"),
    ("FUBP3", "TATAT"),
    ("KHSRP", "TGTAT"),
    ("PUM1", "TGTAT"),
    ("KHDRBS2", "ATAAA")
]

# create a motif to id mapping, ids can overlap but proteins should not
motifs = list(set(map(lambda x: x[1], motif_overlap)))
motifs = dict(zip(motifs, range(len(motifs)))) # (motif, id)

#now add ids to the motif_overlap
motif_overlap = list(map(lambda x: (x[0], x[1], motifs[x[1]]), motif_overlap))

# MotifHandler takes a list of tuples 
# (protein, motif, id, motif_regex)
motifs = MotifHandler(motif_overlap)
motifs.dict

{'GGGGG': 7,
 'ACACA': 4,
 'AAAAA': 0,
 'GAAGA': 1,
 'CCCCC': 12,
 'GCATG': 11,
 'GTATG': 9,
 'TTTTT': 2,
 'TTTCT': 10,
 'TATGT': 3,
 'TATAT': 8,
 'TGTAT': 6,
 'ATAAA': 5}

In [5]:
kseq_len = 5000
total_len = 5000

seq_transform = sequence_encoders.RollingMasker()
                       
test_dataset = SeqDataset(fasta_fa, seq_df, transform = seq_transform, motifs=motifs.dict)
test_dataloader = DataLoader(dataset = test_dataset, batch_size = 1, num_workers = 1, collate_fn = None, shuffle = False)



In [7]:
print(torch.cuda.is_available())

False


In [10]:
gc.collect()
torch.cuda.empty_cache()
# test wether cuda is available - if cpu_bool is set to True, cuda is not used

cpu_bool = False
device = torch.device("cuda" if torch.cuda.is_available() and not cpu_bool else "cpu")

d_model = 128
n_layers = 4
dropout = 0.
learn_rate = 1e-4
weight_decay = 0.
output_dir = "./test/"
get_embeddings = True
save_at = None

species_encoder = SpecAdd(embed = True, encoder = 'label', d_model = 128)

model = DSSResNetEmb(d_input = 5, d_output = 5, d_model = d_model, n_layers = n_layers, 
                     dropout = dropout, embed_before = True, species_encoder = species_encoder)

model = model.to(device) 

model_params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.Adam(model_params, lr = learn_rate, weight_decay = weight_decay)

last_epoch = 0

In [11]:

model_weight = "../../test/MLM_mammals_species_aware_5000_weights"
# load model but avoid torch._C._cuda_getDeviceCount() > 0 failed error
model.load_state_dict(torch.load(model_weight, map_location=device))

<All keys matched successfully>

In [12]:

predictions_dir = os.path.join(output_dir, 'predictions') #dir to save predictions
weights_dir = os.path.join(output_dir, 'weights') #dir to save model weights at save_at epochs
if save_at:
    os.makedirs(weights_dir, exist_ok = True)

def metrics_to_str(metrics):
    loss, total_acc, masked_acc = metrics
    return f'loss: {loss:.4}, total acc: {total_acc:.3f}, masked acc: {masked_acc:.3f}'


In [13]:

from helpers.metrics import MaskedAccuracy
def model_eval_check(model, optimizer, dataloader, device, get_embeddings = False, silent=False):
    criterion = torch.nn.CrossEntropyLoss(reduction = "mean")

    metric = MaskedAccuracy().to(device)
    motif_metric = MaskedAccuracy().to(device)

    model.eval() #model to train mode

    if not silent:
        tot_itr = len(dataloader.dataset)//dataloader.batch_size #total train iterations
        pbar = tqdm(total = tot_itr, ncols=700) #progress bar

    avg_loss, masked_acc, total_acc = 0., 0., 0.
    
    all_embeddings = []
    outputs = []
    with torch.no_grad():
        #               x_batch, y_masked_batch, y_batch, mask_batch, motif_mask_batch
        for itr_idx, (((masked_sequence, species_label), targets_masked, targets, motif_mask)) in enumerate(dataloader):
            
            if get_embeddings:
                #batches are generated by transformation in the dataset,
                #so remove extra batch dimension added by dataloader
                masked_sequence, targets_masked, targets = masked_sequence[0], targets_masked[0], targets[0]
                species_label = species_label.tile((len(masked_sequence),))
            
            masked_sequence = masked_sequence.to(device)
            targets_masked = targets_masked.to(device)

            motif_targets=targets.detach().clone()
            motif_targets[motif_mask.squeeze() == 0] = -100.0
            motif_targets[targets_masked == -100] = -100.0
            targets = targets.to(device)
            species_label = torch.tensor(species_label).long().to(device)
            
            logits, embeddings = model(masked_sequence, species_label)

            loss = criterion(logits, targets_masked)

            avg_loss += loss.item()
                
            preds = torch.argmax(logits, dim=1)


            test_acc_motif = motif_metric(preds, motif_targets)
            masked_acc += metric(preds, targets_masked).detach() # compute only on masked nucleotides
            total_acc += metric(preds, targets).detach()
            #print(masked_acc/(itr_idx+1))
                
            if get_embeddings:
                # only get embeddings of the masked nucleotide
                sequence_embedding = embeddings["seq_embedding"]
                sequence_embedding = sequence_embedding.transpose(-1,-2)[targets_masked!=-100]
                # shape # B, L, dim  to L,dim, left with only masked nucleotide embeddings
                # average over sequence 
                #print(sequence_embedding.shape)
                sequence_embedding = sequence_embedding.mean(dim=0) # if we mask
                #sequence_embedding = sequence_embedding[0].mean(dim=-1) # no mask

                sequence_embedding = sequence_embedding.detach().cpu().numpy()
                all_embeddings.append(sequence_embedding)
                
            if not silent:
                pbar.update(1)
                pbar.set_description(f"acc: {total_acc/(itr_idx+1):.2}, masked acc: {masked_acc/(itr_idx+1):.2}, motif acc {test_acc_motif/(itr_idx+1):.2} loss: {avg_loss/(itr_idx+1):.4}")
            if itr_idx == 2:
                break
            outputs.append({"loss": loss, "preds": preds, "logits": logits, "targets": targets_masked, "motifs": motif_mask})
    if not silent:
        del pbar
    return outputs
    #return (avg_loss/(itr_idx+1), total_acc/(itr_idx+1), masked_acc/(itr_idx+1)), all_embeddings



In [14]:
outputs = model_eval_check(model, optimizer, test_dataloader, device, 
                                                        get_embeddings = get_embeddings, silent = True)


  species_label = torch.tensor(species_label).long().to(device)
  return einsum('chn,hnl->chl', W, S).float(), state                   # [C H L]


tensor(0.3989)
tensor(0.4726)
tensor(0.4895)


In [19]:
# creates the .pt files and writes them to the cwd
# motif dict can be gotten from the motif handler
motif_metrics = MotifMetrics(outputs=outputs, plots_dir="../test_dir/", motif_dict=motifs.dict, save=True)