In [1]:
import torch
from bpnet_customize.io_ import PeakGenerator
from bpnet_customize.io_ import extract_loci
import json
from bpnet_customize.bpnet import BPNet
import nbimporter
import cnn
import numpy as np

In [2]:
default_fit_parameters = {
    'n_filters': 64,
    'n_layers': 8,
    'profile_output_bias': True,
    'count_output_bias': True,
    'name': None,
    'batch_size': 64,
    'in_window': 2114,
    'out_window': 1000,
    'max_jitter': 128,
    'reverse_complement': True,
    'max_epochs': 50,
    'validation_iter': 100,
    'lr': 0.001,
    'alpha': 1,
    'verbose': False,

    'min_counts': 0,
    'max_counts': 99999999,

    'training_chroms': ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 
        'chr9', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 
        'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'],
    'validation_chroms': ['chr8', 'chr10'],
    'sequences': None,
    'loci': None,
    'signals': None,
    'controls': None,
    'random_state': None
    }

def merge_parameters(parameters, default_parameters):
    with open(parameters, "r") as infile:
        parameters = json.load(infile)

    for parameter, value in default_parameters.items():
        if parameter not in parameters:
            if value is None and parameter != "controls":
                raise ValueError("Must provide value for '{}'".format(parameter))

            parameters[parameter] = value

    return parameters

In [3]:
# Fit parameter

fit_para = merge_parameters("Data/json/fit.json", default_fit_parameters)

In [4]:
# Fit

training_data = PeakGenerator(
    loci=fit_para['loci'], 
    sequences=fit_para['sequences'],
    signals=fit_para['signals'],
    controls=fit_para['controls'],
    chroms=fit_para['training_chroms'],
    in_window=fit_para['in_window'],
    out_window=fit_para['out_window'],
    max_jitter=fit_para['max_jitter'],
    reverse_complement=fit_para['reverse_complement'],
    min_counts=fit_para['min_counts'],
    max_counts=fit_para['max_counts'],
    random_state=fit_para['random_state'],
    batch_size=fit_para['batch_size'],
    verbose=fit_para['verbose']
)

# training_data = cnn.getInput("Data/hg38.fa")

valid_data = extract_loci(
    sequences=fit_para['sequences'],
    signals=fit_para['signals'],
    controls=fit_para['controls'],
    loci=fit_para['loci'],
    chroms=fit_para['validation_chroms'],
    in_window=fit_para['in_window'],
    out_window=fit_para['out_window'],
    max_jitter=0,
    verbose=fit_para['verbose']
)

if fit_para['controls'] is not None:
    valid_sequences, valid_signals, valid_controls = valid_data
    n_control_tracks = 2
else:
    valid_sequences, valid_signals = valid_data
    valid_controls = None
    n_control_tracks = 0
    
trimming = (fit_para['in_window'] - fit_para['out_window']) // 2

model = BPNet(n_filters=fit_para['n_filters'], 
    n_layers=fit_para['n_layers'],
    n_outputs=len(fit_para['signals']),
    n_control_tracks=n_control_tracks,
    profile_output_bias=fit_para['profile_output_bias'],
    count_output_bias=fit_para['count_output_bias'],
    alpha=fit_para['alpha'],
    trimming=trimming,
    name=fit_para['name'],
    verbose=fit_para['verbose'])#.cuda()

optimizer = torch.optim.AdamW(model.parameters(), lr=fit_para['lr'])

if fit_para['verbose']:
    print("Training Set Size: ", training_data.dataset.sequences.shape[0])
    display(training_data.dataset.sequences)
    print("Validation Set Size: ", valid_sequences.shape[0])

model.fit(training_data, optimizer, X_valid=valid_sequences, 
    X_ctl_valid=valid_controls, y_valid=valid_signals, 
    max_epochs=fit_para['max_epochs'], 
    validation_iter=fit_para['validation_iter'], 
    batch_size=fit_para['batch_size'])

