# Predict expression from genotype

1. Load weights from trained LM and freeze
2. Take only the encoder part
3. Add a prediction head -> learnable params

In [1]:
import numpy as np
import pandas as pd
import pickle
import gc
import os
import re

import sys
import pysam
import torch
from torch.utils.data import DataLoader, Dataset
from itertools import chain
from torch import nn

if not "../" in sys.path:
    sys.path.append("../")

from encoding_utils import sequence_encoders
import helpers.train_eval_a as train_eval    #train and evaluation
import helpers.misc as misc                #miscellaneous functions
from models.spec_dss import DSSResNetEmb, SpecAdd, DSSResNetExpression

%load_ext autoreload
%autoreload 2

### Load the original model with checkpoints

In [2]:
input_params = misc.dotdict({})

input_params.dataset =  '../../extdata/datasets/phase3_top10/dataset.parquet'
input_params.model_weight = '../../extdata/checkpoints/phase3_top10/aware_large_splitmsk/weights/epoch_100_weights_model.pt'


input_params.output_dir = '../../results/test'

input_params.split_mask = False
input_params.mask_rate = 0.2 #[0.012,0.2]#RAN #single float or 2 floats for reference and alternative
input_params.masking = 'none' # stratified_maf or none

input_params.diploid = True

input_params.test = False

input_params.get_embeddings = False
input_params.mask_at_test = True

input_params.agnostic = False

input_params.seq_len = 500

input_params.tot_epochs = 1
input_params.fold = 0
input_params.Nfolds = 5

input_params.train_splits = 1

input_params.save_at = [1]
input_params.validate_every = 1

input_params.d_model = 256
input_params.n_layers = 8
input_params.dropout = 0.

input_params.batch_size = 8
input_params.learning_rate = 1e-4
input_params.weight_decay = 0

if input_params.dataset.endswith('.fa'):
    seq_df = pd.read_csv(input_params.dataset + '.fai', header=None, sep='\t', usecols=[0], names=['seq_name'])
elif input_params.dataset.endswith('.parquet'):
    seq_df = pd.read_parquet(input_params.dataset).reset_index()
    
seq_df[['split','sample_id','seg_name']] =  seq_df['seq_name'].str.split(':',expand=True)

if not input_params.agnostic:
    #for segment-aware model, assign a label to each segment
    seg_name = seq_df.seq_name.apply(lambda x:':'.join(x.split(':')[2:]))
    segment_encoding = seg_name.drop_duplicates().reset_index(drop=True)
    segment_encoding = {seg_name:idx for idx,seg_name in segment_encoding.items()}
    seq_df['seg_label'] = seg_name.map(segment_encoding)
else:
    seq_df['seg_label'] = 0


if input_params.test:
    seq_df = seq_df[seq_df.split=='test']
else:
    seq_df = seq_df[seq_df.split!='test']

In [3]:
expression_data = pd.read_csv("../../extdata/GD660.GeneQuantRPKM.txt.gz", sep="\t")\
    .rename(columns=lambda x: re.sub("\..*",'',x))\
    .melt(id_vars = ["TargetID", "Gene_Symbol", "Chr", "Coord"], var_name = "sample_id", value_name = "expr")[["Gene_Symbol", "sample_id", "expr"]]

In [4]:
seq_expression_df = pd.merge(seq_df, expression_data, left_on=["sample_id", "seg_name"], right_on=["sample_id", "Gene_Symbol"])

In [5]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('\nCUDA device: GPU\n')
else:
    device = torch.device('cpu')
    print('\nCUDA device: CPU\n')
gc.collect()
torch.cuda.empty_cache()

device = "cpu"
#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"


CUDA device: GPU



