In [None]:
import torch
import torch.nn as nn

In [None]:
class MultinomialNLLLossFromLogits(nn.Module):
    def __init__(self, reduction=torch.mean):
        super(MultinomialNLLLossFromLogits, self).__init__()
        self.reduction = reduction
    
    def __call__(self, y, y_pred):
        return self.log_likelihood_from_logits(y, y_pred)

    def log_likelihood_from_logits(self, y, y_pred):
        log_prob = -torch.sum(torch.mul(torch.log_softmax(y_pred, dim=-1), y), dim=-1) * self.log_combinations(y)
        if self.reduction is not None:
            return self.reduction(log_prob)
        return log_prob

    def log_combinations(self, input):
        total_permutations = torch.lgamma(torch.sum(input, dim=-1) + 1)
        counts_factorial = torch.lgamma(input + 1)
        redundant_permutations = torch.sum(counts_factorial, dim=-1)
        return total_permutations - redundant_permutations

In [None]:
class Conv1DFirstLayer(nn.Module):
    def __init__(self, in_chan, filters=128, kernel_size=12):
        super(Conv1DFirstLayer, self).__init__()

        self.conv1d = nn.Conv1d(in_chan, filters, kernel_size=kernel_size, padding='same')
        self.act = nn.ReLU()
    
    def forward(self, inputs, **kwargs):
        x = self.conv1d(inputs)
        x = self.act(x)
        return x

In [None]:
class Conv1DResBlock(nn.Module):
    def __init__(self, in_chan, filters=128, kernel_size=3, dropout=0.25, dilation=1, residual=True):
        super(Conv1DResBlock, self).__init__()

        self.conv1d = nn.Conv1d(in_chan, filters, kernel_size=kernel_size, dilation=dilation, padding='same')
        self.batch_norm = nn.BatchNorm1d(filters)
        self.act = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.residual = residual
    
    def forward(self, inputs, **kwargs):
        x = self.conv1d(inputs)
        x = self.batch_norm(x)
        x = self.act(x)
        x = self.dropout(x)
        if self.residual:
            x = inputs + x
        return x

# %%
class IndexEmbeddingOutputHead(nn.Module):
    def __init__(self, n_tasks, dims):
        super(IndexEmbeddingOutputHead, self).__init__()

        # protein/experiment embedding of shape (p, d)
        self.embedding = torch.nn.Embedding(n_tasks, dims)
    
    def forward(self, bottleneck, **kwargs):
        # bottleneck of shape (batch, d, n) --> (batch, n, d)
        bottleneck = torch.transpose(bottleneck, -1, -2)
        
        # embedding of (batch, p, d) --> (batch, d, p)
        embedding = torch.transpose(self.embedding.weight, 0, 1)

        logits = torch.matmul(bottleneck, embedding) # torch.transpose(self.embedding.weight, 0, 1)  
        return logits

In [None]:
class IndexEmbeddingOutputHead(nn.Module):
    def __init__(self, n_tasks, dims):
        super(IndexEmbeddingOutputHead, self).__init__()

        # protein/experiment embedding of shape (p, d)
        self.embedding = torch.nn.Embedding(n_tasks, dims)
    
    def forward(self, bottleneck, **kwargs):
        # bottleneck of shape (batch, d, n) --> (batch, n, d)
        bottleneck = torch.transpose(bottleneck, -1, -2)
        
        # embedding of (batch, p, d) --> (batch, d, p)
        embedding = torch.transpose(self.embedding.weight, 0, 1)

        logits = torch.matmul(bottleneck, embedding) # torch.transpose(self.embedding.weight, 0, 1)  
        return logits

In [None]:
class Network(nn.Module):
    def __init__(self, tasks, nlayers=9):
        super(Network, self).__init__()

        self.tasks = tasks

        self.body = nn.Sequential(*[Conv1DFirstLayer(4, 128)]+[(Conv1DResBlock(128, dilation=(2**i))) for i in range(nlayers)])
        self.head = IndexEmbeddingOutputHead(len(self.tasks), dims=128)
    
    def forward(self, inputs, **kwargs):
        x = inputs

        for layer in self.body:
            x = layer(x)

        return self.head(x)

In [None]:
net = Network(tasks=list(range(223)))
net

In [None]:
y_pred = net(torch.rand(2, 4, 201))
y_pred.shape

In [None]:
from bioflow import io
import tensorflow as tf
import torch

