In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

x = torch.rand(5, 3)
print(x)


  from .autonotebook import tqdm as notebook_tqdm


tensor([[0.3701, 0.7939, 0.7184],
        [0.3029, 0.3060, 0.6226],
        [0.6476, 0.6285, 0.1258],
        [0.3469, 0.6439, 0.4477],
        [0.8731, 0.4962, 0.7543]])


In [None]:
from .layers import Conv1DResBlock, IndexEmbeddingOutputHead

# %%
class Network(nn.Module):
    def __init__(self, tasks, nlayers=9):
        super(Network, self).__init__()

        self.tasks = tasks

        self.body = [Conv1DResBlock() for _ in range(nlayers)]
        self.head = IndexEmbeddingOutputHead(len(self.tasks), dims=128)
    
    def forward(self, x, **kwargs):
        x = x['input']

        for layer in self.body:
            x = layer(x)

        return self.head(x)

In [142]:
import tensorflow as tf
from bioflow import io

dataset = io.load_indexed_tfrecord('example-data/data.tfrecord', shuffle=1000)
# dataset = dataset.map(lambda sample: (sample['inputs'], tf.random.uniform(shape=())))
dataset = dataset.batch(2)
dataset = dataset.map(lambda example: (example['inputs'], example['outputs']))

example = next(iter(dataset))
example = tf.nest.map_structure(lambda s: torch.tensor(s.numpy()).to(torch.float32), next(iter(dataset)))
example[0]['input'] = torch.transpose(example[0]['input'], 1, 2)
example

tf.Tensor(1788, shape=(), dtype=int64)


