In [1]:
import yaml
import os
import sys
import random
# sys.path.append('../../')
import numpy as np
from collections import Counter
from omegaconf import OmegaConf
import itertools
from selene_sdk.utils import load_path, parse_configs_and_run
from selene_sdk.utils.config_utils import module_from_dir, module_from_file
from selene_sdk.utils.config import instantiate
from src.dataset import EncodeDataset, LargeRandomSampler, SubsetRandomSampler, encode_worker_init_fn
from src.transforms import *
from src.utils import interval_from_line
# from torchvision import transforms
# from torchmetrics import BinnedAveragePrecision, AveragePrecision, Accuracy
from tqdm import tqdm
import pandas as pd
import copy
from src.utils import expand_dims
import gc
gc.enable()

from src.metrics import jaccard_score, threshold_wrapper
from sklearn.metrics import average_precision_score
from selene_sdk.utils.performance_metrics import compute_score

%load_ext autoreload
%autoreload 2

In [2]:
ct_means = np.load('results/ct_mean_targets_02.npy', allow_pickle=True)
ct_means.shape

(631,)

In [3]:
path = 'model_configs/biox_dnase_multi_ct_crossval_shuffle_loaders_fold_00.yaml'
configs = load_path(path, instantiate=False)
# configs['dataset']['path'] = 'src/bad_dataset.py'
configs['dataset']['debug'] = True
configs['dataset']['loader_args']['batch_size'] = 20

In [3]:
# from src.deepct_model_multi_ct import DeepCT

# model = DeepCT(**configs['model']['class_args'])

In [4]:
from selene_sdk.utils.config_utils import get_full_dataset, get_full_dataloader

full_dataset = get_full_dataset(configs)
# full_dataloader = get_full_dataloader(configs)

DEBUG MODE ON: 1000


In [5]:
dataset_info = configs["dataset"]

# all intervals
genome_intervals = []
with open(dataset_info["sampling_intervals_path"])  as f:
    for line in f:
        chrom, start, end = interval_from_line(line)
        if chrom not in dataset_info["test_holdout"]:
            genome_intervals.append((chrom, start, end))

# bedug mode
if dataset_info['debug']:
    genome_intervals = random.sample(genome_intervals, k=1000)
    print("DEBUG MODE ON:", len(genome_intervals))

print(len(genome_intervals))  # 1248877 vs 1377454

with open(dataset_info["distinct_features_path"]) as f:
    distinct_features = list(map(lambda x: x.rstrip(), f.readlines()))

with open(dataset_info["target_features_path"]) as f:
    target_features = list(map(lambda x: x.rstrip(), f.readlines()))


random.seed(666)

genome_intervals_arr = np.asarray(genome_intervals, dtype='U10,i8,i8')
random.shuffle(genome_intervals_arr)
seq_splits = np.array_split(genome_intervals_arr, 10)
len(seq_splits)

DEBUG MODE ON: 1000
1000


10

In [6]:
# kfold_intervals = []
# for train_intervals in seq_splits:
#     val_size = int(len(train_intervals)*0.2)
#     random.seed(666)
#     val_intervals = random.sample(train_intervals.tolist(), val_size)
#     kfold_intervals.append((train_intervals, val_intervals))

In [7]:
# kfold_intervals[0]

In [6]:
ct_list = list(range(configs['model']['class_args']['n_cell_types'])) 
ct_masks = []
for fold in range(10):
    random.seed(666)
    random.shuffle(ct_list)
    ct_masks.append(np.array_split(ct_list, 10))

In [9]:
len(ct_masks[0])

10

In [10]:
sum([len(c) for c in ct_masks[0]])

631

In [15]:
len(ct_masks[0])

10

In [7]:
ct_masks = np.load(configs['dataset']['ct_fold_ids'], allow_pickle=True)
# набор масок для текущей модели
curr_fold = configs['dataset']['dataset_args']['fold']
ct_masks = ct_masks[curr_fold]
print('# cell_type folds:', len(ct_masks))

# cell_type folds: 10


In [39]:
[len(c) for c in ct_masks]

[64, 63, 63, 63, 63, 63, 63, 63, 63, 63]

In [8]:
from selene_sdk.utils.config_utils import get_all_split_loaders


splits = np.load(configs['dataset']['seq_fold_ids'], allow_pickle=True)
full_dataset = get_full_dataset(configs)
print('full_dataset len:', len(full_dataset))
dataloaders = get_all_split_loaders(configs, full_dataset, splits)
print('# dataloaders:', len(dataloaders))
print('train/val lens:', len(dataloaders[0][0]), len(dataloaders[0][1]))