def load_tf_dataset_to_torch(filepath, features_filepath=None, batch_size=64, cache=True, shuffle=None):
    dataset = io.dataset_ops.load_tfrecord(filepath, deserialize=False)

    # cache
    if cache:
        dataset = dataset.cache()

    if shuffle:
        dataset = dataset.shuffle(shuffle)

    # deserialize
    if features_filepath is None:
        features_filepath = filepath + '.features.json'
    features = io.dataset_ops.features_from_json_file(features_filepath)
    dataset = io.dataset_ops.deserialize_dataset(dataset, features)

    # batch
    dataset = dataset.batch(batch_size)

    # format dataset
    dataset = dataset.map(lambda e: (tf.transpose(e['inputs']['input'], perm=[0, 2, 1]), e['outputs']))

    for example in dataset.as_numpy_iterator():
        # yield example
        yield tf.nest.map_structure(lambda x: torch.tensor(x).to(torch.float32), example)

torch_dataset = load_tf_dataset_to_torch('example-data-matrix/windows.chr13.4.data.matrix.filtered.tfrecord', shuffle=1_000_000)

In [None]:
class TFIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, filepath, features_filepath=None, batch_size=64, cache=True, shuffle=None):
        super(TFIterableDataset).__init__()

        self.dataset = io.dataset_ops.load_tfrecord(filepath, deserialize=False)

        # cache
        if cache:
            self.dataset = self.dataset.cache()

        if shuffle:
            self.dataset = self.dataset.shuffle(shuffle)

        # deserialize
        if features_filepath is None:
            features_filepath = filepath + '.features.json'
        self.features = io.dataset_ops.features_from_json_file(features_filepath)
        self.dataset = io.dataset_ops.deserialize_dataset(self.dataset, self.features)

        # batch
        self.dataset = self.dataset.batch(batch_size)

        # format dataset
        self.dataset = self.dataset.map(lambda e: (tf.transpose(e['inputs']['input'], perm=[0, 2, 1]), e['outputs']))
        
    def __iter__(self):
        for example in self.dataset.as_numpy_iterator():
            yield tf.nest.map_structure(lambda x: torch.tensor(x).to(torch.float32), example)

dataset = TFIterableDataset('example-data-matrix/windows.chr13.4.data.matrix.filtered.tfrecord', shuffle=1_000_000)

In [None]:
next(iter(dataset))

In [None]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None)

In [None]:
for s in dataloader:
    print(len(s))
    print(s[0].shape)
    print(s)
    break

In [None]:
batch = next(torch_dataset)

In [None]:
print(batch[0].shape)
print(batch[1]['signal']['total'].shape)

In [None]:
def example_dataset_generator(n=1000):
    for _ in range(n):
        yield (torch.rand(8, 4, 101, dtype=torch.float32), {'signal': {'total': torch.randint(10, (8, 101, 7)).to(torch.float32)}})

next(iter(example_dataset_generator()))

In [None]:
from tqdm import tqdm

for epoch in range(5):
    print(f'Epoch: {epoch}/5')
    for sample in tqdm(example_dataset_generator(100), total=100):
        _ = net(sample[0])

In [None]:
dataset = lambda: example_dataset_generator(100)

In [None]:
for i in dataset():
    print(i)

In [None]:
import tqdm

test_net = Network(tasks=list(range(7)))
test_net

import torch.optim as optim
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
criterion = MultinomialNLLLossFromLogits()

def train(net, dataset, epochs=2):
    for epoch in tqdm.trange(epochs):
        epoch_running_loss = 0.0
        print(f'Epoch {epoch}')
        for sample in dataset():
            x, y = sample

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            y_pred = net(x)
            loss = criterion(y['signal']['total'], y_pred)
            loss.backward()
            optimizer.step()
            
            # add to running loss
            epoch_running_loss += loss.item()
        print(f'Loss {epoch_running_loss}')

train(test_net, lambda: example_dataset_generator(100), epochs=3)


In [None]:
print(next(iter(torch_dataset))[0].shape)
print(next(iter(torch_dataset))[1].keys())

In [None]:
total_lengths = []
for e in io.load_tfrecord('example-data-matrix/windows.chr15.4.data.matrix.filtered.tfrecord'):
    # total_lengths.append(int(e['outputs']['signal']['total'].shape[1]))
    total_lengths.append(int(e['inputs']['input'].shape[0]))

In [None]:
set(total_lengths)

In [None]:
res = torch.softmax(y_pred, dim=-2)
print(res[0][:,0].shape)
print(torch.sum(res[0][:,3]))

In [None]:
ex_pred = torch.rand(2, 201, 128)
print(ex_pred.shape)

embed = torch.rand(128, 223)
print(embed.shape)

print(torch.unsqueeze(embed, dim=0).shape)

