In [None]:
from rbpnet import io

In [None]:
dataset = io.dataset_ops.load_tfrecord('../example/data.matrix/windows.chr13.4.data.matrix.filtered.tfrecord', deserialize=False)
features = io.dataset_ops.features_from_json_file('../example/data.matrix/windows.chr13.4.data.matrix.filtered.tfrecord.features.json')

In [None]:
features.deserialize_example(proto)

In [None]:
torch.gather(torch.rand(8, 3, 4), dim=0, index=torch.tensor([0, 1, 7]))

In [None]:
torch.rand(5, 3, 4)[torch.tensor([True, False, True, False, False]), :, :].shape

In [None]:
proto_samples = [proto for proto in dataset.as_numpy_iterator()]

In [None]:
len(proto_samples)

In [None]:
import torch
import tensorflow as tf
from torch.utils.data import Dataset, DataLoader

class TFRecordDataset(Dataset):
    def __init__(self, filepath, features_filepath=None):
        self._tf_dataset = io.dataset_ops.load_tfrecord(filepath, deserialize=False)
        self._serialized_protos = [proto for proto in self._tf_dataset.as_numpy_iterator()]

        # deserialize
        if features_filepath is None:
            features_filepath = filepath + '.features.json'
        self.features = io.dataset_ops.features_from_json_file(features_filepath)

    def __len__(self):
        return len(self._serialized_protos)

    def __getitem__(self, idx):
        example = self.features.deserialize_example(self._serialized_protos[idx])
        example = (example['inputs'], example['outputs'])
        example = tf.nest.map_structure(lambda x: torch.tensor(x.numpy()).to(torch.float32), example)
        return example

In [None]:
dataset = TFRecordDataset('../example/data.matrix/windows.chr13.4.data.matrix.filtered.tfrecord')
dataset

In [None]:
example = dataset[2]
example

In [None]:
dataloader = DataLoader(dataset, batch_size=128, num_workers=8)

In [None]:
for batch in dataloader:
    _ = batch

In [None]:
from torchrbpnet.data import tfrecord_to_dataloader

for batch in tfrecord_to_dataloader('../example/data.matrix/windows.chr13.4.data.matrix.filtered.tfrecord'):
    _ = batch

In [None]:
mask_HepG2 = list()
with open('../example/experiments.txt') as f:
    for line in f:
        symbol, cell = line.strip().split('_')

        indicator = False
        if cell == 'HepG2':
            indicator = True
        mask_HepG2.append(indicator)
mask_HepG2 = torch.tensor(mask_HepG2)
torch.save(mask_HepG2, 'experiment-mask.HepG2.pt')

In [None]:
mask_K562 = list()
with open('../example/experiments.txt') as f:
    for line in f:
        symbol, cell = line.strip().split('_')

        indicator = False
        if cell == 'K562':
            indicator = True
        mask_K562.append(indicator)
mask_K562 = torch.tensor(mask_K562)
torch.save(mask_K562, 'experiment-mask.K562.pt')

In [None]:
mask_unique_ENCODE_prefer_HepG2 = list()
visited = set()
with open('../example/experiments.txt') as f:
    for line in f:
        symbol, cell = line.strip().split('_')

        indicator = False
        if symbol not in visited:
            indicator = True
            visited.add(symbol)
        print(symbol, indicator)
        mask_unique_ENCODE_prefer_HepG2.append(indicator)

mask_unique_ENCODE_prefer_HepG2 = torch.tensor(mask_unique_ENCODE_prefer_HepG2)
print(torch.sum(mask_unique_ENCODE_prefer_HepG2))
torch.save(mask_unique_ENCODE_prefer_HepG2, 'experiment-mask.ENCODE-150.prioritize-HepG2.pt')

In [None]:
mask = mask_unique_ENCODE_prefer_HepG2
mask

In [None]:
mask.shape

In [None]:
indices = torch.arange(0, 223)
print(indices.shape)
print(indices.dtype)

In [None]:
selected_indices = torch.masked_select(indices, mask)
selected_indices

In [None]:
torch.sum(torch.nn.functional.one_hot(selected_indices, 233), axis=0)