({'input': tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
           [1., 0., 0.,  ..., 1., 1., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 1., 1.,  ..., 0., 0., 1.]],
  
          [[1., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 1., 1., 1.],
           [0., 1., 1.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]])},
 {'PTBP1_HepG2': {'control': tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0

In [130]:
for s in dataset.as_numpy_iterator():
    print(s)
    break

({'input': array([[[0, 0, 0, 1],
        [0, 0, 1, 0],
        [0, 0, 1, 0],
        ...,
        [0, 0, 1, 0],
        [0, 0, 1, 0],
        [0, 1, 0, 0]],

       [[0, 0, 1, 0],
        [0, 1, 0, 0],
        [1, 0, 0, 0],
        ...,
        [0, 0, 1, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1]]], dtype=int8)}, {'PTBP1_HepG2': {'control': array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
     

In [108]:
x[1]['U2AF2_HepG2']['total'].shape

torch.Size([2, 201])

In [92]:
for x in dataset.as_numpy_iterator():
    x = tf.nest.map_structure(torch.tensor, x)
    x[0]['input'] = torch.transpose(x[0]['input'], 1, 2)
    print(x)
    break

({'input': tensor([[[0, 0, 0,  ..., 0, 1, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 1, 1,  ..., 1, 0, 1],
         [1, 0, 0,  ..., 0, 0, 0]],

        [[0, 1, 1,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 1, 0],
         [1, 0, 0,  ..., 1, 0, 1]]], dtype=torch.int8)}, tensor([0.3080, 0.6998]))


In [7]:
# class TFIterableDataset(torch.utils.data.IterableDataset):
#     def __init__(self, filepath):
#         super(TFIterableDataset).__init__()
#         self.dataset = dataset = io.load_indexed_tfrecord(filepath, shuffle=1000)
    
#     def __iter__(self):
#         for sample in self.dataset:
#             yield tf.nest.map_structure(lambda x: x.numpy(), sample)

# dataset = TFIterableDataset('example-data/data.tfrecord')
# dataset

tf.Tensor(1788, shape=(), dtype=int64)


<__main__.TFIterableDataset at 0x7f2f5917e8f0>

In [9]:
# data_loader = torch.utils.data.DataLoader(dataset, batch_size=4)
# print(next(iter(data_loader)))

{'inputs': {'input': tensor([[[0, 0, 0, 1],
         [0, 0, 1, 0],
         [0, 0, 0, 1],
         ...,
         [0, 1, 0, 0],
         [0, 1, 0, 0],
         [0, 0, 0, 1]],

        [[0, 1, 0, 0],
         [0, 0, 0, 1],
         [0, 0, 0, 1],
         ...,
         [0, 0, 1, 0],
         [0, 1, 0, 0],
         [0, 1, 0, 0]],

        [[0, 0, 0, 1],
         [0, 0, 1, 0],
         [1, 0, 0, 0],
         ...,
         [0, 0, 0, 1],
         [0, 0, 0, 1],
         [0, 0, 1, 0]],

        [[0, 0, 1, 0],
         [1, 0, 0, 0],
         [1, 0, 0, 0],
         ...,
         [1, 0, 0, 0],
         [0, 0, 0, 1],
         [0, 0, 0, 1]]], dtype=torch.int8)}, 'meta': [b'chr19:54461584-54461785:+:U2AF2_HepG2_rep01', b'chr19:39375533-39375734:+:U2AF2_HepG2_rep02', b'chr19:18991482-18991683:-:U2AF2_HepG2_rep02', b'chr19:44157054-44157255:+:U2AF2_HepG2_rep02'], 'outputs': {'PTBP1_HepG2': {'control': tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 

In [72]:
class Conv1DResBlock(nn.Module):
    def __init__(self, in_chan, out_chan):
        super(Conv1DResBlock, self).__init__()

        self.conv1d = nn.Conv1d(in_chan, out_chan, kernel_size=3, padding='same')
        self.batch_norm = nn.BatchNorm1d(out_chan)
        self.act = nn.ReLU()
        self.dropout = nn.Dropout(0.25)
    
    def forward(self, x, **kwargs):
        x = self.conv1d(x)
        x = self.batch_norm(x)
        x = self.act(x)
        x = self.dropout(x)
        return x

In [98]:
class Head(nn.Module):
    def __init__(self, in_chan):
        super(Head, self).__init__()

        self.conv1d = nn.Conv1d(in_chan, 1, 1)
    
    def forward(self, x, **kwargs):
        x = self.conv1d(x)
        return torch.squeeze(x)

In [103]:
class Network(nn.Module):
    def __init__(self, tasks):
        super(Network, self).__init__()

        self.tasks = tasks

        # self.conv1d = nn.Conv1d(4, 16, 3, padding='same')
        self.conv1d_resblock = Conv1DResBlock(4, 16)
        self.heads = {task: Head(16) for task in self.tasks}
    
    def forward(self, x, **kwargs):
        x = x['input']
        x = self.conv1d_resblock(x)

        return {task: self.heads[task](x) for task in self.tasks}

network = Network(tasks=['PTBP1_HepG2', 'U2AF2_HepG2'])
network

Network(
  (conv1d_resblock): Conv1DResBlock(
    (conv1d): Conv1d(4, 16, kernel_size=(3,), stride=(1,), padding=same)
    (batch_norm): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): ReLU()
    (dropout): Dropout(p=0.25, inplace=False)
  )
)

In [109]:
# x_torch = tf.nest.map_structure(torch.tensor, x[0])
# x_torch = tf.nest.map_structure(lambda x: torch.transpose(x, dim0=1, dim1=2), x_torch)
# x_torch['input'] = x_torch['input'].to(torch.float32)
# print(x_torch[1].keys())
# print(x_torch['U2AF2_HepG2']['total'].shape)

In [53]:
# x_input_torch = x_torch['input']
# x_input_torch = x_input_torch.to(torch.float32)
# print(x_input_torch.shape)
# print(x_input_torch.dtype)

torch.Size([2, 4, 201])
torch.float32


In [143]:
pred = network.forward(example[0])
print(pred.keys())
print(pred['U2AF2_HepG2'].shape)

dict_keys(['PTBP1_HepG2', 'U2AF2_HepG2'])
torch.Size([2, 201])


In [116]:
torch.save(network, 'network.pth')

In [117]:
loaded_network = torch.load('network.pth')
loaded_network(x[0]).keys()

dict_keys(['PTBP1_HepG2', 'U2AF2_HepG2'])

In [94]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(network.parameters(), lr=0.001, momentum=0.9)

In [118]:
loaded_network.tasks

['PTBP1_HepG2', 'U2AF2_HepG2']

In [167]:
from functorch import vmap

class MultinomialNLLLoss(nn.Module):
    def __init__(self):
        super(MultinomialNLLLoss, self).__init__()
    
    def __call__(self, y, y_pred):
        cummulative_loss = torch.tensor(0.0)
        for i in range(y.shape[0]):
            cummulative_loss += self.unbatched_multinomial_nll_loss(y[i], y_pred[i])
        return cummulative_loss/y.shape[0]

    def unbatched_multinomial_nll_loss(self, y, y_pred):
        return -torch.distributions.Multinomial(int(torch.sum(y)), logits=y_pred).log_prob(y)

l = MultinomialNLLLoss()
l.unbatched_multinomial_nll_loss(y['U2AF2_HepG2']['total'][0], y_pred['U2AF2_HepG2'][0])

tensor(183.9577, grad_fn=<NegBackward0>)

In [170]:
l = MultinomialNLLLoss()

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(dataset.as_numpy_iterator()):
        data = tf.nest.map_structure(lambda s: torch.tensor(s).to(torch.float32), data)
        x, y = data
        x['input'] = torch.transpose(x['input'], 1, 2)

        # print(x)
        # print(y)
        # break

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        y_pred = network(x)
        loss = 0
        for task in network.tasks:
            loss += l(y[task]['total'], y_pred[task])
            # loss += criterion(y_pred[task], y[task]['total'])
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 100 == 99:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

[1,   100] loss: 9.601
[1,   200] loss: 8.499
[1,   300] loss: 9.696
[1,   400] loss: 9.754
[1,   500] loss: 8.583
[1,   600] loss: 8.493
[1,   700] loss: 7.722
[1,   800] loss: 8.097
[2,   100] loss: 10.195
[2,   200] loss: 7.939
[2,   300] loss: 9.896
[2,   400] loss: 8.742
[2,   500] loss: 8.607
[2,   600] loss: 7.987
[2,   700] loss: 8.550
[2,   800] loss: 8.977


In [145]:
y_pred = network(example[0])
y = example[1]

In [154]:
torch.sum(y['U2AF2_HepG2']['total'], dim=1).to(torch.int32)[0]

tensor(67, dtype=torch.int32)

In [156]:
m = torch.distributions.Multinomial(int(torch.sum(y['U2AF2_HepG2']['total'], dim=1).to(torch.int32)[0]), logits=y_pred['U2AF2_HepG2'][0])
m

Multinomial()

In [168]:
l(y['U2AF2_HepG2']['total'], y_pred['U2AF2_HepG2'])

tensor(130.5727, grad_fn=<DivBackward0>)

In [160]:
- m.log_prob(y['U2AF2_HepG2']['total'][0])

tensor(183.9577, grad_fn=<NegBackward0>)

In [137]:
data

({'input': tensor([[[0., 1., 0., 0.],
           [0., 1., 0., 0.],
           [0., 0., 0., 1.],
           ...,
           [1., 0., 0., 0.],
           [0., 0., 1., 0.],
           [0., 0., 1., 0.]],
  
          [[1., 0., 0., 0.],
           [0., 1., 0., 0.],
           [0., 0., 0., 1.],
           ...,
           [0., 0., 0., 1.],
           [0., 0., 0., 1.],
           [0., 1., 0., 0.]]])},
 {'PTBP1_HepG2': {'control': tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0

In [None]:
# def _manual_multinomial_negative_log_likelihood(x, log_p):
#     return tf.cast(-np.log( np.prod(tf.math.softmax(log_p)**x) * (factorial(np.sum(x)) / np.prod(factorial(x))) ), dtype=tf.float32)

In [1]:
from bioflow import io
import tensorflow as tf

dataset = io.load_tfrecord('example-data-matrix/windows.chr13.4.data.matrix.filtered.tfrecord')
dataset = dataset.batch(4)
print(dataset)
example = next(iter(dataset))

  from .autonotebook import tqdm as notebook_tqdm


<BatchDataset element_spec={'inputs': {'input': TensorSpec(shape=(None, None, 4), dtype=tf.int8, name=None)}, 'meta': TensorSpec(shape=(None,), dtype=tf.string, name=None), 'outputs': {'counts': {'control': TensorSpec(shape=(None, 223), dtype=tf.float32, name=None), 'total': TensorSpec(shape=(None, 223), dtype=tf.float32, name=None)}, 'signal': {'control': TensorSpec(shape=(None, 223, None), dtype=tf.float32, name=None), 'total': TensorSpec(shape=(None, 223, None), dtype=tf.float32, name=None)}}}>


In [2]:
tf.reduce_sum(example['outputs']['signal']['total'])

<tf.Tensor: shape=(), dtype=float32, numpy=443.0>

In [9]:
y = torch.tensor(example['outputs']['signal']['total'].numpy())
y.shape

torch.Size([4, 223, 1000])

In [11]:
torch.sum(y, dim=-1).shape

torch.Size([4, 223])

In [23]:
torch.lgamma(torch.tensor([[0, 3, 4, 5], [0, 3, 4, 5]]))

tensor([[   inf, 0.6931, 1.7918, 3.1781],
        [   inf, 0.6931, 1.7918, 3.1781]])

In [27]:
def log_combinations(x):
    total_permutations = torch.lgamma(torch.sum(y, dim=-1) + 1)
    counts_factorial = torch.lgamma(y + 1)
    redundant_permutations = torch.sum(counts_factorial, dim=-1)
    return total_permutations - redundant_permutations

y_logc = log_combinations(y)
y_logc.shape

torch.Size([4, 223])

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

x = torch.rand(5, 3)
print(x)


tensor([[0.4186, 0.4810, 0.8704],
        [0.0866, 0.1233, 0.7130],
        [0.6400, 0.0403, 0.8048],
        [0.1123, 0.8429, 0.5790],
        [0.1690, 0.2289, 0.2578]])


In [47]:
class MultinomialNLLLossFromLogits(nn.Module):
    def __init__(self, reduction=torch.mean):
        super(MultinomialNLLLossFromLogits, self).__init__()
        self.reduction = reduction
    
    def __call__(self, y, y_pred):
        return self.log_likelihood_from_logits(y, y_pred)

    def log_likelihood_from_logits(self, y, y_pred):
        log_prob = -torch.sum(torch.mul(torch.log_softmax(y_pred, dim=-1), y), dim=-1) * self.log_combinations(y)
        if self.reduction is not None:
            return self.reduction(log_prob)
        return log_prob

    def log_combinations(self, input):
        total_permutations = torch.lgamma(torch.sum(input, dim=-1) + 1)
        counts_factorial = torch.lgamma(input + 1)
        redundant_permutations = torch.sum(counts_factorial, dim=-1)
        return total_permutations - redundant_permutations

loss_fn = MultinomialNLLLossFromLogits()
y_log_prob = loss_fn(y, torch.rand(4, 223, 1000))
y_log_prob

tensor(28.5873)

In [35]:
torch.sum(y_log_prob[0])

tensor(6299.0479)

In [None]:
class AdditiveMixLayer(nn.Module):
    def __init__(self):
        super(AdditiveMixLayer, self).__init__()
    
    def forward(self, logits_a, logits_b):
        pass

In [23]:
class IndexEmbeddingOutputHead(nn.Module):
    def __init__(self, n_tasks, dims):
        super(IndexEmbeddingOutputHead, self).__init__()

        # embedding of shape (p, d)
        self.embedding = torch.nn.Embedding(n_tasks, dims)
    
    def forward(self, bottleneck, **kwargs):
        # bottleneck of shape (n, d)
        logits = torch.matmul(bottleneck, torch.transpose(self.embedding.weight, 0, 1))  
        return logits

out_head = OutputHead(223, 128)

In [24]:
embedding = torch.nn.Embedding(223, 128)

In [21]:
seq = torch.rand(1000, 128)
seq.shape

torch.Size([1000, 128])

In [25]:
out_head.forward(seq).shape

torch.Size([1000, 223])