In [1]:
# Imports
import os
import torch
import pandas as pd
import numpy as np
import xarray as xr
import torch.nn as nn

import seqpro as sp
import seqmodels as sm
import seqdata as sd

The history saving thread hit an unexpected error (DatabaseError('database disk image is malformed')).History will not be written to the database.


In [2]:
# Report cuda availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using device:', device)

Using device: cuda


In [3]:
# Change working directory
os.chdir("/cellar/users/aklie/data/datasets/Chiou2021_islet_snATAC-seq/results/sequence_models/training")

In [4]:
# Set seed
np.random.seed(1234)

# Data

In [5]:
# Define training and validation chromosomes
train_chroms = ['chr{}'.format(i) for i in range(1, 23)]
valid_chroms = ['chr8', 'chr20']
test_chroms = ['chr1', 'chr3', 'chr6']
train_chroms = [chrom for chrom in train_chroms if chrom not in valid_chroms + test_chroms]
len(train_chroms), len(valid_chroms), len(test_chroms)

(17, 2, 3)

In [6]:
# Load in the SeqData object
sdata = sd.open_zarr("/cellar/users/aklie/data/datasets/Chiou2021_islet_snATAC-seq/results/sequence_models/zarrs/beta_1.zarr")
sdata.load()

In [7]:
# Need to upper case the seqs and add ohe
sdata["seq"] = xr.DataArray(np.char.upper(sdata["seq"]), dims=["_sequence", "_length"])
sdata["ohe_seq"] = xr.DataArray(sp.ohe(sdata["seq"].values, alphabet=sp.DNA), dims=["_sequence", "_length", "_alphabet"]).transpose("_sequence", "_alphabet", "_length")

In [8]:
# Grab some data
test_seqs = torch.tensor(sdata["ohe_seq"][:10].values, dtype=torch.float32)
test_cov = torch.tensor(sdata["cov"][:10].values, dtype=torch.float32)
test_dict = {"ohe_seq": test_seqs}
targets_dict = {"cov": test_cov}

# Architecture

In [9]:
from bpnetlite.bpnet import BPNet

In [10]:
seq_len = 2114
target_len = 1000
trimming = (seq_len - target_len) // 2
seq_len, target_len, trimming

(2114, 1000, 557)

In [11]:
arch = BPNet(n_outputs=1, n_control_tracks=0, trimming=trimming)
arch, arch(test_seqs)[0].shape, arch(test_seqs)[1].shape

(BPNet(
   (iconv): Conv1d(4, 64, kernel_size=(21,), stride=(1,), padding=(10,))
   (irelu): ReLU()
   (rconvs): ModuleList(
     (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
     (1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(4,))
     (2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(8,), dilation=(8,))
     (3): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(16,), dilation=(16,))
     (4): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(32,), dilation=(32,))
     (5): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(64,), dilation=(64,))
     (6): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(128,), dilation=(128,))
     (7): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(256,), dilation=(256,))
   )
   (rrelus): ModuleList(
     (0-7): 8 x ReLU()
   )
   (fconv): Conv1d(64, 1, kernel_size=(75,), stride=(1,), padding=(37,))
   (linear): Linear(in_features=64, out_features=1

# Training module

In [12]:
from seqmodels import Module

import sys
sys.path.append("/cellar/users/aklie/data/datasets/Chiou2021_islet_snATAC-seq/bin/sequence_models")
from bpnet_utils import bpnetlite_loss, bpnetlite_metrics

In [13]:
module = Module(
    arch=arch,
    input_vars=["ohe_seq"],
    output_vars=["profile", "counts"],
    target_vars=["cov"],
    loss_fxn=bpnetlite_loss,
    val_metrics_fxn=bpnetlite_metrics,
    val_metrics_kwargs={"alpha": arch.alpha},
    optimizer="adam",
    optimizer_lr=1e-3,
)
module, module(test_dict)[0].shape, module(test_dict)[1].shape

(Module(
   (arch): BPNet(
     (iconv): Conv1d(4, 64, kernel_size=(21,), stride=(1,), padding=(10,))
     (irelu): ReLU()
     (rconvs): ModuleList(
       (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
       (1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(4,))
       (2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(8,), dilation=(8,))
       (3): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(16,), dilation=(16,))
       (4): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(32,), dilation=(32,))
       (5): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(64,), dilation=(64,))
       (6): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(128,), dilation=(128,))
       (7): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(256,), dilation=(256,))
     )
     (rrelus): ModuleList(
       (0-7): 8 x ReLU()
     )
     (fconv): Conv1d(64, 1, kernel_size=(75,), stride=(1,), padding=(37,))

# Dataloaders

In [14]:
import sys
sys.path.append("/cellar/users/aklie/data/datasets/Chiou2021_islet_snATAC-seq/bin/sequence_models")
from bpnet_utils import get_transforms

In [15]:
# Split into train, valid, test
train_sdata = sdata.sel(_sequence=(sdata["chrom"].isin(train_chroms)).compute())
valid_sdata = sdata.sel(_sequence=(sdata["chrom"].isin(valid_chroms)).compute())
test_sdata = sdata.sel(_sequence=(sdata["chrom"].isin(test_chroms)).compute())
print(f"Train: {len(train_sdata['seq'])}")
print(f"Valid: {len(valid_sdata['seq'])}")
print(f"Test: {len(test_sdata['seq'])}")

Train: 263053
Valid: 28158
Test: 79850


In [16]:
# Dataloader params
max_jitter = 128
rc_prob = 0.5
train_transforms = get_transforms(jitter=True, rc=True, trimming=trimming, max_jitter=max_jitter, rc_prob=rc_prob)
valid_transforms = get_transforms(jitter=False, rc=False, trimming=trimming, max_jitter=max_jitter, rc_prob=rc_prob)

