In [4]:
import yaml
import os
import sys
import random
# sys.path.append('../../')
import numpy as np
from collections import Counter
from omegaconf import OmegaConf
import itertools
from selene_sdk.utils import load_path, parse_configs_and_run
from selene_sdk.utils.config_utils import module_from_dir, module_from_file
from selene_sdk.utils.config import instantiate
from src.dataset import EncodeDataset, LargeRandomSampler, encode_worker_init_fn
from src.transforms import *
from src.utils import interval_from_line
# from torchvision import transforms
# from torchmetrics import BinnedAveragePrecision, AveragePrecision, Accuracy
from tqdm import tqdm
import pandas as pd
import copy
from src.utils import expand_dims
import gc
gc.enable()

from src.metrics import jaccard_score, threshold_wrapper
from sklearn.metrics import average_precision_score
from selene_sdk.utils.performance_metrics import compute_score

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
path = 'model_configs/biox_dnase_multi_ct_crossval.yaml'
configs = load_path(path, instantiate=False)
configs['dataset']['debug'] = True
configs['dataset']['loader_args']['batch_size'] = 20

In [47]:
from src.deepct_model_multi_ct import DeepCT

model = DeepCT(**configs['model']['class_args'])

In [19]:
from selene_sdk.utils.config_utils import get_full_dataset

full_dataset = get_full_dataset(configs)

DEBUG MODE ON: 1000


In [3]:
from sklearn.model_selection import KFold, GroupKFold, StratifiedKFold

n_folds = 10
k_fold = KFold(n_folds, shuffle=True, random_state=666)

In [8]:
dataset_info = configs["dataset"]

# all intervals
genome_intervals = []
with open(dataset_info["sampling_intervals_path"])  as f:
    for line in f:
        chrom, start, end = interval_from_line(line)
        genome_intervals.append((chrom, start, end))

# bedug mode
if dataset_info['debug']:
    genome_intervals = random.sample(genome_intervals, k=1000)
    print("DEBUG MODE ON:", len(genome_intervals))

with open(dataset_info["distinct_features_path"]) as f:
    distinct_features = list(map(lambda x: x.rstrip(), f.readlines()))

# print(len(distinct_features))

with open(dataset_info["target_features_path"]) as f:
    target_features = list(map(lambda x: x.rstrip(), f.readlines()))

DEBUG MODE ON: 1000


In [9]:
# genome_intervals

In [46]:
splits = []
for train_idx, test_idx in k_fold.split(genome_intervals):
    splits.append((train_idx, test_idx))

len(splits)

10

In [17]:
ct_masks = np.array_split(range(configs['model']['class_args']['n_cell_types']), n_folds)
[len(c) for c in ct_masks]

[64, 63, 63, 63, 63, 63, 63, 63, 63, 63]

In [25]:
train_folds_idx = splits[0][0]
valid_folds_idx = splits[0][1]

len(train_folds_idx), len(valid_folds_idx)


(900, 100)

In [30]:
current_fold_idx = np.append(train_folds_idx, valid_folds_idx)

In [31]:
len(current_fold_idx)

1000

