In [10]:
import torch
from torch.utils.data import Dataset, DataLoader
import pickle
import numpy as np

def one_hot_encode(seq):
    mapping = dict(zip("ACGT", range(4)))    
    seq2 = [mapping[i] for i in seq]
    return np.eye(4, dtype=np.int8)[seq2]

class BiasDataset(Dataset):
    """Background sequences for bias model training"""

    def __init__(self, path_sequences, path_ATAC_signal):
        """
        Arguments:
            path_sequences (string): Path to the pickle file with background regions sequences
            path_ATAC_signal (string): Path to the pickle file with ATAC tracks per datasets and time points

        """
        with open(path_sequences, 'rb') as file:
            self.sequences = pickle.load(file).sequence

        #Encode sequences
        self.sequences = self.sequences.apply(lambda x: one_hot_encode(x))

        with open(path_ATAC_signal, 'rb') as file:
            self.ATAC_track = pickle.load(file)
            self.ATAC_track.time = self.ATAC_track.time.astype('category')
            self.ATAC_track.cell_type = self.ATAC_track.cell_type.astype('category')

    def __len__(self):
        return self.ATAC_track.shape[0]

    def __getitem__(self, idx):
        
        track = self.ATAC_track.iloc[idx,0]
        time = self.ATAC_track.iloc[idx,:].time
        cell_type = self.ATAC_track.iloc[idx,:].cell_type

        input = self.sequences[self.ATAC_track.index[idx]]

        return input, time, cell_type, track


In [11]:
dataset = BiasDataset('../results/peaks_seq.pkl', '../results/ATAC_peaks.pkl')
dataloader = DataLoader(dataset, batch_size=32,
                        shuffle=True, num_workers=0)

In [12]:
from tqdm import tqdm

for data in tqdm(dataloader): 
    inputs, time, cell_type, tracks = data 
    print(inputs[0])
    break

  0%|          | 0/24960 [00:01<?, ?it/s]

tensor([[0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 1, 0],
        ...,
        [0, 0, 0, 1],
        [0, 0, 0, 1],
        [1, 0, 0, 0]], dtype=torch.int8)





In [45]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

class BPNet(nn.Module):
    def __init__(self, nb_conv=8, nb_filters=64, first_kernel=21, rest_kernel=3, profile_kernel_size=75, out_pred_len=1000):
        super().__init__()
        """ BPNet architechture as in paper 
        
        Parameters
        -----------
        nb_conv: int (default 8)
            number of convolutional layers

        nb_filters: int (default 64)
            number of filters in the convolutional layers

        first_kernel: int (default 25)
            size of the kernel in the first convolutional layer

        rest_kernel: int (default 3)
            size of the kernel in all convolutional layers except the first one

        profile_kernel_size: int (default 75)
            size of the kernel in the profile convolution

        out_pred_len: int (default 1000)
            number of bp for which ATAC signal is predicted

        Model Architecture 
        ------------------------

        - Body: sequence of convolutional layers with residual skip connections, dilated convolutions, 
        and  ReLU activation functions

        - Head: 
            > Profile prediction head: a multinomial probability of Tn5 insertion counts at each position 
            in the input sequence, deconvolution layer
            > Total count prediction: the total Tn5 insertion counts over the input region, global average
            poooling and linear layer predicting the total count per strand
        
        The predicted (expected) count at a specific position is a multiplication of the predicted total 
        counts and the multinomial probability at that position.

        -------------------------
        
        Reference: Avsec, Ž., Weilert, M., Shrikumar, A. et al. Base-resolution models of transcription-factor binding 
        reveal soft motif syntax. Nat Genet 53, 354–366 (2021). https://doi.org/10.1038/s41588-021-00782-6

        
        """
        #Define parameters
        self.nb_conv = nb_conv
        self.nb_filters = nb_filters
        self.first_kernel = first_kernel
        self.rest_kernel = rest_kernel
        self.profile_kernel = profile_kernel_size
        self.out_pred_len = out_pred_len

        #Convolutional layers
        self.convlayers = nn.ModuleList()

        self.convlayers.append(nn.Conv1d(in_channels=4, 
                                         out_channels=self.nb_filters,
                                         kernel_size=self.first_kernel))
        for i in range (1,self.nb_conv):
            self.convlayers.append(nn.Conv1d(in_channels=self.nb_filters, 
                                         out_channels=self.nb_filters,
                                         kernel_size=self.rest_kernel,
                                         dilation=2**i))
        #Profile prediction head   
        self.profile_conv = nn.ConvTranspose1d(self.nb_filters, 1, kernel_size=self.profile_kernel)
        self.flatten = nn.Flatten()

        #Total count prediction head
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.linear = nn.Linear(self.nb_filters,1)

            
    def forward(self,x):
        
        #Residual + Dilated convolution layers
        #-----------------------------------------------
        x = F.relu(self.convlayers[0](x))

        for layer in self.convlayers[1:]:
            
            conv_x = F.relu(layer(x))

            #Crop output previous layer to size of current 
            x_len = x.size(2); conv_x_len = conv_x.size(2)
            cropsize = (x_len - conv_x_len) // 2
            x = x[:, :, cropsize:-cropsize] 

            #Skipped connection
            x = conv_x + x    

        #Profile head
        #-----------------------------------------------
        profile = self.profile_conv(x)
        
        cropsize = int((profile.size(2)/2) - (self.out_pred_len/2))
        profile = profile[:,:, cropsize:-cropsize]
        
        profile = self.flatten(profile)

        #Total count head
        #-----------------------------------------------
        count = self.global_pool(x)  
        count = count.squeeze()
        count = self.linear(count)

        return x, profile, count

