In [1]:
%load_ext autoreload
%autoreload 2

import sys, os
sys.path.insert(0, '../../')

import gc
import pysam

import pandas as pd
import re
import torch
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
import tqdm
import openpyxl


import helpers.train_eval as train_eval    #train and evaluation
import helpers.misc as misc                #miscellaneous functions
from helpers.plots import MetricsHandler, MotifMetrics
from helpers.motifs import motifs

import encoding_utils.sequence_encoders as sequence_encoders
import encoding_utils.sequence_utils as sequence_utils
from models.spec_dss import DSSResNet, DSSResNetEmb, SpecAdd

In [13]:
### Finding all the 

xlsx_file_path = "../../../dataset/1-s2.0-S1097276518303514-mmc4.xlsx"



def find_index(element, my_list):
    return next((index for index, value in enumerate(my_list) if value == element), -1)


# Load the XLSX file
workbook = openpyxl.load_workbook(xlsx_file_path)
sheet = workbook['logo_5mers.prop_in_logo']  # Replace "Sheet1" with the actual sheet name
data = []
for row in sheet.iter_rows(values_only=True):
    data.append(row)

strong_motifs = []
seq_only = []
motif_tup = []
r1_val = []
protein_name = []
iteri =0 
# Print the data

for i,row in enumerate(data):
    if i == 0:
        continue
    if i == 1:
        for j in range(0,int(len(row)/2)):
            protein_name.append(row[(j*2)])
    if i > 1:        
        for j in range(0,int(len(row)/2)):
            mot = row[(j*2)]
            r1 = row[(j*2)+1]
            if mot not in strong_motifs and mot!=None:
                strong_motifs.append(mot)
                seq_only.append(mot)
                motif_tup.append((protein_name[j],mot,iteri))
                iteri +=1
                r1_val.append(r1)
            elif mot in strong_motifs: 
                ind = find_index(mot,strong_motifs)
                if r1_val[ind] < r1 :
                    r1_val[ind] = r1


print (motif_tup)
print (len(protein_name))
print (len(strong_motifs))