In [40]:
def create_split_loaders(configs, full_dataset, split):
    """
    Called for each split, this creates a two DataLoaders for each split. 
    One DataLoader for the samples in the training folds and one DataLoader 
    for the samples in the validation fold.
    """
    dataset_info = configs["dataset"]
    # current_fold = configs["dataset"]['dataset_args']['fold']
    # print('current fold:', current_fold)

    train_folds_idx = split[0]
    valid_folds_idx = split[1]
    current_fold_idx = np.append(train_folds_idx, valid_folds_idx)

    train_subset = torch.utils.data.Subset(
        full_dataset, 
        current_fold_idx
        )

    val_subset = torch.utils.data.Subset(
        full_dataset, 
        valid_folds_idx
        )
    val_transform = instantiate(dataset_info["val_transform"])
    val_subset.dataset.transform = val_transform

    module = None
    if os.path.isdir(dataset_info["path"]):
        module = module_from_dir(dataset_info["path"])
    else:
        module = module_from_file(dataset_info["path"])

    train_sampler_class = getattr(module, dataset_info["sampler_class"])
    gen = torch.Generator()
    gen.manual_seed(configs["random_seed"])
    train_sampler = train_sampler_class(
        train_subset, replacement=False, generator=gen
    )

    train_loader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=dataset_info["loader_args"]["batch_size"],
        num_workers=dataset_info["loader_args"]["num_workers"],
        worker_init_fn=module.subset_encode_worker_init_fn,
        sampler=train_sampler,
    )

    val_sampler_class = getattr(module, dataset_info["validation_sampler_class"])
    gen = torch.Generator()
    gen.manual_seed(configs["random_seed"])

    val_sampler = val_sampler_class(
        data_source=val_subset, 
        num_samples=dataset_info['validation_sampler_args']['num_samples'], 
        generator=gen
    )

    val_loader = torch.utils.data.DataLoader(
            val_subset,
            batch_size=configs['dataset']["loader_args"]["batch_size"],
            num_workers=configs['dataset']["loader_args"]["num_workers"],
            worker_init_fn=module.subset_encode_worker_init_fn,
            sampler=val_sampler,
        )

    return (train_loader, val_loader) 


def get_all_split_loaders(dataset, cv_splits):
    """Create DataLoaders for each split.

    Keyword arguments:
    dataset -- Dataset to sample from.
    cv_splits -- Array containing indices of samples to 
                 be used in each fold for each split.
    aug_count -- Number of variations for each sample in dataset.
    batch_size -- batch size.
    
    """
    split_samplers = []
    
    for i in range(len(cv_splits)):
        split_samplers.append(
            create_split_loaders(
                configs,
                dataset,
                cv_splits[i]
                )
        )
    return split_samplers




In [41]:
train_loader_0, val_loader_0 = create_split_loaders(configs, full_dataset, splits[0])
len(train_loader_0), len(val_loader_0)

In [44]:
dataloaders = get_all_split_loaders(full_dataset, splits)
len(dataloaders)

10

In [45]:
dataloaders[0]

(<torch.utils.data.dataloader.DataLoader at 0x7fc79232feb8>,
 <torch.utils.data.dataloader.DataLoader at 0x7fc791e55a20>)

In [49]:
for i, batch in tqdm(enumerate(dataloaders[0][1])):
    sequence_batch, cell_type_batch, targets, target_mask = batch
    break

0it [00:00, ?it/s]


In [60]:
target_mask.sum()

tensor(12620)

In [57]:
# torch.masked_select(target_mask, torch.tensor(ct_masks[0]))

target_mask_tr = target_mask.clone()
target_mask_tr[:, ct_masks[0].min(): ct_masks[0].max()+1] = False

In [59]:
target_mask_tr.sum()

tensor(11340)

In [64]:
def train(model, batch, fold):
    """
    Trains the model on a batch of data.

    Returns
    -------
    float
        The training loss.

    """
    model.train()
    
    # retrieved_seq, cell_type, target, target_mask
    sequence_batch = batch[0]#.to(device)
    cell_type_batch = batch[1]#.to(device)
    targets = batch[2]#.to(device)
    target_mask = batch[3]#.to(device)

    # make train mask
    target_mask_tr = target_mask.clone()
    target_mask_tr[:, ct_masks[fold].min(): ct_masks[fold].max()+1] = False

    outputs = model(sequence_batch, cell_type_batch)

    criterion.weight = target_mask_tr
    loss = criterion(outputs, targets)
    if criterion.reduction == "sum":
        loss = loss / criterion.weight.sum()
    predictions = torch.sigmoid(outputs)

    # predictions = predictions * target_mask_tr
    # targets = targets * target_mask_tr

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return (
        predictions.detach().numpy(),
        targets.detach().numpy(),
        target_mask_tr.numpy(),
        loss.item(),
    )