The model predicts the base-resolution 1,000 bp length Tn5 insertion count profile using two complementary outputs: (1) the total Tn5 insertion counts over the 1,000 bp region, and (2) a multinomial probability of Tn5 insertion counts at each position in the 1,000 bp sequence. The predicted (expected) count at a specific position is a multiplication of the predicted total counts and the multinomial probability at that position.

In [47]:
m = BPNet()
input = torch.randn(32, 4, 2114)
x, profile, count = m(input)

profile

tensor([[-4.9177e+00, -4.1933e+00, -4.1938e+00,  ..., -6.8152e+00,
         -2.7607e+00, -1.0761e+01],
        [-8.8051e+00, -9.7412e-01, -5.7450e+00,  ..., -5.6425e+00,
         -3.1305e+00, -5.7297e+00],
        [-3.5067e+00, -4.0354e+00, -9.2691e+00,  ..., -6.7895e+00,
         -1.0056e+01, -4.8932e+00],
        ...,
        [-2.6666e+00, -1.8417e+00, -6.3831e+00,  ..., -7.5926e+00,
         -7.9752e+00, -3.2860e+00],
        [-5.9738e+00, -1.1024e+01, -6.0759e+00,  ..., -7.0113e+00,
         -1.0954e+01, -4.0736e+00],
        [ 2.6391e-03, -2.6899e+00, -4.6821e+00,  ..., -6.7192e+00,
         -1.1125e+00, -7.1009e+00]], grad_fn=<ReshapeAliasBackward0>)

BPNet uses a composite loss function consisting of a linear combination of a mean squared error (MSE) loss on the log of the total counts and a multinomial negative log likelihood loss (MNLL) for the profile probability output. We use a weight of [4.9, 4.3, 18.5, 9.8, 8.9, 4.8, 4.6, 4.9, 12.4, 15.4, 4.3, 6.3, 1.4, 2.6, 7.6, 2.3, 16.3, 7.1 & 3.7] for the MSE loss for clusters c0–c20 (c15-c16 combined as one model), and a weight of 1 for the MNLL loss in the linear combination. The MSE loss weight is derived as the median of total counts across all peak regions for each cluster divided by a factor of 10 

In [None]:
#Custom losses functions
import torch.distributions as dist

class MultinomialNLLLoss(nn.Module):
    def __init__(self):
        super(self).__init__()

    def forward(self, true_counts, logits):
        counts_per_example = torch.sum(true_counts, dim=1)
        dist = dist.multinomial.Multinomial(total_count=counts_per_example, logits=logits)

        return (-torch.sum(dist.log_prob(true_counts)))/float(true_counts.shape[0])


In [None]:
def train(model, criterion, optimizer, num_epochs, dataloader):
    
    for epoch in range(num_epochs):
        model.train() 
        running_loss = 0

        for data in tqdm(dataloader):
            inputs, time, cell_type, tracks = data  

            optimizer.zero_grad()

            outputs = model(inputs)
            
            loss = criterion(outputs, tracks)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        epoch_loss = running_loss / len(dataloader)
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}')

    print('Finished Training')