In [25]:
import numpy as np
import pandas as pd

import os
import gc

import pysam

import torch
from torch.utils.data import DataLoader, Dataset

from tqdm.notebook import tqdm

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [26]:
from encoding_utils import sequence_encoders

import helpers.train_eval as train_eval    #train and evaluation
import helpers.misc as misc                #miscellaneous functions
from helpers.metrics import MaskedAccuracy

from models.spec_dss import DSSResNet, DSSResNetEmb, SpecAdd

In [27]:
class SeqDataset(Dataset):
    
    def __init__(self, fasta_fa, seq_df, transform):
        
        self.fasta = pysam.FastaFile(fasta_fa)
        
        self.seq_df = seq_df
        self.transform = transform
        
    def __len__(self):
        
        return len(self.seq_df)
    
    def __getitem__(self, idx):
        
        seq = self.fasta.fetch(seq_df.iloc[idx].seq_name).upper()
                
        species_label = seq_df.iloc[idx].species_label
                
        masked_sequence, target_labels_masked, target_labels, mask, _ = self.transform(seq, motifs = {})
        
        masked_sequence = (masked_sequence, species_label)
        
        return masked_sequence, target_labels_masked, target_labels
    
    def close(self):
        self.fasta.close()

In [28]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('\nCUDA device: GPU\n')
else:
    device = torch.device('cpu')
    print('\nCUDA device: CPU\n')
    #raise Exception('CUDA is not found')


CUDA device: CPU



In [29]:
datadir = '/s/project/mll/sergey/effect_prediction/MLM/'

In [30]:
input_params = misc.dotdict({})

input_params.fasta = datadir + 'fasta/240_mammals/240_mammals.shuffled.fa'
input_params.species_list = datadir + 'fasta/240_mammals/240_species.txt'

input_params.output_dir = './test'

input_params.test = False

input_params.seq_len = 5000

input_params.tot_epochs = 11
input_params.val_fraction = 0.1
input_params.train_splits = 4

input_params.save_at = []
input_params.validate_every = 1

input_params.d_model = 128
input_params.n_layers = 4
input_params.dropout = 0.

input_params.batch_size = 128
input_params.learning_rate = 1e-4
input_params.weight_decay = 0

In [43]:
input_params.fasta = datadir + 'griesemer/fasta/GRCh38_UTR_variants.fa'
input_params.model_weight = datadir + 'nnc_logs/seq_len_5000/weights/epoch_11_weights_model'

input_params.output_dir = './test'
input_params.batch_size = 32

input_params.test = True
input_params.get_embeddings = False

In [59]:
seq_df = pd.read_csv(input_params.fasta + '.fai', header=None, sep='\t', usecols=[0], names=['seq_name'])
seq_df['species_name'] = seq_df.seq_name.apply(lambda x:x.split(':')[-1])

#seq_df['seq_len'] = seq_df.seq_name.apply(lambda x:int(x.split(':')[-1]))
#seq_df = seq_df[seq_df.seq_len>60]

species_encoding = pd.read_csv(input_params.species_list, header=None).squeeze().to_dict()
species_encoding = {species:idx for idx,species in species_encoding.items()}
species_encoding['Homo_sapiens'] = species_encoding['Pan_troglodytes']

seq_df['species_label'] = seq_df.species_name.map(species_encoding)

#seq_df = seq_df.sample(frac = 1., random_state = 1) #DO NOT SHUFFLE, otherwise too slow

In [60]:
#seq_df = seq_df.iloc[:100000]