def evaluate(model, batch, target_mask_tr):
    """
    Makes predictions for some labeled input data.

    Parameters
    ----------
    data_in_batches : list(SamplesBatch)
        A list of tuples of the data, where the first element is
        the example, and the second element is the label.

    Returns
    -------
    tuple(float, list(numpy.ndarray))
        Returns the average loss, and the list of all predictions.

    """
    model.eval()

    batch_losses = []
    all_predictions = []
    all_targets = []
    all_target_masks = []

    sequence_batch = batch[0]#.to(device)
    cell_type_batch = batch[1]#.to(device)
    targets = batch[2]#.to(device)
    target_mask = batch[3]#.to(device)
    # print('targets', targets.shape)

    # val mask
    target_mask_val = target_mask.clone()
    target_mask_val = ~target_mask_tr

    if target_mask_val.shape[0] != targets.shape[0]:
        target_mask_val = target_mask_val[:targets.shape[0], ...]

    with torch.no_grad():
        outputs = model(sequence_batch, cell_type_batch)

        criterion.weight = target_mask_val
        loss = criterion(outputs, targets)
        if criterion.reduction == "sum":
            loss = loss / criterion.weight.sum()

        predictions = torch.sigmoid(outputs)
        predictions = predictions.view(-1, predictions.shape[-1])
        targets = targets.view(-1, targets.shape[-1])

        target_mask = target_mask_val.view(-1, target_mask_val.shape[-1])

        all_predictions.append(predictions.data.numpy())
        all_targets.append(targets.data.numpy())
        all_target_masks.append(target_mask.data.numpy())
        batch_losses.append(loss.item())

    all_predictions = expand_dims(np.concatenate(all_predictions))
    all_targets = expand_dims(np.concatenate(all_targets))
    all_target_masks = expand_dims(np.concatenate(all_target_masks))

    return np.average(batch_losses), all_predictions, all_targets, all_target_masks


In [48]:
criterion = torch.nn.BCEWithLogitsLoss(reduction='sum')
optimizer = torch.optim.Adam(params=model.parameters(), lr = 0.0001, weight_decay = 1e-6)

In [65]:
# train_batch_loader -- batches all samples in training folds.
# valid_batch_loader -- batches all samples in validation fold.
for fold, (train_batch_loader, valid_batch_loader) in enumerate(dataloaders):
    # Loop through all batches in training folds for a given split.
    model.train()
    for batch in train_batch_loader:
        # Train model on the training folds in the split.
        prediction, target, target_mask, loss = train(model, batch, fold)

        break
    
    # Loop through all batches in validation fold for a given split.
    model.eval()
    for batch in valid_batch_loader:
        # Test model on the validation fold in the split.
        (
            average_loss,
            all_predictions,
            all_targets,
            all_target_masks,
        ) = evaluate(model, batch, target_mask_tr)
        print(average_loss)
        break
    break

0.6919394135475159


In [7]:
n_folds = 5
ct_mask_range = np.array_split(range(configs['model']['class_args']['n_cell_types']), n_folds)
mask_iterator = itertools.cycle(ct_mask_range)

val_mask_idx = next(mask_iterator)
val_mask_idx

array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
       117, 118, 119, 120, 121, 122, 123, 124, 125, 126])

