In [17]:
%load_ext autoreload
%autoreload 2

import sys, os
sys.path.insert(0, '..')

import gc
import pysam
import pandas as pd
import re
import torch
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
import tqdm


import helpers.train_eval as train_eval    #train and evaluation
import helpers.misc as misc                #miscellaneous functions
from helpers.plots import MetricsHandler, MotifMetrics
from helpers.motifs import motifs#, motifs_all

import encoding_utils.sequence_encoders as sequence_encoders
import encoding_utils.sequence_utils as sequence_utils
from models.spec_dss import DSSResNet, DSSResNetEmb, SpecAdd

from models.baseline_din import DiNucDist

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
# Parameters
species_agnostic = True

In [19]:

class SeqDataset(Dataset):
    
    def __init__(self, fasta_fa, seq_df, transform, motifs):
        
        self.fasta = pysam.FastaFile(fasta_fa)
        
        self.val_fraction = 0.1 
        N_train = int(len(seq_df) * (1-self.val_fraction))
        self.start_index = N_train 
        self.seq_df = seq_df
        self.transform = transform

        self.motifs = motifs
        
    def __len__(self):
        
        return len(self.seq_df[self.start_index:])
    
    def __getitem__(self, idx):
        
        seq = self.fasta.fetch(self.seq_df.iloc[self.start_index + idx].seq_name).upper()
        #print(seq)
                
        species_label = self.seq_df.iloc[self.start_index + idx].species_label
        #print(species_label)
        # x_batch, y_masked_batch, y_batch, mask_batch, motif_mask_batch 
        masked_sequence, target_labels_masked, target_labels, mask, motif_mask_batch = self.transform(seq, motifs = self.motifs)
        
        masked_sequence = (masked_sequence, species_label)
        return masked_sequence, target_labels_masked, target_labels, motif_mask_batch
    
    def close(self):
        self.fasta.close()

# Reading sequences and filtering motifs


In [20]:
fasta_fa = "../../test/Homo_sapiens_3prime_UTR.fa"
species_list = "../240_species.txt"

# seq_df = pd.read_csv(fasta_fa + '.fai', header=None, sep='\t', usecols=[0], names=['seq_name'])
# seq_df['species_name'] = seq_df.seq_name.apply(lambda x:x.split(':')[1])
# species_encoding = pd.read_csv(species_list, header=None).squeeze().to_dict()
# species_encoding = {species:idx for idx,species in species_encoding.items()}
# species_encoding['Homo_sapiens'] = species_encoding['Pan_troglodytes']

# if not species_agnostic:
#     species_encoding = {species:idx for idx,species in species_encoding.items()}
# else:
#     species_encoding = {species:0 for _,species in species_encoding.items()}

# seq_df['species_label'] = seq_df.species_name.map(species_encoding)

seq_df = pd.read_csv(fasta_fa + '.fai', header=None, sep='\t', usecols=[0], names=['seq_name'])
seq_df['species_name'] = seq_df.seq_name.apply(lambda x:x.split(':')[1])
species_encoding = pd.read_csv(species_list, header=None).squeeze().to_dict()

if not species_agnostic:
    species_encoding = {species:idx for idx,species in species_encoding.items()}
else:
    species_encoding = {species:0 for _,species in species_encoding.items()}

species_encoding['Homo_sapiens'] = species_encoding['Pan_troglodytes']
seq_df['species_label'] = seq_df.species_name.map(species_encoding)

seq_df

Unnamed: 0,seq_name,species_name,species_label
0,ENST00000641515.2_utr3_2_0_chr1_70009_f:Homo_s...,Homo_sapiens,0
1,ENST00000616016.5_utr3_13_0_chr1_944154_f:Homo...,Homo_sapiens,0
2,ENST00000327044.7_utr3_18_0_chr1_944203_r:Homo...,Homo_sapiens,0
3,ENST00000338591.8_utr3_11_0_chr1_965192_f:Homo...,Homo_sapiens,0
4,ENST00000379410.8_utr3_15_0_chr1_974576_f:Homo...,Homo_sapiens,0
...,...,...,...
18129,ENST00000303766.12_utr3_11_0_chrY_22168542_r:H...,Homo_sapiens,0
18130,ENST00000250831.6_utr3_11_0_chrY_22417604_f:Ho...,Homo_sapiens,0
18131,ENST00000303728.5_utr3_4_0_chrY_22514071_f:Hom...,Homo_sapiens,0
18132,ENST00000382407.1_utr3_0_0_chrY_24045793_r:Hom...,Homo_sapiens,0


In [21]:
motifs.dict