In [61]:
if not input_params.test:
    
    #Train and Validate
    
    seq_transform = sequence_encoders.SequenceDataEncoder(seq_len = input_params.seq_len, total_len = input_params.seq_len, 
                                                      mask_rate = 0.15, split_mask = True)
    
    N_train = int(len(seq_df)*(1-input_params.val_fraction))       
    train_df, test_df = seq_df.iloc[:N_train], seq_df.iloc[N_train:]
                  
    train_fold = np.repeat(list(range(input_params.train_splits)),repeats = N_train // input_params.train_splits + 1 )
    train_df['train_fold'] = train_fold[:N_train]

    train_dataset = SeqDataset(input_params.fasta, train_df, transform = seq_transform)
    train_dataloader = DataLoader(dataset = train_dataset, batch_size = input_params.batch_size, num_workers = 2, collate_fn = None, shuffle = False)

    test_dataset = SeqDataset(input_params.fasta, test_df, transform = seq_transform)
    test_dataloader = DataLoader(dataset = test_dataset, batch_size = input_params.batch_size, num_workers = 2, collate_fn = None, shuffle = False)

elif input_params.get_embeddings:
    
    #Test and get sequence embeddings (MPRA)
    
    seq_transform = sequence_encoders.RollingMasker(mask_stride = 50, frame = 0)

    test_dataset = SeqDataset(input_params.fasta, seq_df, transform = seq_transform)
    test_dataloader = DataLoader(dataset = test_dataset, batch_size = 1, num_workers = 1, collate_fn = None, shuffle = False)
    
else:
    
    #Test
    
    seq_transform = sequence_encoders.SequenceDataEncoder(seq_len = input_params.seq_len, total_len = input_params.seq_len, 
                                                      mask_rate = 0.15, split_mask = True, frame = 0)
    
    test_dataset = SeqDataset(input_params.fasta, seq_df, transform = seq_transform)
    test_dataloader = DataLoader(dataset = test_dataset, batch_size = input_params.batch_size, num_workers = 2, collate_fn = None, shuffle = False)

In [62]:
species_encoder = SpecAdd(embed = True, encoder = 'label', d_model = input_params.d_model)

model = DSSResNetEmb(d_input = 5, d_output = 5, d_model = input_params.d_model, n_layers = input_params.n_layers, 
                     dropout = input_params.dropout, embed_before = True, species_encoder = species_encoder)

model = model.to(device) 

model_params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.Adam(model_params, lr = input_params.learning_rate, weight_decay = input_params.weight_decay)

In [63]:
last_epoch = 0

if input_params.model_weight:

    if torch.cuda.is_available():
        #load on gpu
        model.load_state_dict(torch.load(input_params.model_weight))
        if input_params.optimizer_weight:
            optimizer.load_state_dict(torch.load(input_params.optimizer_weight))
    else:
        #load on cpu
        model.load_state_dict(torch.load(input_params.model_weight, map_location=torch.device('cpu')))
        if input_params.optimizer_weight:
            optimizer.load_state_dict(torch.load(input_params.optimizer_weight, map_location=torch.device('cpu')))

    last_epoch = int(input_params.model_weight.split('_')[-3]) #infer previous epoch from input_params.model_weight

weights_dir = os.path.join(input_params.output_dir, 'weights') #dir to save model weights at save_at epochs

if input_params.save_at:
    os.makedirs(weights_dir, exist_ok = True)

In [64]:
def metrics_to_str(metrics):
    loss, total_acc, masked_acc = metrics
    return f'loss: {loss:.4}, total acc: {total_acc:.3f}, masked acc: {masked_acc:.3f}'

In [66]:
#from utils.misc import print    #print function that displays time

if not input_params.test:

    for epoch in range(last_epoch+1, input_params.tot_epochs+1):

        print(f'EPOCH {epoch}: Training...')

        train_dataset.seq_df = train_df[train_df.train_fold == (epoch-1) % input_params.train_splits]
        print(f'using train samples: {list(train_dataset.seq_df.index[[0,-1]])}')

        train_metrics = train_eval.model_train(model, optimizer, train_dataloader, device,
                            silent = False)

        print(f'epoch {epoch} - train, {metrics_to_str(train_metrics)}')

        if epoch in input_params.save_at: #save model weights

            misc.save_model_weights(model, optimizer, weights_dir, epoch)

        if input_params.val_fraction>0 and ( epoch==input_params.tot_epochs or
                            (input_params.validate_every and epoch%input_params.validate_every==0)):

            print(f'EPOCH {epoch}: Validating...')

            val_metrics, _ =  train_eval.model_eval(model, optimizer, test_dataloader, device,
                    silent = False)

            print(f'epoch {epoch} - validation, {metrics_to_str(val_metrics)}')

else:

    print(f'EPOCH {last_epoch}: Test/Inference...')

    test_metrics, test_embeddings =  train_eval.model_eval(model, optimizer, test_dataloader, device, 
                                                          get_embeddings = input_params.get_embeddings, silent = False)
    
    

    print(f'epoch {last_epoch} - test, {metrics_to_str(test_metrics)}')

    os.makedirs(input_params.output_dir, exist_ok = True)

    with open(input_params.output_dir + '/embeddings.npy', 'wb') as f:
        test_embeddings = np.vstack(test_embeddings)
        np.save(f, test_embeddings)

print()
print(f'peak GPU memory allocation: {round(torch.cuda.max_memory_allocated(device)/1024/1024)} Mb')
print('Done')

EPOCH 11: Test/Inference...


  0%|                                                                                                         …

KeyboardInterrupt: 