In [None]:
%load_ext autoreload
%autoreload 2
import sys
import os
sys.path.append(os.path.abspath("../src/"))
import model.basset_model as basset_model
import model.binary_performance as binary_performance
import feature.make_binary_dataset as make_binary_dataset
import torch
import numpy as np
import scipy.stats
import tqdm
tqdm.tqdm_notebook()  # It is necessary to call this before the tqdm.notebook submodule is available

In [None]:
# Constants and paths
basset_weights_path = "/users/amtseng/att_priors/data/processed/basset/pretrained_model_reloaded_th.pth"
k562_index = 120  # From https://github.com/kipoi/models/blob/master/Basset/target_labels.txt

labels_hdf5_path = "/users/amtseng/att_priors/data/processed/ENCODE_DNase/binary/labels/K562/K562_labels.h5"
bin_labels_npy_path = "/users/amtseng/att_priors/data/processed/ENCODE_DNase/binary/labels/K562/K562_bin_labels.npy"

reference_fasta_path = "/users/amtseng/genomes/hg38.fasta"
chrom_set = ["chr1"]
input_length = 600

### Import the Basset model
And restore its weights

In [None]:
basset = basset_model.get_model()

In [None]:
state_dict = torch.load(basset_weights_path)
basset.load_state_dict(state_dict)

In [None]:
basset = basset.cuda()

### Create data loader for K562
Basset takes in input sequences of length 600. Our input sequences are 1000, and labeled as a positive if at least half of the central 400 bp overlaps with an IDR-optimal peak. The 600-bp central region should be a reasonably approximated.

In [None]:
batch_size = 128
revcomp = True

loader = make_binary_dataset.create_data_loader(
    labels_hdf5_path, bin_labels_npy_path, batch_size, reference_fasta_path,
    simulate_seqs=False, motif_path=None, motif_bound=0, gc_prob=0.5,
    input_length=input_length, negative_ratio=1, peak_retention=None,
    num_workers=10, revcomp=revcomp, negative_seed=None, shuffle_seed=None,
    peak_signals_npy_or_array=None, chrom_set=chrom_set, shuffle=True,
    return_coords=True
)
loader.dataset.on_epoch_start()

### Compute predictions for the test set

In [None]:
# Allocate arrays to store results
num_expected = batch_size * len(loader.dataset) * (2 if revcomp else 1)
all_seqs = np.empty((num_expected, input_length, 4))
all_true_vals = np.empty((num_expected, 1))
all_pred_vals = np.empty((num_expected, 1))
all_coords = np.empty((num_expected, 3), dtype=object)

num_actual = 0
for batch in tqdm.notebook.tqdm(loader, total=len(loader.dataset)):
    seqs, vals, statuses, coords = batch
    
    # Input to Basset must be shape B x 4 x 600 x 1
    input_seqs = torch.tensor(
        np.swapaxes(np.expand_dims(seqs, axis=3), 1, 2)
    ).float().cuda()
    preds = basset(input_seqs)[:, k562_index : k562_index + 1]
    preds = preds.detach().cpu().numpy()
    
    num_in_batch = len(seqs)
    batch_slice = slice(num_actual, num_actual + num_in_batch)
    all_seqs[batch_slice] = seqs
    all_true_vals[batch_slice] = vals
    all_pred_vals[batch_slice] = preds
    all_coords[batch_slice] = coords
    num_actual += num_in_batch

# Cut off excess
all_seqs = all_seqs[:num_actual]
all_true_vals = all_true_vals[:num_actual]
all_pred_vals = all_pred_vals[:num_actual]
all_coords = all_coords[:num_actual]

### Compute performance metrics

In [None]:
perf_metrics = binary_performance.compute_performance_metrics(
    all_true_vals, all_pred_vals,
    loader.dataset.bins_batcher.neg_to_pos_imbalance
)

In [None]:
for key in perf_metrics:
    print("%s: %.6f" % (key, perf_metrics[key][0]))