{'GGGGG': 8,
 'ACACA': 11,
 'AAAAA': 4,
 'GAAGA': 7,
 'CCCCC': 3,
 'GCATG': 1,
 'GTATG': 2,
 'TTTTT': 9,
 'TTTCT': 12,
 'TATGT': 10,
 'TATAT': 0,
 'TGTAT': 5,
 'ATAAA': 6}

kseq_len = 5000
total_len = 5000

seq_transform = sequence_encoders.RollingMasker()
                       
test_dataset = SeqDataset(fasta_fa, seq_df, transform = seq_transform, motifs=motifs.dict)
test_dataloader = DataLoader(dataset = test_dataset, batch_size = 1, num_workers = 1, collate_fn = None, shuffle = False)
len(test_dataset)



In [23]:
print(torch.cuda.is_available())

True


In [24]:
gc.collect()
torch.cuda.empty_cache()
# test wether cuda is available - if cpu_bool is set to True, cuda is not used

# TODO checkout why gpu isn't working
cpu_bool = True
device = torch.device("cuda" if torch.cuda.is_available() and not cpu_bool else "cpu")

d_model = 128
n_layers = 4
dropout = 0.
learn_rate = 1e-4
weight_decay = 0.
output_dir = "./test/"
get_embeddings = True
save_at = None

species_encoder = SpecAdd(embed = True, encoder = 'label', d_model = 128)

model = DSSResNetEmb(d_input = 5, d_output = 5, d_model = d_model, n_layers = n_layers, 
                     dropout = dropout, embed_before = True, species_encoder = species_encoder)

#model = DiNucDist()

model = model.to(device) 

model_params = [p for p in model.parameters() if p.requires_grad]

#optimizer = torch.optim.Adam(model_params, lr = learn_rate, weight_decay = weight_decay)

last_epoch = 0

In [25]:
if not species_agnostic:
    model_weight = "../../test/MLM_mammals_species_aware_5000_weights"
else :
    model_weight = "../../test/MLM_mammals_species_agnostic_5000_weights"
# load model but avoid torch._C._cuda_getDeviceCount() > 0 failed error
model.load_state_dict(torch.load(model_weight, map_location=device))