[('A1CF', 'AATTA', 0), ('BOLL', 'TTTTT', 1), ('CELF1', 'TATGT', 2), ('CNOT4', 'ACACA', 3), ('DAZ3', 'AGTTA', 4), ('DAZAP1', 'ATATA', 5), ('EIF4G2', 'GTTGC', 6), ('ESRP1', 'GGGGG', 7), ('FUBP3', 'TATAT', 8), ('HNRNPA0', 'TATAG', 9), ('HNRNPD', 'TATTA', 10), ('HNRNPDL', 'TAATT', 11), ('HNRNPK', 'GCCCA', 12), ('KHDRBS2', 'ATAAA', 13), ('KHSRP', 'TGTAT', 14), ('MBNL1', 'CGCTT', 15), ('MSI1', 'TAGTT', 16), ('NOVA1', 'TTCAT', 17), ('NUPL2', 'AAAAA', 18), ('PCBP1', 'GCCCC', 19), ('PCBP2', 'CCCCC', 20), ('PCBP4', 'ATCCC', 21), ('PRR3', 'ATAAG', 22), ('PTBP3', 'TTTCT', 23), ('RBFOX2', 'GCATG', 24), ('RBM22', 'ACCGG', 25), ('RBM24', 'GTGTG', 26), ('RBM4', 'GCGCG', 27), ('RBM41', 'TACTT', 28), ('RBM45', 'ACGCA', 29), ('RBM47', 'AATCA', 30), ('RBM6', 'CGTCC', 31), ('RC3H1', 'ATATT', 32), ('SF1', 'TAACA', 33), ('SFPQ', 'TGTAA', 34), ('SNRPA', 'TGCAC', 35), ('SRSF10', 'AGCAG', 36), ('SRSF11', 'AGGGG', 37), ('SRSF8', 'GCAGC', 38), ('SRSF9', 'AGGAG', 39), ('TARDBP', 'GTATG', 40), ('TRA2A', 'GAAGA', 41

In [15]:
from helpers.motifs import MotifHandler

new_motifs = MotifHandler(motif_tup)

In [33]:
class SeqDataset(Dataset):
    
    def __init__(self, fasta_fa, seq_df, transform, motifs):
        
        self.fasta = pysam.FastaFile(fasta_fa)
        
        self.val_fraction = 0.1 
        N_train = int(len(seq_df) * (1-self.val_fraction))
        self.start_index = N_train 
        self.seq_df = seq_df
        self.transform = transform

        self.motifs = motifs
        
    def __len__(self):
        
        return len(self.seq_df[self.start_index:])
    
    def __getitem__(self, idx):
        
        seq = self.fasta.fetch(self.seq_df.iloc[self.start_index + idx].seq_name).upper()
        #print(seq)
                
        species_label = self.seq_df.iloc[self.start_index + idx].species_label
        #print(species_label)
        # x_batch, y_masked_batch, y_batch, mask_batch, motif_mask_batch 
        masked_sequence, target_labels_masked, target_labels, mask, motif_mask_batch = self.transform(seq, motifs = self.motifs)
        
        masked_sequence = (masked_sequence, species_label)
        return masked_sequence, target_labels_masked, target_labels, motif_mask_batch
    
    def close(self):
        self.fasta.close()

In [34]:
fasta_fa = "../../../dataset/data_homo/Homo_sapiens_3prime_UTR.fa"
species_list = "../../240_species.txt"
species_agnostic = False

seq_df = pd.read_csv(fasta_fa + '.fai', header=None, sep='\t', usecols=[0], names=['seq_name'])
seq_df['species_name'] = seq_df.seq_name.apply(lambda x:x.split(':')[1])
species_encoding = pd.read_csv(species_list, header=None).squeeze().to_dict()

if not species_agnostic:
    species_encoding = {species:idx for idx,species in species_encoding.items()}
else:
    species_encoding = {species:0 for _,species in species_encoding.items()}

species_encoding['Homo_sapiens'] = species_encoding['Pan_troglodytes']
seq_df['species_label'] = seq_df.species_name.map(species_encoding)

seq_df

Unnamed: 0,seq_name,species_name,species_label
0,ENST00000641515.2_utr3_2_0_chr1_70009_f:Homo_s...,Homo_sapiens,181
1,ENST00000616016.5_utr3_13_0_chr1_944154_f:Homo...,Homo_sapiens,181
2,ENST00000327044.7_utr3_18_0_chr1_944203_r:Homo...,Homo_sapiens,181
3,ENST00000338591.8_utr3_11_0_chr1_965192_f:Homo...,Homo_sapiens,181
4,ENST00000379410.8_utr3_15_0_chr1_974576_f:Homo...,Homo_sapiens,181
...,...,...,...
18129,ENST00000303766.12_utr3_11_0_chrY_22168542_r:H...,Homo_sapiens,181
18130,ENST00000250831.6_utr3_11_0_chrY_22417604_f:Ho...,Homo_sapiens,181
18131,ENST00000303728.5_utr3_4_0_chrY_22514071_f:Hom...,Homo_sapiens,181
18132,ENST00000382407.1_utr3_0_0_chrY_24045793_r:Hom...,Homo_sapiens,181


In [46]:
kseq_len = 5000
total_len = 5000

seq_transform = sequence_encoders.RollingMasker()      
test_dataset = SeqDataset(fasta_fa, seq_df, transform = seq_transform, motifs=new_motifs.dict)
test_dataloader = DataLoader(dataset = test_dataset, batch_size = 1, num_workers = 1, collate_fn = None, shuffle = False)
len(test_dataset)

1814

In [47]:
gc.collect()
torch.cuda.empty_cache()
# test wether cuda is available - if cpu_bool is set to True, cuda is not used

# TODO checkout why gpu isn't working
cpu_bool = True
device = torch.device("cuda" if torch.cuda.is_available() and not cpu_bool else "cpu")

d_model = 128
n_layers = 4
dropout = 0.
learn_rate = 1e-4
weight_decay = 0.
output_dir = "./test/"
get_embeddings = True
save_at = None

species_encoder = SpecAdd(embed = True, encoder = 'label', d_model = 128)

model = DSSResNetEmb(d_input = 5, d_output = 5, d_model = d_model, n_layers = n_layers, 
                     dropout = dropout, embed_before = True, species_encoder = species_encoder)

model = model.to(device) 

model_params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.Adam(model_params, lr = learn_rate, weight_decay = weight_decay)

last_epoch = 0

In [48]:
if not species_agnostic:
    model_weight = "../../../dataset/data_homo/MLM_mammals_species_aware_5000_weights"
else :
    model_weight = "../../../dataset/data_homo/MLM_mammals_species_agnostic_5000_weights"
# load model but avoid torch._C._cuda_getDeviceCount() > 0 failed error
model.load_state_dict(torch.load(model_weight, map_location=device))

<All keys matched successfully>

In [49]:

predictions_dir = os.path.join(output_dir, 'predictions') #dir to save predictions
weights_dir = os.path.join(output_dir, 'weights') #dir to save model weights at save_at epochs
if save_at:
    os.makedirs(weights_dir, exist_ok = True)

def metrics_to_str(metrics):
    loss, total_acc, masked_acc = metrics
    return f'loss: {loss:.4}, total acc: {total_acc:.3f}, masked acc: {masked_acc:.3f}'


In [50]:

from helpers.metrics import MaskedAccuracy
def model_eval_check(model, optimizer, dataloader, device, get_embeddings = False, silent=False):
    criterion = torch.nn.CrossEntropyLoss(reduction = "mean")

    metric = MaskedAccuracy()
    motif_metric = MaskedAccuracy()

    model.eval() #model to train mode

    if not silent:
        tot_itr = len(dataloader.dataset)//dataloader.batch_size #total train iterations
        pbar = tqdm(total = tot_itr, ncols=700) #progress bar

    avg_loss, masked_acc, total_acc = 0., 0., 0.
    
    all_embeddings = []
    outputs = []
    with torch.no_grad():
        #               x_batch, y_masked_batch, y_batch, mask_batch, motif_mask_batch
        for itr_idx, (((masked_sequence, species_label), targets_masked, targets, motif_mask)) in enumerate(dataloader):
            
            if get_embeddings:
                #batches are generated by transformation in the dataset,
                #so remove extra batch dimension added by dataloader
                masked_sequence, targets_masked, targets = masked_sequence[0], targets_masked[0], targets[0]
                species_label = species_label.tile((len(masked_sequence),))
            
            masked_sequence = masked_sequence.to(device)
            targets_masked = targets_masked.to(device)

            motif_targets=targets.detach().clone()
            motif_targets[motif_mask.squeeze()== 0] = -100.0
            print(f"{itr_idx}: {motif_targets.shape}")
            motif_targets[targets_masked == -100] = -100.0
            targets = targets.to(device)
            species_label = torch.tensor(species_label).long().to(device)
            
            logits, embeddings = model(masked_sequence, species_label)

            loss = criterion(logits, targets_masked)

            avg_loss += loss.item()
                
            preds = torch.argmax(logits, dim=1)


            test_acc_motif = motif_metric(preds, motif_targets)
            masked_acc += metric(preds, targets_masked).detach() # compute only on masked nucleotides
            total_acc += metric(preds, targets).detach()
            #print(masked_acc/(itr_idx+1))
                
            if get_embeddings:
                # only get embeddings of the masked nucleotide
                sequence_embedding = embeddings["seq_embedding"]
                sequence_embedding = sequence_embedding.transpose(-1,-2)[targets_masked!=-100]
                # shape # B, L, dim  to L,dim, left with only masked nucleotide embeddings
                # average over sequence 
                #print(sequence_embedding.shape)
                sequence_embedding = sequence_embedding.mean(dim=0) # if we mask
                #sequence_embedding = sequence_embedding[0].mean(dim=-1) # no mask

                sequence_embedding = sequence_embedding.detach().cpu().numpy()
                all_embeddings.append(sequence_embedding)
                
            if not silent:
                pbar.update(1)
                pbar.set_description(f"acc: {total_acc/(itr_idx+1):.2}, masked acc: {masked_acc/(itr_idx+1):.2}, motif acc {test_acc_motif/(itr_idx+1):.2} loss: {avg_loss/(itr_idx+1):.4}")
            outputs.append({"loss": loss, "preds": preds, "logits": logits, "targets": targets_masked, "motifs": motif_mask})
    if not silent:
        del pbar
    return outputs


In [None]:
import time
start = time.time()
outputs = model_eval_check(model, optimizer, test_dataloader, device, 
                                                        get_embeddings = get_embeddings, silent = True)
end = time.time()
print("Time taken in mins: ", (end-start)/60)

0: torch.Size([30, 4956])


  species_label = torch.tensor(species_label).long().to(device)


1: torch.Size([30, 1214])
2: torch.Size([30, 892])
3: torch.Size([30, 511])
4: torch.Size([30, 504])
5: torch.Size([30, 555])
6: torch.Size([30, 2448])
7: torch.Size([30, 2522])
8: torch.Size([30, 6014])


In [44]:
import pickle
outputs_file = "../../../our_code/results/species_aware/outputs.pickle"
with open(outputs_file, "wb") as f:
    pickle.dump(outputs, f)

In [45]:
motif_metrics = MotifMetrics(outputs=outputs, plots_dir="../../../our_code/results/species_aware/", motif_dict=new_motifs.dict, save=True)

No nan entries in preds


In [None]:
#sns.set_theme(context="poster")
sns.set_theme()
#sns.set_context("talk", font_scale=1.5, rc={"lines.linewidth": 2.5})
sns.set_context("talk", font_scale=1.7)#, rc={"font.size": 7})
#
#sns.set_context("poster", 
sns.set_style(style={'xtick.bottom': True,'ytick.left': True, 'axes.edgecolor': 'black'})
#sns.set_style("ticks")