DeepCT(
  (conv_net): Sequential(
    (0): Conv1d(4, 320, kernel_size=(8,), stride=(1,))
    (1): ReLU(inplace=True)
    (2): Conv1d(320, 320, kernel_size=(8,), stride=(1,))
    (3): ReLU(inplace=True)
    (4): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
    (5): BatchNorm1d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Conv1d(320, 480, kernel_size=(8,), stride=(1,))
    (7): ReLU(inplace=True)
    (8): Conv1d(480, 480, kernel_size=(8,), stride=(1,))
    (9): ReLU(inplace=True)
    (10): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
    (11): BatchNorm1d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): Dropout(p=0.2, inplace=False)
    (13): Conv1d(480, 960, kernel_size=(8,), stride=(1,))
    (14): ReLU(inplace=True)
    (15): Conv1d(960, 960, kernel_size=(8,), stride=(1,))
    (16): ReLU(inplace=True)
    (17): BatchNorm1d(960, eps=1e-05, momentum=0.1, affine=True,

In [9]:
criterion = torch.nn.BCEWithLogitsLoss(reduction='sum')
optimizer = torch.optim.Adam(params=model.parameters(), lr = 0.0001, weight_decay = 1e-6)

In [10]:
def create_random_samples(t):
    """
    """
    ids = torch.multinomial(
        input=torch.ones(t.shape[0]).flatten(),
        num_samples=5,
        replacement=False,
    )
    return ids 

In [11]:
map_baseline_list = []
map_model_list = []

predictions_list = []
baselines = []
targets_list = []
val_masks = []


mask_iterator = itertools.cycle(ct_mask_range)

for i, batch in tqdm(enumerate(full_dataloader)):
    sequence_batch, cell_type_batch, targets, target_mask = batch

    # new masks
    val_mask_idx = next(mask_iterator)
    target_mask_tr = target_mask.clone()
    target_mask_tr[:, val_mask_idx[0]: val_mask_idx[-1]+1] = False
    target_mask_val = ~target_mask_tr

    # compute baseline (mean feature on train step)
    mean_seq_tr = (targets * target_mask_tr).sum(axis=1) / target_mask_tr.sum(axis=1)
    # print('mean_seq_tr', mean_seq_tr)
    mean_seq_batch = torch.repeat_interleave(mean_seq_tr.unsqueeze(1), 631, dim=1)
    # print(mean_seq_batch.shape)

    # model
    logits = model(sequence_batch, cell_type_batch)
    predictions = torch.sigmoid(logits)

    # для обучения используем target_mask_tr
    criterion.weight = target_mask_tr
    loss = criterion(predictions, targets.float())
    loss = loss / criterion.weight.sum()
    # print("train loss:", loss.item())

    # compute baseline score on val step
    map_baseline, ap_baseline = compute_score(
        mean_seq_batch.detach().numpy(), 
        targets.detach().numpy(), 
        average_precision_score, 
        target_mask=target_mask_val.detach().numpy(),
        )
    # compute model score on val step
    map_model, ap_model = compute_score(
        predictions.detach().numpy(), 
        targets.detach().numpy(), 
        average_precision_score, 
        target_mask=target_mask_val.detach().numpy(),
        )

    print(map_baseline, map_model)

    map_baseline_list.append(map_baseline)
    map_model_list.append(map_model)

    # accum preds and targets
    # ids = create_random_samples(targets)
    # targets_list.append(targets[ids, :])
    # predictions_list.append(predictions[ids, :])
    # val_masks.append(target_mask_val[ids, :])
    # baselines.append(mean_seq_batch[ids, :])

    if i > 3:
        break


1it [00:02,  2.21s/it]

0.4791359413406657 0.07153907834551546


2it [00:04,  2.42s/it]

0.2166265392455869 0.08058089365585772


3it [00:06,  2.23s/it]

0.09720248040736762 0.029146953564666213


4it [00:09,  2.44s/it]

0.13918852726275693 0.03524491458875838


4it [00:11,  2.92s/it]

0.8111756077086069 0.2407245478068364





In [206]:
# на полном батче
np.mean(map_baseline_list), np.mean(map_model_list)

(0.3070952526525577, 0.05648780249548171)

In [225]:
45220/3000

15.073333333333334

In [208]:
targets_cat = torch.cat(targets_list, dim=0)
predictions_cat = torch.cat(predictions_list, dim=0)
val_masks_cat = torch.cat(val_masks, dim=0)
baselines_cat = torch.cat(baselines, dim=0)

# compute metrics
map_baseline, ap_baseline = compute_score(
    baselines_cat.detach().numpy(), 
    targets_cat.detach().numpy(), 
    average_precision_score, 
    target_mask=val_masks_cat.detach().numpy(),
    )

map_model, ap_model = compute_score(
    predictions_cat.detach().numpy(), 
    targets_cat.detach().numpy(), 
    average_precision_score, 
    target_mask=val_masks_cat.detach().numpy(),
    )    

map_baseline, map_model

(0.07229601151752332, 0.03363876850943481)

In [212]:
val_masks_cat[0].sum(), val_masks_cat[1].sum()

(tensor(127), tensor(127))

In [142]:
indices = torch.multinomial(
        input=torch.ones(targets.shape[0]).flatten(),
        num_samples=5,
        replacement=False,
    )

indices

tensor([0, 1, 9, 3, 5])

In [144]:
targets[indices, :].shape

torch.Size([5, 631, 1])

In [194]:
# на случ подвыборке

map_baseline_list = []
map_model_list = []

predictions_list = []
baselines = []
targets_list = []
val_masks = []

mask_iterator = itertools.cycle(ct_mask_range)

for i, batch in tqdm(enumerate(full_dataloader)):
    sequence_batch, cell_type_batch, targets, target_mask = batch

    # new masks
    val_mask_idx = next(mask_iterator)
    target_mask_tr = target_mask.clone()
    target_mask_tr[:, val_mask_idx[0]: val_mask_idx[-1]+1] = False
    target_mask_val = ~target_mask_tr

    # compute baseline
    mean_seq_tr = (targets * target_mask_tr).sum(axis=1) / target_mask_tr.sum(axis=1)
    # print('mean_seq_tr', mean_seq_tr)
    mean_seq_batch = torch.repeat_interleave(mean_seq_tr.unsqueeze(1), 631, dim=1)
    # print(mean_seq_batch.shape)

    # model
    logits = model(sequence_batch, cell_type_batch)
    predictions = torch.sigmoid(logits)

    criterion.weight = target_mask_tr
    loss = criterion(predictions, targets.float())
    loss = loss / criterion.weight.sum()
    # print("train loss:", loss.item())

    # compute metrics
    map_baseline, ap_baseline = compute_score(
        mean_seq_batch.detach().numpy(), 
        targets.detach().numpy(), 
        average_precision_score, 
        target_mask=target_mask_val.detach().numpy(),
        )

    map_model, ap_model = compute_score(
        predictions.detach().numpy(), 
        targets.detach().numpy(), 
        average_precision_score, 
        target_mask=target_mask_val.detach().numpy(),
        )    

    print(map_baseline, map_model)
    
    map_baseline_list.append(map_baseline)
    map_model_list.append(map_model)

    # accum preds and targets
    ids = create_random_samples(targets)
    targets_list.append(targets[ids, :])
    predictions_list.append(predictions[ids, :])
    val_masks.append(target_mask_val[ids, :])
    baselines.append(mean_seq_batch[ids, :])

    if i > 3:
        break

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7fe7eda426a0>>Exception ignored in: 
<bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7fe7eda426a0>>Traceback (most recent call last):
  File "/home/thurs/genv/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__

Traceback (most recent call last):
      File "/home/thurs/genv/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()self._shutdown_workers()

  File "/home/thurs/genv/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
  File "/home/thurs/genv/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():    if w.is_alive():
  File "/usr/lib/python3.6/multiprocessing/proces

0.03430623803603497 0.02881303427174626


2it [00:06,  3.53s/it]

None None
0.433577806122449 0.05816887400500888


4it [00:13,  3.50s/it]

0.008272707231040566 0.008317516752354277


4it [00:17,  4.36s/it]

0.07802232003912675 0.03296399140914681





In [195]:
# на полном батче
np.mean(map_baseline_list), np.mean(map_model_list)

TypeError: unsupported operand type(s) for +: 'float' and 'NoneType'

(0.14884606016357732, 0.02611064614409038)

In [189]:
map_model_list

[None,
 0.01994768862137945,
 0.034396911356918525,
 0.02605354163114649,
 0.04695420318794074]