RuntimeError: Error(s) in loading state_dict for DiNucDist:
	Unexpected key(s) in state_dict: "encoder.weight", "encoder.bias", "s4_layers.0.D", "s4_layers.0.kernel.log_dt", "s4_layers.0.kernel.Lambda", "s4_layers.0.kernel.W", "s4_layers.0.output_linear.0.weight", "s4_layers.0.output_linear.0.bias", "s4_layers.1.D", "s4_layers.1.kernel.log_dt", "s4_layers.1.kernel.Lambda", "s4_layers.1.kernel.W", "s4_layers.1.output_linear.0.weight", "s4_layers.1.output_linear.0.bias", "s4_layers.2.D", "s4_layers.2.kernel.log_dt", "s4_layers.2.kernel.Lambda", "s4_layers.2.kernel.W", "s4_layers.2.output_linear.0.weight", "s4_layers.2.output_linear.0.bias", "s4_layers.3.D", "s4_layers.3.kernel.log_dt", "s4_layers.3.kernel.Lambda", "s4_layers.3.kernel.W", "s4_layers.3.output_linear.0.weight", "s4_layers.3.output_linear.0.bias", "norms.0.weight", "norms.0.bias", "norms.1.weight", "norms.1.bias", "norms.2.weight", "norms.2.bias", "norms.3.weight", "norms.3.bias", "decoder.weight", "decoder.bias", "resnet_layer.0.conv1.weight", "resnet_layer.0.conv1.bias", "resnet_layer.0.bn1.weight", "resnet_layer.0.bn1.bias", "resnet_layer.0.bn1.running_mean", "resnet_layer.0.bn1.running_var", "resnet_layer.0.bn1.num_batches_tracked", "resnet_layer.0.conv2.weight", "resnet_layer.0.conv2.bias", "resnet_layer.0.bn2.weight", "resnet_layer.0.bn2.bias", "resnet_layer.0.bn2.running_mean", "resnet_layer.0.bn2.running_var", "resnet_layer.0.bn2.num_batches_tracked", "resnet_layer.0.layer.0.weight", "resnet_layer.0.layer.0.bias", "resnet_layer.0.layer.1.weight", "resnet_layer.0.layer.1.bias", "resnet_layer.0.layer.1.running_mean", "resnet_layer.0.layer.1.running_var", "resnet_layer.0.layer.1.num_batches_tracked", "resnet_layer.0.layer.3.weight", "resnet_layer.0.layer.3.bias", "resnet_layer.0.layer.4.weight", "resnet_layer.0.layer.4.bias", "resnet_layer.0.layer.4.running_mean", "resnet_layer.0.layer.4.running_var", "resnet_layer.0.layer.4.num_batches_tracked", "resnet_layer.1.conv1.weight", "resnet_layer.1.conv1.bias", "resnet_layer.1.bn1.weight", "resnet_layer.1.bn1.bias", "resnet_layer.1.bn1.running_mean", "resnet_layer.1.bn1.running_var", "resnet_layer.1.bn1.num_batches_tracked", "resnet_layer.1.conv2.weight", "resnet_layer.1.conv2.bias", "resnet_layer.1.bn2.weight", "resnet_layer.1.bn2.bias", "resnet_layer.1.bn2.running_mean", "resnet_layer.1.bn2.running_var", "resnet_layer.1.bn2.num_batches_tracked", "resnet_layer.1.layer.0.weight", "resnet_layer.1.layer.0.bias", "resnet_layer.1.layer.1.weight", "resnet_layer.1.layer.1.bias", "resnet_layer.1.layer.1.running_mean", "resnet_layer.1.layer.1.running_var", "resnet_layer.1.layer.1.num_batches_tracked", "resnet_layer.1.layer.3.weight", "resnet_layer.1.layer.3.bias", "resnet_layer.1.layer.4.weight", "resnet_layer.1.layer.4.bias", "resnet_layer.1.layer.4.running_mean", "resnet_layer.1.layer.4.running_var", "resnet_layer.1.layer.4.num_batches_tracked", "resnet_layer.2.conv1.weight", "resnet_layer.2.conv1.bias", "resnet_layer.2.bn1.weight", "resnet_layer.2.bn1.bias", "resnet_layer.2.bn1.running_mean", "resnet_layer.2.bn1.running_var", "resnet_layer.2.bn1.num_batches_tracked", "resnet_layer.2.conv2.weight", "resnet_layer.2.conv2.bias", "resnet_layer.2.bn2.weight", "resnet_layer.2.bn2.bias", "resnet_layer.2.bn2.running_mean", "resnet_layer.2.bn2.running_var", "resnet_layer.2.bn2.num_batches_tracked", "resnet_layer.2.layer.0.weight", "resnet_layer.2.layer.0.bias", "resnet_layer.2.layer.1.weight", "resnet_layer.2.layer.1.bias", "resnet_layer.2.layer.1.running_mean", "resnet_layer.2.layer.1.running_var", "resnet_layer.2.layer.1.num_batches_tracked", "resnet_layer.2.layer.3.weight", "resnet_layer.2.layer.3.bias", "resnet_layer.2.layer.4.weight", "resnet_layer.2.layer.4.bias", "resnet_layer.2.layer.4.running_mean", "resnet_layer.2.layer.4.running_var", "resnet_layer.2.layer.4.num_batches_tracked", "species_encoder.species_embedder.weight". 

In [26]:

predictions_dir = os.path.join(output_dir, 'predictions') #dir to save predictions
weights_dir = os.path.join(output_dir, 'weights') #dir to save model weights at save_at epochs
if save_at:
    os.makedirs(weights_dir, exist_ok = True)

def metrics_to_str(metrics):
    loss, total_acc, masked_acc = metrics
    return f'loss: {loss:.4}, total acc: {total_acc:.3f}, masked acc: {masked_acc:.3f}'

print(predictions_dir)


./test/predictions


In [33]:

from helpers.metrics import MaskedAccuracy
def model_eval_check(model, optimizer, dataloader, device, get_embeddings = False, silent=False):
    criterion = torch.nn.CrossEntropyLoss(reduction = "mean")

    metric = MaskedAccuracy()
    motif_metric = MaskedAccuracy()

    model.eval() #model to train mode

    if not silent:
        tot_itr = len(dataloader.dataset)//dataloader.batch_size #total train iterations
        pbar = tqdm(total = tot_itr, ncols=700) #progress bar

    avg_loss, masked_acc, total_acc = 0., 0., 0.
    
    all_embeddings = []
    outputs = []
    with torch.no_grad():
        #               x_batch, y_masked_batch, y_batch, mask_batch, motif_mask_batch
        for itr_idx, (((masked_sequence, species_label), targets_masked, targets, motif_mask)) in enumerate(dataloader):
            
            if get_embeddings:
                #batches are generated by transformation in the dataset,
                #so remove extra batch dimension added by dataloader
                masked_sequence, targets_masked, targets = masked_sequence[0], targets_masked[0], targets[0]
                species_label = species_label.tile((len(masked_sequence),))
            
            masked_sequence = masked_sequence.to(device)
            targets_masked = targets_masked.to(device)

            motif_targets=targets.detach().clone()
            motif_targets[motif_mask.squeeze()== 0] = -100.0
            print(f"{itr_idx}: {motif_targets.shape}")
            motif_targets[targets_masked == -100] = -100.0
            targets = targets.to(device)
            species_label = torch.tensor(species_label).long().to(device)
            
            logits, embeddings = model(masked_sequence, species_label)

            loss = criterion(logits, targets_masked)

            avg_loss += loss.item()
                
            preds = torch.argmax(logits, dim=1)


            test_acc_motif = motif_metric(preds, motif_targets)
            masked_acc += metric(preds, targets_masked).detach() # compute only on masked nucleotides
            total_acc += metric(preds, targets).detach()
            #print(masked_acc/(itr_idx+1))
                
            #if get_embeddings:
                # only get embeddings of the masked nucleotide
                #sequence_embedding = embeddings["seq_embedding"]
                #sequence_embedding = sequence_embedding.transpose(-1,-2)[targets_masked!=-100]
                # shape # B, L, dim  to L,dim, left with only masked nucleotide embeddings
                # average over sequence 
                #print(sequence_embedding.shape)
                #sequence_embedding = sequence_embedding.mean(dim=0) # if we mask
                #sequence_embedding = sequence_embedding[0].mean(dim=-1) # no mask

                #sequence_embedding = sequence_embedding.detach().cpu().numpy()
                #all_embeddings.append(sequence_embedding)
                
            if not silent:
                pbar.update(1)
                pbar.set_description(f"acc: {total_acc/(itr_idx+1):.2}, masked acc: {masked_acc/(itr_idx+1):.2}, motif acc {test_acc_motif/(itr_idx+1):.2} loss: {avg_loss/(itr_idx+1):.4}")
            outputs.append({"loss": loss, "preds": preds, "logits": logits, "targets": targets_masked, "motifs": motif_mask})
    if not silent:
        del pbar
    return outputs


In [34]:
import time
start = time.time()
outputs = model_eval_check(model, None, test_dataloader, device, 
                                                        get_embeddings = get_embeddings, silent = True)
end = time.time()
print("Time taken in mins: ", (end-start)/60)


0: torch.Size([30, 4956])
1: torch.Size([30, 1214])
2: torch.Size([30, 892])
3: torch.Size([30, 511])
4: torch.Size([30, 504])
5: torch.Size([30, 555])
6: torch.Size([30, 2448])
7: torch.Size([30, 2522])
8: torch.Size([30, 6014])
9: torch.Size([30, 2441])
10: torch.Size([30, 891])
11: torch.Size([30, 2879])
12: torch.Size([30, 1550])


  species_label = torch.tensor(species_label).long().to(device)


13: torch.Size([30, 218])
14: torch.Size([30, 154])
15: torch.Size([30, 1021])
16: torch.Size([30, 1276])
17: torch.Size([30, 466])
18: torch.Size([30, 5654])
19: torch.Size([30, 777])
20: torch.Size([30, 2377])
21: torch.Size([30, 5823])
22: torch.Size([30, 426])
23: torch.Size([30, 3850])
24: torch.Size([30, 1714])
25: torch.Size([30, 594])
26: torch.Size([30, 1518])
27: torch.Size([30, 1907])
28: torch.Size([30, 1101])
29: torch.Size([30, 1656])
30: torch.Size([30, 258])
31: torch.Size([30, 1986])
32: torch.Size([30, 2798])
33: torch.Size([30, 126])
34: torch.Size([30, 2245])
35: torch.Size([30, 1003])
36: torch.Size([30, 7883])
37: torch.Size([30, 1511])
38: torch.Size([30, 1162])
39: torch.Size([30, 25])
40: torch.Size([30, 2509])
41: torch.Size([30, 5776])
42: torch.Size([30, 2316])
43: torch.Size([30, 2738])
44: torch.Size([30, 2716])
45: torch.Size([30, 3003])
46: torch.Size([30, 3446])
47: torch.Size([30, 2272])
48: torch.Size([30, 941])
49: torch.Size([30, 1284])
50: torch.Si

In [35]:
import pickle
# TODO 
outputs_file = "../results/dinuc_test/outputs.pickle"
with open(outputs_file, "wb") as f:
    pickle.dump(outputs, f)

In [None]:
print(motifs.dict)


In [36]:
# creates the .pt files and writes them to the cwd
# motif dict can be gotten from the motif handler
motif_metrics = MotifMetrics(outputs=outputs, plots_dir="../results/dinuc_test/", motif_dict=motifs.dict, save=True)

No nan entries in preds
