In [1]:
import sys
import glob
import numpy
import scipy

from scipy import sparse

from dragonnfruit.io import LocusGenerator
from dragonnfruit.io import GenomewideGenerator

from dragonnfruit.models import CellStateController
from dragonnfruit.models import DynamicBPNet
from dragonnfruit.models import DragoNNFruit
from dragonnfruit.io import load_data, save_data

import torch
import os

torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True

torch.manual_seed(0)
numpy.random.seed(0)

In [2]:
from dragonnfruit.preprocessing import read_chrom_sizes

chrom_sizes, header = read_chrom_sizes("/workspaces/torch_ddsm/_data_pool1/general/hg38.chrom.sizes")
canonical_chroms = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 
		'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 
		'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 
		'chr22', 'chrX']
filtered_chrom_sizes = {key: value for key, value in chrom_sizes.items() if key in canonical_chroms}
chroms_prepro = list(filtered_chrom_sizes.keys())

test_chroms = ['chr2']
validation_chroms = ['chr8', 'chr10']

training_chroms = [chrom for chrom in chroms_prepro if chrom not in test_chroms + validation_chroms]
chroms = training_chroms + validation_chroms

In [3]:
n_filters = 256
n_layers = 8

batch_size = 256
window = 2114
trimming = (window - 1000) // 2

data_dir = "/workspaces/torch_ddsm/_data_pool1"
peak_file = f"{data_dir}/10x_data/pbmc3kv2/pbmc_granulocyte_sorted_3k_atac_peaks.bed"
run_prefix = "fit_atac_pca"
os.makedirs(f"{data_dir}/test_dragonnfruit_pbmc/{run_prefix}", exist_ok=True)

###
# Load single-cell data
#print("Loading sc-data")

k = 500
n_nodes = 1024
n_outputs = 256

neighbors = numpy.load(f"{data_dir}/test_dragonnfruit_pbmc/atac_neighbors.npz")['arr_0'][:, :k]
cell_states = numpy.load(f"{data_dir}/test_dragonnfruit_pbmc/atac_pca.npz")['arr_0'].astype('float32')
cell_states = (cell_states - cell_states.mean(axis=0, keepdims=True)) / cell_states.std(axis=0, keepdims=True)

read_depths = numpy.load(f"{data_dir}/test_dragonnfruit_pbmc/atac_read_depths.npz")['arr_0'].astype('float32')
read_depths = read_depths[neighbors].sum(axis=1)
read_depths = numpy.log2(read_depths + 1)
read_depths = read_depths.reshape(-1, 1)


#### import the single cell atac signals

In [4]:
X_cscs = load_data(f"{data_dir}/test_dragonnfruit_pbmc/dragonnfruit_data.h5")

#### enumerate the one-hot sequences

In [5]:
import pyfaidx
from bpnetlite.io import one_hot_encode
import h5py
import hdf5plugin

sequences_faidx = pyfaidx.Fasta('/workspaces/torch_ddsm/_data_pool1/10x_data/refdata-gex-GRCh38-2020-A/fasta/genome.fa')

signals = {}
sequences = {}
for key in X_cscs.keys():
    print(f"encoding sequence {key}")
    sequences[key] = one_hot_encode(sequences_faidx[key][:].seq.upper(),
                                    alphabet=['A', 'C', 'G', 'T'])
    print(f"dim: {sequences[key].shape}")
    signals[key] = X_cscs[key]

outfile = h5py.File(f"{data_dir}/test_dragonnfruit_pbmc/{run_prefix}/chrom_onehots.h5", 'w')
for chrom, data in sequences.items():
    print(f"adding {chrom} to hdf5")
    outfile.create_dataset(chrom, data=data.data, **hdf5plugin.Blosc(clevel=9))
# Close the HDF5 file
outfile.close()