In [6]:
class SeqDataset(Dataset):

    def __init__(self, seq_df, transform, max_augm_shift=0, 
                 mode='train'):

        if input_params.dataset.endswith('.fa'):
            self.fasta = pysam.FastaFile(input_params.dataset)
        else:
            self.fasta = None

        self.seq_df = seq_df
        self.transform = transform
        self.max_augm_shift = max_augm_shift
        self.mode = mode

    def __len__(self):

        return 2*len(self.seq_df) # times two because returns both haplotypes 

    def __getitem__(self, idx):

        if self.fasta:
            seq = self.fasta.fetch(self.seq_df.iloc[idx].seq_name).upper()
        else:
            seq = self.seq_df.iloc[idx].seq.upper()

        shift = np.random.randint(self.max_augm_shift+1) #random shift at training, must be chunk_size-input_params.seq_len

        seq = seq[shift:shift+input_params.seq_len] #shift the sequence and limit its size
        seg_label = self.seq_df.iloc[idx].seg_label #label for segment-aware training
        #'''
        seq1 = seq.replace('-','').replace('B','A').replace('F','A').replace('M','R') # father
        seq2 = seq.replace('-','').replace('B','A').replace('M','A').replace('F','R') # mother 

        masked_sequence1, target_labels_masked1, target_labels1, _, _ = self.transform(seq1)
        masked_sequence2, target_labels_masked2, target_labels2, _, _ = self.transform(seq2)

        masked_sequence = torch.vstack((masked_sequence1, masked_sequence2))
        seg_label = torch.vstack((torch.tensor(seg_label), torch.tensor(seg_label)))
        masked_sequence = (masked_sequence, seg_label)

        target_labels_masked = torch.vstack((target_labels_masked1, target_labels_masked2))
        target_labels = torch.vstack((target_labels1, target_labels2))
        seq = (seq1, seq2)
        return masked_sequence, target_labels_masked, target_labels, seq
        

    def close(self):
        self.fasta.close()


class ExpressionDataset(SeqDataset): 

    def __init__(self, *args, **kwargs): 
        super().__init__(*args, **kwargs)
        pass

    def __getitem__(self, idx): 
        if self.fasta:
            seq = self.fasta.fetch(self.seq_df.iloc[idx].seq_name).upper()
        else:
            seq = self.seq_df.iloc[idx].seq.upper()

        shift = np.random.randint(self.max_augm_shift+1) #random shift at training, must be chunk_size-input_params.seq_len

        seq = seq[shift:shift+input_params.seq_len] #shift the sequence and limit its size
        seg_label = self.seq_df.iloc[idx].seg_label #label for segment-aware training
        #'''
        seq1 = seq.replace('-','').replace('B','A').replace('F','A').replace('M','R') # father
        seq2 = seq.replace('-','').replace('B','A').replace('M','A').replace('F','R') # mother 

        masked_sequence1, target_labels_masked1, target_labels1, _, _ = self.transform(seq1)
        masked_sequence2, target_labels_masked2, target_labels2, _, _ = self.transform(seq2)

        masked_sequence = torch.vstack((masked_sequence1, masked_sequence2))
        seg_label = torch.vstack((torch.tensor(seg_label), torch.tensor(seg_label)))
        masked_sequence = (masked_sequence, seg_label)

        target_labels_masked = torch.vstack((target_labels_masked1, target_labels_masked2))
        target_labels = torch.vstack((target_labels1, target_labels2))
        seq = (seq1, seq2)

        seq_expr = self.seq_df.iloc[idx].expr.astype(np.float32)

        return masked_sequence, target_labels_masked, target_labels, seq, seq_expr

    