In [None]:
torch.matmul(ex_pred, embed).shape

In [None]:
torch.mul(torch.rand(2, 3, 4), torch.rand(3, 4))

In [1]:
import torch
import torchmetrics

  from .autonotebook import tqdm as notebook_tqdm


In [33]:
corr = torchmetrics.PearsonCorrCoef(num_outputs=4)
corr(torch.rand(101, 4), torch.rand(101, 4))

tensor([-0.0397, -0.0754, -0.0656,  0.0886])

In [None]:
torchmetrics.functional.pearson_corrcoef

In [2]:
# corr = torchmetrics.PearsonCorrCoef()

In [27]:
def transparent_corr(x, y):
    print(x.shape, y.shape)
    # print(x[0].shape, y[0].shape)
    # print(x)
    # return x
    # print(x, y)
    # return torch.mean(x + y)
    return torchmetrics.functional.pearson_corrcoef(x, y)

In [35]:
import functorch

# a, b = torch.rand(2, 101, 4), torch.rand(2, 101, 4)
a, b = torch.rand(2, 101, 7), torch.rand(2, 101, 7)
# print(a)
print(torchmetrics.functional.pearson_corrcoef(a[0], b[0]).shape)


# vmap_corr = functorch.vmap(transparent_corr, in_dims=(0, 2), out_dims=(0, 2))
# vmap_corr = functorch.vmap(transparent_corr, in_dims=0, out_dims=0)
# out = vmap_corr(a, b)
# out.shape

torch.Size([7])


In [45]:
def batched_pearson_corrcoef(y_batch, y_pred_batch):
    return torch.sum(torch.stack([torchmetrics.functional.pearson_corrcoef(y_batch[i], y_pred_batch[i]) for i in range(y_batch.shape[0])]), dim=0)

batched_pearson_corrcoef(a, b)

tensor([-0.1186,  0.0394, -0.1787, -0.0259, -0.0223,  0.0634,  0.1822])

In [43]:
torch.stack([torchmetrics.functional.pearson_corrcoef(a[i], b[i]) for i in range(a.shape[0])]).shape

torch.Size([2, 7])

In [None]:
batched_pearson_corrcoef

In [11]:
out.shape

torch.Size([101, 2])

In [None]:
multinomial = torch.distributions.Multinomial(total_count=42, logits=torch.tensor([2, 3.2, 5, 1.9]))
nll = -multinomial.log_prob(torch.tensor([7, 8, 20, 7]))
nll

In [None]:
from torch.distributions import Multinomial

In [None]:
y, y_pred = torch.randint(0, 10, size=(4, 42, 7)), torch.rand(4, 42, 7)

In [None]:
manual_nll = []
for i in range(y.shape[0]):
    for j in range(y.shape[2]):
        single_y, single_y_pred = y[i, :, j], y_pred[i, :, j]
        # print(Multinomial(total_count=torch.sum(single_y), logits=single_y_pred))
        manual_nll.append(-Multinomial(int(torch.sum(single_y)), logits=single_y_pred).log_prob(single_y))
true_nll = torch.mean(torch.tensor(manual_nll))
true_nll

In [None]:
class MultinomialNLLLossFromLogits(nn.Module):
    def __init__(self, reduction=torch.mean):
        super(MultinomialNLLLossFromLogits, self).__init__()
        self.reduction = reduction
    
    def __call__(self, y, y_pred, dim=-1):
        neg_log_probs = self.log_likelihood_from_logits(y, y_pred, dim) * -1
        if self.reduction is not None:
            return self.reduction(neg_log_probs)
        return neg_log_probs

    def log_likelihood_from_logits(self, y, y_pred, dim):
        return torch.sum(torch.mul(torch.log_softmax(y_pred, dim=dim), y), dim=dim) + self.log_combinations(y, dim)

    def log_combinations(self, input, dim):
        total_permutations = torch.lgamma(torch.sum(input, dim=dim) + 1)
        counts_factorial = torch.lgamma(input + 1)
        redundant_permutations = torch.sum(counts_factorial, dim=dim)
        return total_permutations - redundant_permutations

print(y.shape)
print(y_pred.shape)

nll_loss = MultinomialNLLLossFromLogits(reduction=torch.mean)
nll = nll_loss(y, y_pred, dim=-2)
nll

In [None]:
assert bool(true_nll == nll)

In [None]:
Multinomial(total_count=int(torch.sum(single_y)), logits=single_y_pred).log_prob(single_y)

In [None]:
nll_loss(single_y, single_y_pred)

In [None]:
torch.sum(single_y)