encoding sequence chr12
dim: (4, 133275309)
encoding sequence chr15
dim: (4, 101991189)
encoding sequence chr4
dim: (4, 190214555)
encoding sequence chr22
dim: (4, 50818468)
encoding sequence chr5
dim: (4, 181538259)
encoding sequence chr19
dim: (4, 58617616)
encoding sequence chr8
dim: (4, 145138636)
encoding sequence chr11
dim: (4, 135086622)
encoding sequence chr1
dim: (4, 248956422)
encoding sequence chr21
dim: (4, 46709983)
encoding sequence chr7
dim: (4, 159345973)
encoding sequence chr20
dim: (4, 64444167)
encoding sequence chr13
dim: (4, 114364328)
encoding sequence chr9
dim: (4, 138394717)
encoding sequence chr2
dim: (4, 242193529)
encoding sequence chr17
dim: (4, 83257441)
encoding sequence chrX
dim: (4, 156040895)
encoding sequence chr6
dim: (4, 170805979)
encoding sequence chr18
dim: (4, 80373285)
encoding sequence chr10
dim: (4, 133797422)
encoding sequence chr16
dim: (4, 90338345)
encoding sequence chr14
dim: (4, 107043718)
encoding sequence chr3
dim: (4, 198295559)
addin

#### load the sequences from memory

In [6]:
import h5py

# Define the path to the HDF5 file
file_path = f"{data_dir}/test_dragonnfruit_pbmc/{run_prefix}/chrom_onehots.h5"

# Open the HDF5 file for reading
with h5py.File(file_path, 'r') as infile:
    sequences = {}
    
    # Iterate through the datasets in the HDF5 file
    for chrom in infile.keys():
        # Read the data for each dataset (chromosome)
        data = infile[chrom][()]
        
        # Store the data in the sequences dictionary
        sequences[chrom] = data

In [7]:
#print("Done loading sc-data")
###

X = torch.utils.data.DataLoader(
    GenomewideGenerator(
        sequence=sequences,
        signal=signals,
        neighbors=neighbors,
        cell_states=cell_states,
        read_depths=read_depths,
        trimming=trimming, 
        window=window, 
        chroms=training_chroms,
        random_state=None),
        pin_memory=True, 
        num_workers=8,
        worker_init_fn=lambda x: numpy.random.seed(x),
        batch_size=batch_size)

X_valid = LocusGenerator(
    sequence=sequences,
    signal=signals,
    loci_file=peak_file,
    neighbors=neighbors,
    cell_states=cell_states,
    read_depths=read_depths,
    trimming=trimming, 
    window=window,
    chroms=validation_chroms,
    random_state=0)

bias_model = torch.load("/workspaces/torch_ddsm/_data_pool1/test_dragonnfruit_pbmc/pbmc_granulocyte_sorted_3k_atac.final.torch", map_location='cpu').cuda()
controller = CellStateController(n_inputs=cell_states.shape[-1], n_nodes=n_nodes, n_layers=1, n_outputs=n_outputs).cuda()
accessibility_model = DynamicBPNet(n_filters=n_filters, n_layers=n_layers, trimming=trimming, controller=controller).cuda()

name = "{}/test_dragonnfruit_pbmc/{}/dragonnfruit.fibr.{}.{}.{}.{}.{}".format(data_dir,run_prefix,n_filters, n_layers, k, n_nodes, n_outputs)
model = DragoNNFruit(bias_model, accessibility_model, name).cuda()

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
model.fit(X, X_valid, optimizer, n_validation_samples=50, max_epochs=5, 
validation_iter=1000, batch_size=batch_size)

# optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
# model.fit(X, X_valid, optimizer, n_validation_samples=50324, max_epochs=50, 
# validation_iter=1000, batch_size=batch_size)

Epoch	Iteration	Training Time	Validation Time	Training MNLL	Validation MNLL	Validation Profile Correlation	Validation Count Correlation	Saved?
0	0	5.1913	0.7624	25.8086	0.0	0.0	0.0	False
0	1000	142.7383	0.0934	17.8639	0.0	0.0	0.0	False


KeyboardInterrupt: 