DEBUG MODE ON: 1000
full_dataset len: 7885
# dataloaders: 10
train/val lens: 65366 13324


In [35]:
dataloaders[0][0].dataset.transform, dataloaders[0][1].dataset.transform

(Compose(
     PermuteSequenceChannels()
     RandomReverseStrand()
 ),
 PermuteSequenceChannels())

In [45]:
dataloaders

[(<torch.utils.data.dataloader.DataLoader at 0x7f76c3673710>,
  <torch.utils.data.dataloader.DataLoader at 0x7f7758359438>),
 (<torch.utils.data.dataloader.DataLoader at 0x7f76b841fd68>,
  <torch.utils.data.dataloader.DataLoader at 0x7f76b73b6630>),
 (<torch.utils.data.dataloader.DataLoader at 0x7f76c05324a8>,
  <torch.utils.data.dataloader.DataLoader at 0x7f76bc29ccc0>),
 (<torch.utils.data.dataloader.DataLoader at 0x7f76bf9cbb38>,
  <torch.utils.data.dataloader.DataLoader at 0x7f76ba9d3400>),
 (<torch.utils.data.dataloader.DataLoader at 0x7f76bfc4d240>,
  <torch.utils.data.dataloader.DataLoader at 0x7f76b80ffb00>),
 (<torch.utils.data.dataloader.DataLoader at 0x7f76bdc55978>,
  <torch.utils.data.dataloader.DataLoader at 0x7f76ba447278>),
 (<torch.utils.data.dataloader.DataLoader at 0x7f76bb3aa0b8>,
  <torch.utils.data.dataloader.DataLoader at 0x7f76b8097940>),
 (<torch.utils.data.dataloader.DataLoader at 0x7f76c36c2780>,
  <torch.utils.data.dataloader.DataLoader at 0x7f76b711e080>),


In [9]:
fold_map = {}

for i in range(len(dataloaders)):
    fold_map[dataloaders[i]] = ct_masks[i]

In [119]:
len(fold_map.keys())

10

In [120]:
random.seed(666)
l = list(fold_map.items())
for chunk, ((train_batch_loader, valid_batch_loader), current_ct_mask) in enumerate(dict(random.sample(l, len(fold_map.keys()))).items()):
    print(chunk, current_ct_mask[:5])

0 [  8 583 216 575 465]
1 [387 623 587 278 480]
2 [106 315 630 515 310]
3 [307 229  74 591  55]
4 [423 202 442 435 403]
5 [ 40  37 139 298 206]
6 [ 84 246 397 574 260]
7 [ 60 239 485 449  80]
8 [168 238 402 497 407]
9 [195 143 588 100 288]


In [118]:
random.seed(666)
l = list(fold_map.items())
for chunk, ((train_batch_loader, valid_batch_loader), current_ct_mask) in enumerate(dict(random.sample(l, 10)).items()):
    if chunk >= 7:
        print(chunk, current_ct_mask[:5])

7 [ 60 239 485 449  80]
8 [168 238 402 497 407]
9 [195 143 588 100 288]


In [4]:
import math

((1/4) * (2/3)**3)/((1/4) * (2/3)**3 + (3/4) * (1/3)**3)

0.7272727272727272

In [124]:
# main train/val loop
# random.seed(666)
for epoch in tqdm(range(0, 10)):
    print('epoch:', epoch)
    # shuffle loaders
    # random.seed(666)

    random.seed(666+epoch)
    l = list(fold_map.items())
    for chunk, ((train_batch_loader, valid_batch_loader), current_ct_mask) in enumerate(dict(random.sample(l, len(fold_map.keys()))).items()):
        print(chunk, current_ct_mask[:5])
        
    print('--------------')
    # if epoch == 2:
    #     break


100%|██████████| 10/10 [00:00<00:00, 318.06it/s]

epoch: 0
0 [  8 583 216 575 465]
1 [387 623 587 278 480]
2 [106 315 630 515 310]
3 [307 229  74 591  55]
4 [423 202 442 435 403]
5 [ 40  37 139 298 206]
6 [ 84 246 397 574 260]
7 [ 60 239 485 449  80]
8 [168 238 402 497 407]
9 [195 143 588 100 288]
--------------
epoch: 1
0 [195 143 588 100 288]
1 [106 315 630 515 310]
2 [387 623 587 278 480]
3 [ 60 239 485 449  80]
4 [168 238 402 497 407]
5 [307 229  74 591  55]
6 [423 202 442 435 403]
7 [ 84 246 397 574 260]
8 [ 40  37 139 298 206]
9 [  8 583 216 575 465]
--------------
epoch: 2
0 [  8 583 216 575 465]
1 [195 143 588 100 288]
2 [ 40  37 139 298 206]
3 [168 238 402 497 407]
4 [106 315 630 515 310]
5 [ 84 246 397 574 260]
6 [ 60 239 485 449  80]
7 [307 229  74 591  55]
8 [423 202 442 435 403]
9 [387 623 587 278 480]
--------------
epoch: 3
0 [423 202 442 435 403]
1 [106 315 630 515 310]
2 [  8 583 216 575 465]
3 [ 60 239 485 449  80]
4 [195 143 588 100 288]
5 [307 229  74 591  55]
6 [ 84 246 397 574 260]
7 [387 623 587 278 480]
8 [ 40 




In [24]:
# main train/val loop
for epoch in tqdm(range(1, 10)):
    print('epoch:', epoch)
    # shuffle loaders
    random.seed(666)
    l = list(fold_map.items())
    random.shuffle(l)
    fold_map = dict(l)

    for chunk, ((train_batch_loader, valid_batch_loader), current_ct_mask) in enumerate(fold_map.items()):

        print(current_ct_mask[:5])
    print('--------------')

100%|██████████| 9/9 [00:00<00:00, 421.09it/s]

epoch: 1
[168 238 402 497 407]
[195 143 588 100 288]
[ 40  37 139 298 206]
[106 315 630 515 310]
[307 229  74 591  55]
[ 84 246 397 574 260]
[423 202 442 435 403]
[387 623 587 278 480]
[ 60 239 485 449  80]
[  8 583 216 575 465]
--------------
epoch: 2
[195 143 588 100 288]
[106 315 630 515 310]
[ 84 246 397 574 260]
[  8 583 216 575 465]
[168 238 402 497 407]
[307 229  74 591  55]
[ 40  37 139 298 206]
[ 60 239 485 449  80]
[423 202 442 435 403]
[387 623 587 278 480]
--------------
epoch: 3
[106 315 630 515 310]
[  8 583 216 575 465]
[307 229  74 591  55]
[387 623 587 278 480]
[195 143 588 100 288]
[168 238 402 497 407]
[ 84 246 397 574 260]
[423 202 442 435 403]
[ 40  37 139 298 206]
[ 60 239 485 449  80]
--------------
epoch: 4
[  8 583 216 575 465]
[387 623 587 278 480]
[168 238 402 497 407]
[ 60 239 485 449  80]
[106 315 630 515 310]
[195 143 588 100 288]
[307 229  74 591  55]
[ 40  37 139 298 206]
[ 84 246 397 574 260]
[423 202 442 435 403]
--------------
epoch: 5
[387 623 587 27




In [11]:
for (train_batch_loader, valid_batch_loader), ct_mask in fold_map.items():
    print(len(train_batch_loader), len(valid_batch_loader), len(ct_mask))

55055 11050 63
50855 10269 63
50502 10047 63
52922 10516 63
65366 13324 64
50500 10030 63
53064 10540 63
51413 10316 63
51438 10423 63
51486 10332 63


In [12]:
ct_mask

array([307, 229,  74, 591,  55,  13, 378, 134, 127, 176, 492, 562, 600,
       599, 153, 446, 363, 528, 616, 162, 113, 251, 326, 245, 450, 137,
         5, 241, 490,  65,  54, 595, 149, 205,  70, 286,  90, 496, 475,
       453, 265,  33, 208, 211, 356, 214, 316, 366, 474, 438, 498, 484,
       421, 531, 441, 148, 339, 170, 361, 384, 370, 199, 417])

In [13]:
for batch in valid_batch_loader:
    sequence_batch = batch[0]
    cell_type_batch = batch[1]
    targets = batch[2]
    target_mask = batch[3]

    # val mask
    target_mask_tr = target_mask.clone()
    # !!!
    target_mask_tr[:, ct_mask, :] = False
    target_mask_val = ~target_mask_tr
    break

In [14]:
target_mask.sum(), target_mask_tr.sum(), target_mask_val.sum()

(tensor(12620), tensor(11360), tensor(1260))

In [15]:
target_mask_tr.sum() + target_mask_val.sum()

tensor(12620)

10

array([('chr10', 66184220, 66184455), ('chr6',   934320,   934680),
       ('chr10', 66184220, 66184455)],
      dtype=[('f0', '<U10'), ('f1', '<i8'), ('f2', '<i8')])

In [48]:
seq_splits[0].shape[0]*0.25

31222.0

In [51]:
module = None
if os.path.isdir(dataset_info["path"]):
    module = module_from_dir(dataset_info["path"])
else:
    module = module_from_file(dataset_info["path"])

dataset_class = getattr(module, dataset_info["class"])
dataset_info["dataset_args"]["target_features"] = target_features
dataset_info["dataset_args"]["distinct_features"] = distinct_features

# load train dataset and loader
data_config = dataset_info["dataset_args"].copy()
data_config["intervals"] = seq_splits[0]#.tolist()

del data_config['fold']
del data_config['n_folds']
train_subset = dataset_class(**data_config)

train_sampler_class = getattr(module, dataset_info["sampler_class"])
gen = torch.Generator()
gen.manual_seed(configs["random_seed"])
train_sampler = train_sampler_class(
    train_subset, replacement=False, generator=gen
)

train_loader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=dataset_info["loader_args"]["batch_size"],
        num_workers=dataset_info["loader_args"]["num_workers"],
        worker_init_fn=module.encode_worker_init_fn,
        sampler=train_sampler,
    )

In [52]:
val_sampler_class = getattr(module, dataset_info["validation_sampler_class"])
gen = torch.Generator()
gen.manual_seed(configs["random_seed"])

val_sampler = val_sampler_class(
    data_source=train_subset, 
    num_samples=dataset_info['validation_sampler_args']['num_samples'], 
    generator=gen
)

val_loader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=configs['dataset']["loader_args"]["batch_size"],
        num_workers=configs['dataset']["loader_args"]["num_workers"],
        worker_init_fn=module.subset_encode_worker_init_fn,
        sampler=val_sampler,
    )

In [18]:
splits = []
for train_idx, test_idx in k_fold.split(genome_intervals):
    splits.append((train_idx, test_idx))

len(splits)

10

In [11]:
[len(s) for s in splits]

[2, 2, 2, 2, 2, 2, 2, 2, 2, 2]

In [12]:
len(splits[0][0]), len(splits[0][1])

(900, 100)

In [69]:
np.save(f'/home/thurs/DeepCT/results/kfold_splits_hold.npy', splits)

In [6]:
splits = np.load(f'/home/thurs/DeepCT/results/kfold_splits_hold.npy', allow_pickle=True)
len(splits)

10

In [19]:
for i in splits:
    print(i[0], i[1])

[ 124888  124889  124890 ... 1248874 1248875 1248876] [     0      1      2 ... 124885 124886 124887]
[      0       1       2 ... 1248874 1248875 1248876] [124888 124889 124890 ... 249773 249774 249775]
[      0       1       2 ... 1248874 1248875 1248876] [249776 249777 249778 ... 374661 374662 374663]
[      0       1       2 ... 1248874 1248875 1248876] [374664 374665 374666 ... 499549 499550 499551]
[      0       1       2 ... 1248874 1248875 1248876] [499552 499553 499554 ... 624437 624438 624439]
[      0       1       2 ... 1248874 1248875 1248876] [624440 624441 624442 ... 749325 749326 749327]
[      0       1       2 ... 1248874 1248875 1248876] [749328 749329 749330 ... 874213 874214 874215]
[      0       1       2 ... 1248874 1248875 1248876] [874216 874217 874218 ... 999100 999101 999102]
[      0       1       2 ... 1248874 1248875 1248876] [ 999103  999104  999105 ... 1123987 1123988 1123989]
[      0       1       2 ... 1123987 1123988 1123989] [1123990 1123991 11239

In [75]:
splits[1][0], splits[1][1]

(array([      0,       1,       2, ..., 1248873, 1248875, 1248876]),
 array([     10,      40,      50, ..., 1248846, 1248851, 1248874]))

In [74]:
random.seed(666)

train_folds_idx = splits[0][0]
valid_folds_idx = splits[0][1]
current_fold_idx = np.append(train_folds_idx, valid_folds_idx)
random.shuffle(current_fold_idx)
current_fold_idx


array([1140471,  732277,  782472, ..., 1012061,  878426, 1062745])

In [4]:
import numpy as np
import random
random.seed(666)

n_folds = configs['dataset']['dataset_args']['n_folds']
ct_list = list(range(configs['model']['class_args']['n_cell_types'])) 
ct_masks = []

for fold in range(n_folds):
    random.shuffle(ct_list)
    # print(ct_list)
    ct_masks.append(np.array_split(ct_list, n_folds))

print([len(c) for c in ct_masks])

[10, 10, 10, 10, 10, 10, 10, 10, 10, 10]


In [10]:
ct_masks[0][0], ct_masks[0][1]

array([ 60, 239, 485, 449,  80, 455, 138,  85, 183, 328, 257,  16,  71,
        32,   3, 594, 332, 401, 185, 473, 207, 413,  98, 493, 231, 306,
       454,  66, 524, 470, 556, 610, 275, 519, 367, 536, 261,  68, 412,
       428, 281, 406, 461, 572, 130, 547, 123,  67, 415, 424, 171, 615,
       254, 608, 248, 567, 146,  47, 486, 272, 182, 362, 252, 522])

In [14]:
# 10 val masks
[c.shape[0] for c in ct_masks[0]]

[64, 63, 63, 63, 63, 63, 63, 63, 63, 63]

In [19]:
ct_masks[0][0].min(), ct_masks[0][0].max()

(3, 615)

In [58]:
np.save('results/ct_random_ids_k10.npy', ct_masks)

In [59]:
ct_masks = np.load('results/ct_random_ids_k10.npy', allow_pickle=True)

In [101]:
dataloaders = get_all_split_loaders(configs, full_dataset, kfold_intervals)
len(dataloaders)

10

In [106]:
dataloaders[0][0].dataset.dataset.transform, dataloaders[0][1].dataset.dataset.transform

(PermuteSequenceChannels(), PermuteSequenceChannels())

In [22]:
train_loader_0, val_loader_0 = create_split_loaders(configs, full_dataset, splits[0])
len(train_loader_0), len(val_loader_0)

(50, 5)

In [23]:
dataloaders = get_all_split_loaders(full_dataset, splits)
len(dataloaders)

10

In [107]:
for i, batch in tqdm(enumerate(dataloaders[0][1])):
    sequence_batch, cell_type_batch, targets, target_mask = batch
    break

0it [00:00, ?it/s]


TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/thurs/genv/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/thurs/genv/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/thurs/genv/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/thurs/genv/lib/python3.6/site-packages/torch/utils/data/dataset.py", line 311, in __getitem__
    return self.dataset[self.indices[idx]]
  File "src/dataset.py", line 219, in __getitem__
    chrom, pos, cell_type_idx = self._get_chrom_pos_cell_by_idx(idx)
  File "src/dataset.py", line 251, in _get_chrom_pos_cell_by_idx
    interval_idx = bisect.bisect(self.intervals_length_sums, position_idx) - 1
TypeError: '<' not supported between instances of 'tuple' and 'int'


In [26]:
targets.sum()

tensor(134.)

In [19]:
targets.size()

torch.Size([20, 631, 1])

In [44]:
def train(model, batch, fold):
    """
    Trains the model on a batch of data.

    Returns
    -------
    float
        The training loss.

    """    
    # retrieved_seq, cell_type, target, target_mask
    sequence_batch = batch[0]#.to(device)
    cell_type_batch = batch[1]#.to(device)
    targets = batch[2]#.to(device)
    target_mask = batch[3]#.to(device)

    # make train mask
    target_mask_tr = target_mask.clone()
    target_mask_tr[:, ct_masks[fold].min(): ct_masks[fold].max()+1] = False

    outputs = model(sequence_batch, cell_type_batch)

    criterion.weight = target_mask_tr
    loss = criterion(outputs, targets)
    if criterion.reduction == "sum":
        loss = loss / criterion.weight.sum()
    predictions = torch.sigmoid(outputs)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return (
        predictions.detach().numpy(),
        targets.detach().numpy(),
        target_mask_tr.numpy(),
        loss.item(),
    )

def evaluate(model, batch, target_mask_tr):
    """
    Makes predictions for some labeled input data.

    Parameters
    ----------
    data_in_batches : list(SamplesBatch)
        A list of tuples of the data, where the first element is
        the example, and the second element is the label.

    Returns
    -------
    tuple(float, list(numpy.ndarray))
        Returns the average loss, and the list of all predictions.

    """

    sequence_batch = batch[0]#.to(device)
    cell_type_batch = batch[1]#.to(device)
    targets = batch[2]#.to(device)
    target_mask = batch[3]#.to(device)
    # print('targets', targets.shape)

    # val mask
    target_mask_val = target_mask.clone()
    target_mask_val = ~target_mask_tr

    if target_mask_val.shape[0] != targets.shape[0]:
        target_mask_val = target_mask_val[:targets.shape[0], ...]

    # compure a baseline
    baseline = (targets * target_mask_val).sum(axis=1) / target_mask_val.sum(axis=1)
    baseline = torch.repeat_interleave(baseline.unsqueeze(1), 631, dim=1)

    with torch.no_grad():
        outputs = model(sequence_batch, cell_type_batch)

        criterion.weight = target_mask_val
        loss = criterion(outputs, targets)
        if criterion.reduction == "sum":
            loss = loss / criterion.weight.sum()

        predictions = torch.sigmoid(outputs)
        predictions = predictions.view(-1, predictions.shape[-1])
        targets = targets.view(-1, targets.shape[-1])
        baseline = baseline.view(-1, baseline.shape[-1])
        target_mask = target_mask_val.view(-1, target_mask_val.shape[-1])

    return loss, predictions, targets, baseline, target_mask


In [22]:
criterion = torch.nn.BCEWithLogitsLoss(reduction='sum')
optimizer = torch.optim.Adam(params=model.parameters(), lr = 0.0001, weight_decay = 1e-6)

In [37]:
from selene_sdk.utils import (
    PerformanceMetrics,
    initialize_logger,
    load_model_from_state_dict,
)
from sklearn.metrics import (
    ConfusionMatrixDisplay,
    average_precision_score,
    confusion_matrix,
    roc_auc_score,
)


metrics=dict(roc_auc=roc_auc_score, average_precision=average_precision_score)

_test_metrics = PerformanceMetrics(
            lambda idx: train_batch_loader.dataset.dataset.target_features[idx],
            report_gt_feature_n_positives=10,
            metrics=metrics,
        )

print(train_batch_loader.dataset.dataset.target_features[0])


DNase-seq


In [48]:
# train_batch_loader -- batches all samples in training folds.
# valid_batch_loader -- batches all samples in validation fold.
for fold, (train_batch_loader, valid_batch_loader) in enumerate(dataloaders):
    print('fold:', fold, len(train_batch_loader), len(valid_batch_loader))
    # Loop through all batches in training folds for a given split.
    model.train()
    tain_losses = []
    for batch in tqdm(train_batch_loader):
        # Train model on the training folds in the split.
        prediction, target, target_mask, loss = train(model, batch, fold)
        tain_losses.append(loss)
    print('train loss:', np.average(tain_losses))

    # Loop through all batches in validation fold for a given split.
    model.eval()
    batch_losses = []
    all_predictions = []
    all_targets = []
    all_target_masks = []
    all_baselines = []
    for batch in tqdm(valid_batch_loader):
        # Test model on the validation fold in the split.
        (
            loss,
            predictions,
            targets,
            baseline,
            target_masks,
        ) = evaluate(model, batch, target_mask_tr)

        all_predictions.append(predictions.data.numpy())
        all_targets.append(targets.data.numpy())
        all_target_masks.append(target_masks.data.numpy())
        all_baselines.append(baseline.data.numpy())
        batch_losses.append(loss.item())

    all_predictions = expand_dims(np.concatenate(all_predictions))
    all_targets = expand_dims(np.concatenate(all_targets))
    all_baselines = expand_dims(np.concatenate(all_baselines))
    all_target_masks = expand_dims(np.concatenate(all_target_masks))

    # compute metrics
    average_scores = _test_metrics.update(
        all_predictions, all_targets, all_target_masks
    )
    baseline_score = _test_metrics.update(
        all_baselines, all_targets, all_target_masks
    )

    for name, score in average_scores.items():
        print(name, score)
    for name, score in baseline_score.items():
        print(f'baseline_{name}', score)   


fold: 0 50 5


100%|██████████| 50/50 [05:21<00:00,  6.43s/it]


train loss: 0.6676654028892517


100%|██████████| 5/5 [00:08<00:00,  1.76s/it]


roc_auc 0.8709891324896739
average_precision 0.6355959132281161
baseline_roc_auc 0.9642915112732995
baseline_average_precision 0.7226254161902645
fold: 1 50 5


100%|██████████| 50/50 [05:26<00:00,  6.53s/it]


train loss: 0.6404233521223068


100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


roc_auc 0.9067690793621503
average_precision 0.6030687925927048
baseline_roc_auc 0.9691030431640295
baseline_average_precision 0.7024954911209768
fold: 2 50 5


 28%|██▊       | 14/50 [01:31<03:55,  6.55s/it]


KeyboardInterrupt: 