In [None]:
mask.shape

In [None]:
def sample_positives_from_mask(boolean_mask, n):
    assert len(boolean_mask.shape) == 1
    positive_indices = torch.masked_select(torch.arange(0, len(boolean_mask)), mask)
    return torch.multinomial(positive_indices.to(torch.float32), n, replacement=False).to(torch.int64)

sample = sample_positives_from_mask(mask, 10)
sample

In [None]:
def indices_to_mask(indices, depth):
    return torch.sum(torch.nn.functional.one_hot(indices, depth), dim=0).to(torch.bool)

print(indices_to_mask(sample, depth=223))
print(torch.logical_not(indices_to_mask(sample, depth=223)))

In [None]:
idx2symbol_cell = dict()
with open('../example/experiments.txt') as f:
    for i, line in enumerate(f):
        symbol, cell = line.strip().split('_')
        idx2symbol_cell[i] = (symbol, cell)
torch.save(idx2symbol_cell, 'ENCODE.idx2symbol-cell.pt')

In [None]:
sum(mask_unique_ENCODE_prefer_HepG2)

In [None]:
example = next(iter(tfrecord_to_dataloader('../example/data.matrix/windows.chr13.4.data.matrix.filtered.tfrecord', batch_size=16)))

In [None]:
example[1][:, :, mask_K562].shape

In [None]:
torch.logical_and()

In [None]:
class TFIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, filepath, features_filepath=None, batch_size=64, cache=True, shuffle=None):
        super(TFIterableDataset).__init__()
        
        # load tfrecord file and create tf.data pipeline 
        self.dataset = self._load_dataset(filepath, features_filepath, batch_size, cache, shuffle)

    def _load_dataset(self, filepath, features_filepath=None, batch_size=64, cache=True, shuffle=None):
        # no not serialize - only after shuffle/cache 
        dataset = io.dataset_ops.load_tfrecord(filepath, deserialize=False)
        if cache:
            dataset = dataset.cache()
        if shuffle:
            dataset = dataset.shuffle(shuffle)

        # deserialize proto to example
        if features_filepath is None:
            features_filepath = filepath + '.features.json'
        self.features = io.dataset_ops.features_from_json_file(features_filepath)
        dataset = io.dataset_ops.deserialize_dataset(dataset, self.features)

        # batch & prefetch
        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)

        # format example & prefetch
        dataset = dataset.map(self._format_example, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)
        
        return dataset
    
    def _format_example(self, example):
        # move channel dim from -1 to -2
        # example['inputs']['input'] = tf.transpose(example['inputs']['input'], perm=[0, 2, 1])
        # example['outputs']['signal']['total'] = tf.transpose(example['outputs']['signal']['total'], perm=[0, 2, 1])

        example = {
            'inputs': {
                'sequence': tf.transpose(example['inputs']['input'], perm=[0, 2, 1])},
            'outputs': {
                'total': tf.transpose(example['outputs']['signal']['total'], perm=[0, 2, 1]),
                'control': tf.transpose(example['outputs']['signal']['control'], perm=[0, 2, 1]),
            },
        }

        # return (input: Tensor, output: Tensor)
        return example
    
    def process_example(self, example):
        return example
    
    def _to_pytorch_compatible(self, example):
        return tf.nest.map_structure(lambda x: torch.tensor(x).to(torch.float32), example)

    def __iter__(self):
        for example in self.dataset.as_numpy_iterator():
            processed_pytorch_example = self._to_pytorch_compatible(self.process_example(example))
            yield processed_pytorch_example['inputs'], processed_pytorch_example['outputs']

In [None]:
class MaskedTFIterableDataset(TFIterableDataset):
    def __init__(self, masks=None, **kwargs):
        super(MaskedTFIterableDataset, self).__init__(**kwargs)
        self.composite_mask = None
        if masks is not None:
            self.composite_mask = self._make_composite_mask(masks)

    def _make_composite_mask(self, masks):
        composite_mask = masks[0]
        for mask in masks[1:]:
            composite_mask = torch.logical_and(composite_mask, mask)
        return composite_mask
    
    def mask_structure(self, structure, mask):
        return tf.nest.map_structure(lambda tensor: tensor[:, :, mask], structure)

    def process_example(self, example):
        if self.composite_mask is not None:
            example['outputs'] = self.mask_structure(example['outputs'], self.composite_mask)
        return example

