In [8]:
import yaml
import os
import sys
# sys.path.append('../../')

import torch
from selene_sdk.utils import load_path, parse_configs_and_run
from selene_sdk.utils.config_utils import module_from_dir, module_from_file
from selene_sdk.utils.config import instantiate
%load_ext autoreload
%autoreload 

In [6]:
path = 'model_configs/biox_dnase_multi_ct_train.yaml'

configs = load_path(path, instantiate=False)

In [10]:
from src.notebook_utils import load_datasets, get_loader
import torchvision
from src.transforms import PermuteSequenceChannels, RandomReverseStrand, MaskTracks, MaskFeatures


boix_folder = '/mnt/datasets/DeepCT/dataset_data/Biox_et_al/dnase'
train_transform = torchvision.transforms.Compose([
    PermuteSequenceChannels(),
    RandomReverseStrand(p=0.5),
    # MaskTracks(track_mask)
])

val_transform = torchvision.transforms.Compose([
    PermuteSequenceChannels(),
    #MaskFeatures(feature_mask)
])

boix_train, boix_val = load_datasets(
    boix_folder, train=True, val=True, test=False, 
    train_transform=train_transform, val_transform=val_transform)
    
boix_val_loader = get_loader(boix_val, batch_size=256, shuffle=1448)
boix_train_loader = get_loader(boix_train, batch_size=256, shuffle=1447)

In [11]:
len(boix_train_loader)

35787

In [3]:
dataset_info = configs["dataset"]

In [4]:
dataset_info = configs["dataset"]

<module 'dataset'>

In [5]:
with open(dataset_info["sampling_intervals_path"]) as f:
    for line in f:
        print(line)
        chrom, start, end = line.rstrip().split("\t")[:3]
        1/0

chr1	10500	10730



ZeroDivisionError: division by zero

In [6]:
dataset_info["class"]

'EncodeDataset'

In [7]:
dataset_class = getattr(module, dataset_info["class"])

load_train_val = True
load_test = False

if "dataset" in configs:
    dataset_info = configs["dataset"]
    train_intervals = []
    val_intervals = []
    test_intervals = []
    with open(dataset_info["sampling_intervals_path"]) as f:
        for line in f:
            chrom, start, end = line.rstrip().split("\t")[:3]
            start = int(start)
            end = int(end)

            if load_train_val and chrom in dataset_info["validation_holdout"]:
                val_intervals.append((chrom, start, end))
            elif load_test and chrom in dataset_info["test_holdout"]:
                test_intervals.append((chrom, start, end))
            elif load_train_val:
                train_intervals.append((chrom, start, end))

    with open(dataset_info["distinct_features_path"]) as f:
        distinct_features = list(map(lambda x: x.rstrip(), f.readlines()))

    with open(dataset_info["target_features_path"]) as f:
        target_features = list(map(lambda x: x.rstrip(), f.readlines()))

In [8]:
# with open(dataset_info["distinct_features_path"]) as f:
#     distinct_features = list(map(lambda x: x.rstrip(), f.readlines()))

# with open(dataset_info["target_features_path"]) as f:
#     target_features = list(map(lambda x: x.rstrip(), f.readlines()))

dataset_info["dataset_args"]["target_features"] = target_features
dataset_info["dataset_args"]["distinct_features"] = distinct_features

In [9]:
dataset_info['path']

'src/dataset.py'

In [10]:
train_config = dataset_info["dataset_args"].copy()
train_config["intervals"] = train_intervals
train_config