Loading Loci: 100%|███████████████████████| 3749/3749 [00:02<00:00, 1640.39it/s]
Loading Loci: 100%|███████████████████████| 5666/5666 [00:02<00:00, 1927.33it/s]


Training Set Size:  3749


tensor([[[1., 0., 0.,  ..., 0., 1., 1.],
         [0., 1., 1.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 1., 0.],
         [1., 0., 0.,  ..., 1., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 1., 1.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 1.],
         [1., 0., 1.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 1., 0.,  ..., 1., 1., 0.]],

        [[0., 0., 0.,  ..., 1., 0., 0.],
         [1., 0., 1.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 1., 0.,  ..., 0., 0., 1.]],

        [[0., 0., 1.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 1., 1., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 1., 0.,  ..., 0., 0

Validation Set Size:  5666
Epoch	Iteration	Training Time	Validation Time	Training MNLL	Training Count MSE	Validation MNLL	Validation Profile Pearson	Validation Count Pearson	Validation Count MSE	Saved?
0	0	3.8665	68.2027	408.7278	13.3795	891.6004	0.0005122318	0.13394558	4.5433	True
1	100	121.1912	72.9908	519.8365	3.3667	377.0948	0.042791344	0.10566397	2.0463	True
3	200	70.0543	66.8396	423.8701	0.5417	363.0977	0.115572385	0.1471116	1.0467	True
5	300	16.0024	66.8227	433.6124	0.7197	330.6847	0.23323897	0.21064186	0.7596	True
6	400	132.2588	84.9186	383.2237	0.4858	320.3094	0.25925264	0.31981876	0.8492	True
8	500	79.6285	73.0666	396.0485	0.4515	317.5337	0.26761872	0.41197154	0.8388	True
10	600	26.4561	62.1292	413.5856	0.5612	314.5002	0.27353555	0.46600217	0.7039	True
11	700	136.218	63.7203	437.7029	0.4682	313.4678	0.27578235	0.48485813	0.6687	True
13	800	87.8322	71.5903	358.0885	0.4695	312.556	0.27718776	0.5130254	0.7055	True
15	900	50.3742	68.9826	387.1257	0.4125	311.4565	0.2802622	0.51013

In [9]:
# Predict parameters

default_predict_parameters = {
    'batch_size': 64,
    'in_window': 2114,
    'out_window': 1000,
    'verbose': False,
    'chroms': ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 
        'chr9', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 
        'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'],
    'sequences': None,
    'loci': None,
    'controls': None,
    'model': None,
    'profile_filename': 'y_profile.npz',
    'counts_filename': 'y_counts.npz'
}

predict_para = merge_parameters("Data/json/predict.json", default_predict_parameters)

In [10]:
# Predict

model = torch.load(predict_para['model'])#.cuda()

examples = extract_loci(
    sequences=predict_para['sequences'],
    controls=predict_para['controls'],
    loci=predict_para['loci'],
    chroms=predict_para['chroms'],
    max_jitter=0,
    verbose=predict_para['verbose']
)

if predict_para['controls'] == None:
    X = examples
    if model.n_control_tracks > 0:
        X_ctl = torch.zeros(X.shape[0], model.n_control_tracks, X.shape[-1])
    else:
        X_ctl = None
else:
    X, X_ctl = examples

y_profiles, y_counts = model.predict(X, X_ctl=X_ctl, 
    batch_size=predict_para['batch_size'])

np.savez_compressed(predict_para['profile_filename'], y_profiles)
np.savez_compressed(predict_para['counts_filename'], y_counts)

Loading Loci: 100%|███████████████████████| 5666/5666 [00:02<00:00, 2665.70it/s]


In [11]:
np.load(predict_para['profile_filename'])
np.load(predict_para['counts_filename'])

<numpy.lib.npyio.NpzFile at 0x1051eaa90>