In [17]:
# Get the train dataloader
train_dl = sd.get_torch_dataloader(
    train_sdata,
    sample_dims=['_sequence'],
    variables=['ohe_seq', 'cov'],
    prefetch_factor=None,
    batch_size=32,
    transforms=train_transforms,
    return_tuples=False,
    shuffle=True,
)
batch = next(iter(train_dl))
batch['ohe_seq'].shape, batch['cov'].shape



(torch.Size([32, 4, 2114]), torch.Size([32, 1, 1000]))

In [18]:
# Valid dataloader
valid_dl = sd.get_torch_dataloader(
    valid_sdata,
    sample_dims=['_sequence'],
    variables=['ohe_seq', 'cov'],
    prefetch_factor=None,
    batch_size=64,
    transforms=valid_transforms,
    return_tuples=False,
    shuffle=False,
)
batch = next(iter(valid_dl))
batch["ohe_seq"].shape, batch["cov"].shape

(torch.Size([64, 4, 2114]), torch.Size([64, 1, 1000]))

# Trainer

In [19]:
from pytorch_lightning import Trainer

In [20]:
# Logger
from pytorch_lightning.loggers import CSVLogger
logger = CSVLogger(save_dir="log", name="", version="")

In [21]:
# Add ModelCheckpoint, EarlyStopping and LearningRateMonitor callbacks
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
callbacks = [
ModelCheckpoint(
    dirpath=os.path.join(
        logger.save_dir, 
        logger.name, 
        logger.version, 
        "checkpoints"
    ),
    save_top_k=5,
    monitor="val_loss_epoch",
)
]

In [22]:
# Trainer
trainer = Trainer(
    logger=logger,
    callbacks=callbacks,
    max_epochs=50,
)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


# Fit

In [None]:
# Fit the weigths
trainer.fit(module, train_dl, valid_dl)

In [24]:
# Copy the best weights
best_model_path = trainer.checkpoint_callback.best_model_path
copy_path = os.path.join("best_model.ckpt")
os.system(f"cp {best_model_path} {copy_path}")

0

# Training summary

In [26]:
import sys
sys.path.append("/cellar/users/aklie/data/datasets/Chiou2021_islet_snATAC-seq/bin/sequence_models")
from utils import training_summary

In [27]:
# Save a loss curve
training_summary(logger.save_dir, logger="csv", save="training_summary.png")

# Performance

In [28]:
import sys
sys.path.append("/cellar/users/aklie/data/datasets/Chiou2021_islet_snATAC-seq/bin/sequence_models")
from bpnet_utils import crop
from utils import scatter

In [29]:
# Grab the best model weights
module = Module.load_from_checkpoint(
    "best_model.ckpt", 
    arch=arch,
    input_vars=["ohe_seq"],
    output_vars=["profile", "counts"],
    target_vars=["cov"],
    loss_fxn=bpnetlite_loss,
    val_metrics_fxn=bpnetlite_metrics,
    val_metrics_kwargs={"alpha": arch.alpha},
    optimizer="adam",
    optimizer_lr=1e-3,
).eval().cuda()

In [31]:
# Get test data
X_test = crop(test_sdata["ohe_seq"].values, max_jitter=max_jitter)
y_test = crop(test_sdata["cov"].values, trimming=trimming, max_jitter=max_jitter)
X_test.shape, y_test.shape

(torch.Size([79850, 4, 2114]), torch.Size([79850, 1, 1000]))

In [33]:
# Get the predictions
inputs_dict = {"ohe_seq": X_test}
preds_dict = module.predict(inputs_dict, batch_size=256)
preds_dict["profile"].shape, preds_dict["counts"].shape

Predicting on batches:   0%|          | 0/311 [00:00<?, ?it/s]

(torch.Size([79850, 1, 1000]), torch.Size([79850, 1]))

In [None]:
# Get Performance
targets_dict = {"cov": y_test.cuda()}
metrics_dict = bpnetlite_metrics(preds_dict, targets_dict, alpha=arch.alpha)
profile_corr = metrics_dict['profile_corr']
count_corr = metrics_dict['count_corr']
loss = metrics_dict['profile_mnll'].mean() + arch.alpha * metrics_dict['count_mse'].mean()

In [35]:
# Save the final log
final_log = pd.Series(
    {
        "Epoch": int(trainer.checkpoint_callback.best_model_path.split("epoch=")[1].split("-step")[0]),
        "Iteration": int(trainer.checkpoint_callback.best_model_path.split("-step=")[1].split(".ckpt")[0]),
        "Test MNLL": metrics_dict['profile_mnll'].mean().item(),
        "Test Profile Pearson": np.nan_to_num(profile_corr).mean(),
        "Test Count Pearson": np.nan_to_num(count_corr).mean(),
        "Test Count MSE": metrics_dict['count_mse'].mean().item(),
    }
)
final_log.to_csv("final_log.csv", index=True)

In [37]:
# Grab the counts prediction
true_log_counts = np.log(y_test.sum(axis=(1, 2)) + 1).numpy()
y_counts = preds_dict["counts"].cpu().detach().numpy().squeeze()

In [38]:
# Save the predictions
df = pd.DataFrame({
    "true_log_counts_total": true_log_counts,
    "pred_log_counts_total": y_counts,
})
df.to_csv("test_predictions.csv", index=False)

In [39]:
# Plot a nice blue color
scatter(
    x=true_log_counts,
    y=y_counts,
    c="#4682B4",
    alpha=0.8,
    xlabel="Log true counts + 1",
    ylabel="Log pred counts + 1",
    density=True,
    rasterized=True,
    s=5,
    save="counts_scatter.png",
)

# DONE!

---