In [None]:
class MeanESMEmbeddingMaskedTFIterableDataset(MaskedTFIterableDataset):
    def __init__(self, embedding_matrix_filepath, masks=None, **kwargs):
        super(MeanESMEmbeddingMaskedTFIterableDataset, self).__init__(masks, **kwargs)
        self.embedding_matrix = torch.load(embedding_matrix_filepath)
    
    def process_example(self, example):
        # add protein embedding to inputs
        example['inputs']['embedding'] = self.embedding_matrix[self.composite_mask] if self.composite_mask is not None else self.embedding_matrix
        if self.composite_mask is not None:
            example['outputs'] = self.mask_structure(example['outputs'], self.composite_mask)
        return example

In [None]:
mask_HepG2 = torch.load('experiment-mask.HepG2.pt')
# idx2esm = torch.load('../example/esm2_t33_650M_UR50D.ENCODE.idx2mean.pt')
# idx2esm.shape

esm_masked_dataset = MeanESMEmbeddingMaskedTFIterableDataset(embedding_matrix_filepath='../example/esm2_t33_650M_UR50D.ENCODE.idx2mean.pt', masks=[mask_HepG2], filepath='../example/data.matrix/windows.chr13.4.data.matrix.filtered.tfrecord')
# esm_masked_dataset = MeanESMEmbeddingMaskedTFIterableDataset(embedding_matrix_filepath='../example/esm2_t33_650M_UR50D.ENCODE.idx2mean.pt', filepath='../example/data.matrix/windows.chr13.4.data.matrix.filtered.tfrecord')
example = next(iter(esm_masked_dataset))
print(example[0].keys())
print(example[0]['embedding'].shape)
print(example[1]['total'].shape)

In [None]:
mask_HepG2 = torch.load('experiment-mask.HepG2.pt')

masked_dataset = MaskedTFIterableDataset(masks=[mask_HepG2], filepath='../example/data.matrix/windows.chr13.4.data.matrix.filtered.tfrecord')
example = next(iter(masked_dataset))
example[1]['total'].shape

In [None]:
example = next(iter(TFIterableDataset('../example/data.matrix/windows.chr13.4.data.matrix.filtered.tfrecord')))
example[1]['control'].shape

In [None]:
example[1]

In [None]:
# torch.tensor('abc'.encode('UTF-8'), dtype=torch.uint8)

In [None]:
# int.from_bytes('abc'.encode('UTF-8'), byteorder='big')

In [None]:
idx2esm = torch.load('../example/esm2_t33_650M_UR50D.ENCODE.idx2mean.pt')
idx2esm.shape

In [None]:
idx2esm[mask_HepG2].shape

In [None]:
idx2esm.shape

In [None]:
len(torch.load('experiment-mask.HepG2.pt'))

In [None]:
y = example[1]
y.shape

In [None]:
import torchmetrics