def collate_expression_fn(data):
    """collate fn that adds expression values for each sequence.
    """ 
    #masked sequence
    masked_sequence = [x[0][0] for x in data]
    masked_sequence = [torch.stack(torch.split(d, 3)) for d in masked_sequence] 
    masked_sequence = torch.concat(masked_sequence)
    #seg labels
    seg_labels = [x[0][1] for x in data]
    seg_labels = torch.concat(seg_labels).flatten()
    # target labels masked
    target_labels_masked = [x[1] for x in data]
    target_labels_masked = torch.concat(target_labels_masked)
    # target labels 
    target_labels = [x[2] for x in data]
    target_labels = torch.concat(target_labels)
    #seq
    seqs = [x[3] for x in data]
    seqs = tuple(chain.from_iterable(seqs))

    seg_expr = [x[4] for x in data]
    # repeat each element twice, once for each haplotype
    seg_expr = torch.Tensor(seg_expr).repeat_interleave(2)

    return (masked_sequence, seg_labels, seg_expr),target_labels_masked, target_labels, seqs

In [7]:
test_df = None 

if not input_params.test: #Train and Validate
    seq_transform = sequence_encoders.SequenceDataEncoder(seq_len = input_params.seq_len, total_len = input_params.seq_len,
                                                      mask_rate = input_params.mask_rate, split_mask = input_params.split_mask)

    #N_train = int(len(seq_expression_df)*(1-input_params.val_fraction))
    if input_params.fold is not None:
        
        samples = seq_expression_df.sample_id.unique()
        val_samples = samples[input_params.fold::input_params.Nfolds] 
        train_df = seq_expression_df[~seq_expression_df.sample_id.isin(val_samples)] 
        test_df = seq_expression_df[seq_expression_df.sample_id.isin(val_samples)]
        test_dataset = ExpressionDataset(test_df, transform = seq_transform, mode='eval')
        test_dataloader = DataLoader(dataset = test_dataset, batch_size = input_params.batch_size, num_workers = 0, collate_fn = collate_expression_fn, shuffle = False)
    else:
        train_df = seq_expression_df
        #train_df = seq_expression_df[seq_expression_df.split=='train']
        #test_df = seq_expression_df[seq_expression_df.split=='val']
  
    N_train = len(train_df)
    train_fold = np.repeat(list(range(input_params.train_splits)),repeats = N_train // input_params.train_splits + 1 )
    train_df['train_fold'] = train_fold[:N_train]
    # create training dataset & dataloader 
    train_dataset = ExpressionDataset(train_df, transform = seq_transform,  mode='train')
    train_dataloader = DataLoader(dataset = train_dataset, batch_size = input_params.batch_size, num_workers = 2, collate_fn = collate_expression_fn, shuffle = False)

elif input_params.get_embeddings:
    if input_params.mask_at_test:
        seq_transform = sequence_encoders.RollingMasker(mask_stride = 50, frame = 0)
    else:
        seq_transform = sequence_encoders.PlainOneHot(frame = 0, padding = 'none')
    # create test dataset & dataloader 
    test_dataset = ExpressionDataset(seq_expression_df, transform = seq_transform, mode='eval')
    test_dataloader = DataLoader(dataset = test_dataset, batch_size = 1, num_workers = 1, collate_fn = collate_expression_fn, shuffle = False)

else: #Test
    print("not getting embeddings")
    seq_transform = sequence_encoders.SequenceDataEncoder(seq_len = input_params.seq_len, total_len = input_params.seq_len,
                                                      mask_rate=input_params.mask_rate, split_mask = input_params.split_mask)
    # create test dataset & dataloader 
    test_dataset = ExpressionDataset(seq_expression_df, transform = seq_transform, mode='eval')
    test_dataloader = DataLoader(dataset = test_dataset, batch_size = input_params.batch_size, num_workers = 2, collate_fn = collate_expression_fn, shuffle = False)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  train_df['train_fold'] = train_fold[:N_train]


In [14]:
check_batch = next(iter(train_dataloader))

In [15]:
check_batch[0][0].shape

torch.Size([8, 3, 100])

In [8]:
seg_encoder = SpecAdd(embed = True, encoder = 'label', Nsegments=seq_df.seg_label.nunique(), d_model = input_params.d_model)

model = DSSResNetExpression(d_input = 3, d_output = 3, d_model = input_params.d_model, n_layers = input_params.n_layers, 
                     dropout = input_params.dropout, embed_before = True, species_encoder = seg_encoder)



model = model.to(device)

# load the weights
model.load_state_dict(torch.load(input_params.model_weight, map_location=device), strict=False)

# # get the encoder part of the model
# encoder = DSSResNetEncoder(full_model)

# model = DSSResNetExpression(encoder = encoder, d_encoder = 256, d_output = 1, freeze_encoder = True)

# define which layers to freeze
for param, weights in model.state_dict().items(): 
    if not param.startswith("regression"):
        weights.requires_grad = False

model_params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.Adam(model_params, lr = input_params.learning_rate, weight_decay = input_params.weight_decay)

weights_dir = os.path.join(input_params.output_dir, 'weights') #dir to save model weights at save_at epochs


In [9]:
def metrics_to_str(metrics):
    expr_loss, loss, accuracy, masked_acc, masked_recall, masked_IQS = metrics
    return f'expr_loss: {expr_loss:.4}, loss: {loss:.4}, acc: {accuracy:.4}, masked acc: {masked_acc:.4}, {misc.print_class_recall(masked_recall, "masked recall: ")}, masked IQS: {masked_IQS:.4}'

In [11]:
from helpers.metrics import MeanRecall, MaskedAccuracy, IQS
from helpers.misc import EMA, print_class_recall
from tqdm import tqdm

def train_reg_model(model, optimizer, dataloader, device, silent=False):
    """train function with added regression loss"""

    criterion = torch.nn.CrossEntropyLoss(reduction = "mean")

    accuracy = MaskedAccuracy(smooth=True).to(device)
    masked_recall = MeanRecall().to(device)
    masked_accuracy = MaskedAccuracy(smooth=True).to(device)
    masked_IQS = IQS().to(device)

    reg_criterion = nn.MSELoss()
    
    model.train() #model to train mode

    if not silent:
        tot_itr = len(dataloader.dataset)//dataloader.batch_size #total train iterations
        pbar = tqdm(total = tot_itr, ncols=750) #progress bar

    loss_EMA = EMA()

    for itr_idx, ((masked_sequence, species_label, expr), targets_masked, targets, _) in enumerate(dataloader):

        masked_sequence = masked_sequence.to(device)
        species_label = species_label.to(device)
        targets_masked = targets_masked.to(device)
        targets = targets.to(device)

        logits, _, expr_pred = model(masked_sequence, species_label)

        masked_loss = criterion(logits, targets_masked)
        expr_loss = reg_criterion(expr_pred, expr)

        optimizer.zero_grad()

        loss = expr_loss + masked_loss

        loss.backward()

        #if max_abs_grad:
        #    torch.nn.utils.clip_grad_value_(model.parameters(), max_abs_grad)

        optimizer.step()

        smoothed_loss = loss_EMA.update(loss.item())

        preds = torch.argmax(logits, dim=1)

        accuracy.update(preds, targets)
        masked_recall.update(preds, targets_masked)
        masked_accuracy.update(preds, targets_masked)
        masked_IQS.update(preds, targets_masked)
        
        if not silent:
            pbar.update(1)
            pbar.set_description(f'loss: {expr_loss:.4}, acc: {accuracy.compute():.4}, {print_class_recall(masked_recall.compute(), "masked recall: ")}, masked acc: {masked_accuracy.compute():.4}, masked IQS: {masked_IQS.compute():.4}, loss: {smoothed_loss:.4}')

    if not silent:
        pbar.reset()
        del pbar

    return expr_loss, loss, accuracy, masked_accuracy, masked_recall, masked_IQS


In [12]:
from helpers import misc

from IPython.display import clear_output
last_epoch = 0

clear_output()

#from helpers.misc import print    #print function that displays time

if not input_params.test:

    for epoch in range(last_epoch+1, input_params.tot_epochs+1):

        print(f'EPOCH {epoch}: Training...')

        #if input_params.masking == 'stratified_maf':

        #    meta = get_random_mask()

        train_dataset.seq_df = train_df[train_df.train_fold == (epoch-1) % input_params.train_splits]
        print(f'using train samples: {list(train_dataset.seq_df.index[[0,-1]])}')

        train_metrics = train_reg_model(model, optimizer, train_dataloader, device,
                            silent = False)
            
        print(f'epoch {epoch} - train, {metrics_to_str(train_metrics)}')

        if epoch in input_params.save_at: #save model weights

            misc.save_model_weights(model, optimizer, weights_dir, epoch)

        if test_df is not None  and ( epoch==input_params.tot_epochs or
                            (input_params.validate_every and epoch%input_params.validate_every==0)):

            print(f'EPOCH {epoch}: Validating...')

            val_metrics, *_ =  train_eval.model_eval(model, optimizer, test_dataloader, device,
                    silent = False)

            print(f'epoch {epoch} - validation, {metrics_to_str(val_metrics)}')
            
        #lr_scheduler.step()
else:

    print(f'EPOCH {last_epoch}: Test/Inference...')

    test_metrics, test_embeddings, motif_probas =  train_eval.model_eval(model, test_dataloader, device, 
                                                          get_embeddings = input_params.get_embeddings, diploid=input_params.diploid,
                                                          silent = False)
    
    

    print(f'epoch {last_epoch} - test, {metrics_to_str(test_metrics)}')

    if input_params.get_embeddings:
        
        os.makedirs(input_params.output_dir, exist_ok = True)

        with open(input_params.output_dir + '/embeddings.pickle', 'wb') as f:
            #test_embeddings = np.vstack(test_embeddings)
            #np.save(f,test_embeddings)
            pickle.dump(test_embeddings,f)
            #pickle.dump(seq_df.seq_name.tolist(),f)
            
print()
print(f'peak GPU memory allocation: {round(torch.cuda.max_memory_allocated(device)/1024/1024)} Mb')
print('Done')

EPOCH 1: Training...
using train samples: [2, 1549]


  return einsum('chn,hnl->chl', W, S).float(), state                   # [C H L]
  return F.mse_loss(input, target, reduction=self.reduction)
loss: nan, acc: 0.6287, masked recall: R=0.749;A=0.263;AVG=0.506, masked acc: 0.6416, masked IQS: 0.09914, loss: nan:   8%|█████████████████████████████████████████████▉                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    | 24/310 [01:34<18:19,  3.85s/it]

KeyboardInterrupt: 

In [None]:
last_epoch = 0

if input_params.model_weight:
    print("With loaded model weights:")

    if torch.cuda.is_available():
        print("Using GPU...")
        #load on gpu
        model.load_state_dict(torch.load(input_params.model_weight))
        if input_params.optimizer_weight:
            optimizer.load_state_dict(torch.load(input_params.optimizer_weight))
    else:
        #load on cpu
        model.load_state_dict(torch.load(input_params.model_weight, map_location=torch.device('cpu')))
        if input_params.optimizer_weight:
            optimizer.load_state_dict(torch.load(input_params.optimizer_weight, map_location=torch.device('cpu')))

    last_epoch = int(input_params.model_weight.split('_')[-3]) #infer previous epoch from input_params.model_weight

weights_dir = os.path.join(input_params.output_dir, 'weights') #dir to save model weights at save_at epochs

if input_params.save_at:
    os.makedirs(weights_dir, exist_ok = True)

#lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
#        milestones=input_params.lr_sch_milestones, gamma=input_params.lr_sch_gamma, verbose=False) 

With loaded model weights:
Using GPU...


RuntimeError: Error(s) in loading state_dict for DSSResNetExpression:
	Missing key(s) in state_dict: "regression_head.0.weight", "regression_head.0.bias", "regression_head.2.weight", "regression_head.2.bias", "regression_head.4.weight", "regression_head.4.bias". 