{'reference_sequence_path': '/mnt/datasets/DeepCT/male.hg19.fasta',
 'target_path': '/mnt/datasets/DeepCT/dataset_data/Biox_et_al/sorted_data.bed.gz',
 'cell_wise': True,
 'multi_ct_target': True,
 'sequence_length': 1000,
 'center_bin_to_predict': 200,
 'feature_thresholds': 0.5,
 'position_skip': 120,
 'target_features': ['ATAC-seq',
  'CTCF',
  'DNase-seq',
  'EP300',
  'H2AFZ',
  'H2AK5ac',
  'H2AK9ac',
  'H2BK120ac',
  'H2BK12ac',
  'H2BK15ac',
  'H2BK20ac',
  'H2BK5ac',
  'H3F3A',
  'H3K14ac',
  'H3K18ac',
  'H3K23ac',
  'H3K23me2',
  'H3K27ac',
  'H3K27me3',
  'H3K36me3',
  'H3K4ac',
  'H3K4me1',
  'H3K4me2',
  'H3K4me3',
  'H3K56ac',
  'H3K79me1',
  'H3K79me2',
  'H3K9ac',
  'H3K9me1',
  'H3K9me2',
  'H3K9me3',
  'H3T11ph',
  'H4K12ac',
  'H4K20me1',
  'H4K5ac',
  'H4K8ac',
  'H4K91ac',
  'POLR2A',
  'RAD21',
  'SMC3'],
 'distinct_features': ['22Rv1_treated_with_10_nM_17b-hydroxy-5a-androstan-3-one_for_4_hours|CTCF|None',
  '22Rv1_treated_with_10_nM_17b-hydroxy-5a-androstan-3-o

In [11]:
train_config['distinct_features']

['22Rv1_treated_with_10_nM_17b-hydroxy-5a-androstan-3-one_for_4_hours|CTCF|None',
 '22Rv1_treated_with_10_nM_17b-hydroxy-5a-androstan-3-one_for_4_hours|H3K27ac|None',
 '22Rv1|CTCF|None',
 '22Rv1|H3K27ac|None',
 '8988T|DNase-seq|None',
 'A172|DNase-seq|None',
 'A549_treated_with_0.02pct_ethanol_for_1_hour|CTCF|None',
 'A549_treated_with_0.02pct_ethanol_for_1_hour|EP300|None',
 'A549_treated_with_0.02pct_ethanol_for_1_hour|H2AFZ|None',
 'A549_treated_with_0.02pct_ethanol_for_1_hour|H3K27ac|None',
 'A549_treated_with_0.02pct_ethanol_for_1_hour|H3K27me3|None',
 'A549_treated_with_0.02pct_ethanol_for_1_hour|H3K36me3|None',
 'A549_treated_with_0.02pct_ethanol_for_1_hour|H3K4me1|None',
 'A549_treated_with_0.02pct_ethanol_for_1_hour|H3K4me2|None',
 'A549_treated_with_0.02pct_ethanol_for_1_hour|H3K4me3|None',
 'A549_treated_with_0.02pct_ethanol_for_1_hour|H3K79me2|None',
 'A549_treated_with_0.02pct_ethanol_for_1_hour|H3K9ac|None',
 'A549_treated_with_0.02pct_ethanol_for_1_hour|H3K9me3|None',
 '

In [12]:
if "train_transform" in dataset_info:
    # load transforms
    train_transform = instantiate(dataset_info["train_transform"])
    train_config["transform"] = train_transform

train_dataset = dataset_class(**train_config)

In [13]:
train_dataset.__getitem__(0)

(array([[0., 1., 1., ..., 0., 0., 0.],
        [0., 0., 0., ..., 1., 0., 0.],
        [0., 0., 0., ..., 0., 1., 1.],
        [1., 0., 0., ..., 0., 0., 0.]], dtype=float32),
 0.0,
 array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
 array([[False,  True, False, ..., False, False, False],
        [False,  True, False, ..., False, False, False],
        [False, False,  True, ..., False, False, False],
        ...,
        [False,  True,  True, ...,  True, False, False],
        [False,  True,  True, ...,  True, False, False],
        [False,  True,  True, ...,  True, False, False]]))

In [None]:
sampler_class = getattr(module, dataset_info["sampler_class"])
gen = torch.Generator()
gen.manual_seed(configs["random_seed"])

train_sampler = sampler_class(
    train_dataset, replacement=False, generator=gen
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=dataset_info["loader_args"]["batch_size"],
    num_workers=dataset_info["loader_args"]["num_workers"],
    worker_init_fn=module.encode_worker_init_fn,
    sampler=train_sampler,
)