class BatchedPCC(torchmetrics.MeanMetric):
    def __init__(self, min_height=2, min_count=None):
        super(BatchedPCC, self).__init__()

        self.min_height = min_height
        self.min_count = min_count

    def update(self, y_pred: torch.Tensor, y: torch.Tensor):
        if y_pred.shape != y.shape:
            raise ValueError('shapes y_pred {y_pred.shape} and y {y.shape} are not the same. ')

        mean_pcc = self._compute_mean_pcc(y_pred, y)

        # update
        super().update(mean_pcc)

    def _compute_mean_pcc(self, y_pred: torch.Tensor, y: torch.Tensor):
        values = []
        for i in range(y.shape[0]):
            values.append(torchmetrics.functional.pearson_corrcoef(y[i], y_pred[i]))
        # stack to (batch_size x ...) - at this point the shape should be (batch_size x experiments
        values = torch.stack(values)

        # create boolean tensor of entries that are *not* NaNs
        values_is_not_nan_mask = torch.logical_not(torch.isnan(values))
        # convert nan's to 0
        values = torch.nan_to_num(values, 0.0)

        # check if required height is reached per experiment
        if self.min_height is not None:
            # should be shape (batch_size, experiments)
            y_min_height_mask = (torch.max(y, dim=-2).values >= self.min_height)
        else:
            y_min_height_mask = torch.ones(*values.shape)
        
        # check if required count is reached per experiment
        if self.min_count is not None:
            # should be shape (batch_size, experiments)
            y_min_count_mask = (torch.sum(y, dim=-2) >= self.min_count)
        else:
            y_min_count_mask = torch.ones(*values.shape)
        
        # boolean mask indicating which experiment (in each batch) passed nan, heigh and count (and is thus used for the final mean PCC)
        passed_boolean_mask = torch.sum(torch.stack([values_is_not_nan_mask, y_min_height_mask, y_min_count_mask]), dim=0) > 0

        # mask out (i.e. zero) all PCC values that did not pass
        values_masked = torch.mul(values, passed_boolean_mask.to(torch.float32))

        # compute mean by only dividing by #-elements that passed
        values_mean = torch.sum(values_masked)/torch.sum(passed_boolean_mask)

        # if ignore_nan:
        #     # only divide by #-elements not NaN
        #     values_mean = torch.sum(values)/torch.sum(values_is_not_nan)
        # else:
        #     values_mean = torch.mean(values)
        
        return values_mean


m = BatchedPCC(min_height=2, min_count=5)
m(torch.rand(*y.shape), y)

In [None]:
torch.max(torch.rand(2, 2, 3), dim=-1).values

In [None]:
torch.ones(2, 3)

In [None]:
example[0]['embedding'].shape

In [None]:
torch.nn.Linear(in_features=1280, out_features=256)(example[0]['embedding']).shape

In [None]:
import torch.nn as nn

from torchrbpnet.layers import Conv1DFirstLayer, Conv1DResBlock, LinearProjection

class ProteinEmbeddingMultiRBPNet(nn.Module):
    def __init__(self, n_tasks, n_layers=9, n_body_filters=256):
        super(ProteinEmbeddingMultiRBPNet, self).__init__()

        self.n_tasks = n_tasks

        # layers RNA
        self.body = nn.Sequential(*[Conv1DFirstLayer(4, n_body_filters, 6)]+[(Conv1DResBlock(n_body_filters, n_body_filters, dilation=(2**i))) for i in range(n_layers)])
        self.rna_projection = nn.Linear(in_features=n_body_filters, out_features=256, bias=False)

        # layers protein
        self.protein_projection = nn.Linear(in_features=1280, out_features=256, bias=False)

    def forward(self, inputs, **kwargs):
        # forward RNA
        x_r = inputs['sequence']
        for layer in self.body:
            x_r = layer(x_r)
        # transpose: # (batch_size, dim, N) --> (batch_size, N, dim)
        x_r = torch.transpose(x_r, dim0=-2, dim1=-1)
        # project: (batch_size, N, dim) --> (batch_size, N, new_dim)
        x_r = self.rna_projection(x_r)
        
        # forward protein
        x_p = inputs['embedding']
        x_p = self.protein_projection(x_p)
        # x_r: (#proteins, dim)

        # transpose representations for matmul
        # x_r = torch.transpose(x_r, dim0=-2, dim1=-1) # (batch_size, N, dim)
        x_p = torch.transpose(x_p, dim0=1, dim1=0) # (dim, #proteins)
        
        return torch.matmul(x_r, x_p) # (batch_size, N, #proteins)

network = ProteinEmbeddingMultiRBPNet(n_tasks=223)

In [None]:
y_pred = network(example[0])
y_pred.shape

In [None]:
# y_pred = network(example[0])
# print(y_pred[0].shape)
# print(y_pred[1].shape)

# y_pred_1_t = torch.transpose(y_pred[1], dim0=1, dim1=0)
# y_pred_1_t.shape

# y_pred_0_t = torch.transpose(y_pred[0], dim0=-2, dim1=0)
# y_pred_0_t.shape

In [None]:
# from functorch import vmap

In [None]:
torch.matmul(y_pred_0_t, y_pred_1_t).shape