In [None]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import torch
from edice import loggers
from edice.data.datasets import EpigenomeSliceDataset, EpigenomeSliceWithTargets
from edice.data.dataset_config import load_dataset
from edice.model.edice import eDICEModel, eDICE
from edice.training import train

from matplotlib import pyplot as plt

# TODO

* check masking
* check inputs construction
* check MSE dims

In [None]:
data_module = load_dataset("PredictdSample")
n_cells, n_assays = len(data_module.cells), len(data_module.assays)
print(f"n cells {n_cells}, n assays {n_assays}")

In [None]:
model = eDICEModel(
    n_cells,
    n_assays,
)
model

# 2. Stepping through the model

### 2.1 Constructing a C x A matrix from a list of tracks

Where C is number of cells, A is number of assays

In [None]:
from edice.model.encoders import InputExpander

In [None]:
batch_size = 3
cell_ids = [data_module.cell2id[data_module.get_track_cell(t)] for t in data_module.splits["train"]]
assay_ids = [data_module.assay2id[data_module.get_track_assay(t)] for t in data_module.splits["train"]]
vals = torch.rand(batch_size, len(data_module.splits["train"]))
batch_cell_ids = torch.from_numpy(np.array(cell_ids)).expand(3,-1)
batch_assay_ids = torch.from_numpy(np.array(assay_ids)).expand(3,-1)

In [None]:
vec2mat = InputExpander(n_cells, n_assays)
obs = vec2mat(vals, batch_cell_ids, batch_assay_ids)
obs.shape

In [None]:
def cell_assay_vals(mat, cell):
    assert mat.ndim ==2  # no batch dim
    cell_id = data_module.cell2id[cell]
    id2assay = {v: k for k, v in data_module.assay2id.items()}
    cell_vec = mat[cell_id]
    d = {id2assay[i]: cell_vec[i] for i in range(mat.shape[-1])}
    return d

print(
    {k: v for k, v in cell_assay_vals(obs[0], "E001").items() if v != 0},
    [t for t in data_module.splits["train"] if data_module.get_track_cell(t) == "E001"]
)

TODO: check that these actually make sense

# 3. Training on sample data

### 3.1 Setup and load data

In [None]:
train_splits = ["train"]
val_split = "val"
n_targets = 120
lr = 3e-4
transformation = "arcsinh"
batch_size = 256

optim = torch.optim.Adam(model.parameters(), lr=lr)
edice = eDICE(
    model,
    optim,
    device=torch.device("cpu"),
    n_targets=n_targets,
)
logger = loggers.StdOutLogger(log_freq=1)

train_tracks = [t for split in train_splits for t in data_module.splits[split]]
train_tracks, train_cell_ids, train_assay_ids = data_module.prepare_data(train_tracks)
if val_split is not None:
    val_tracks = data_module.splits[val_split]
    val_tracks, val_cell_ids, val_assay_ids = data_module.prepare_data(val_tracks)
    data = EpigenomeSliceWithTargets(
        train_tracks,
        train_cell_ids,
        train_assay_ids,
        val_tracks,
        val_cell_ids,
        val_assay_ids,
        transform=transformation,
    )
    train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=False)
else:
    data = EpigenomeSliceDataset(
        train_tracks,
        train_cell_ids,
        train_assay_ids,
        transform=transformation,
    )
    train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
    val_loader = None


### 3.2 Let's visualise the training data

Note we first transform the data back to the original log 10 p-value
scale before visualising using np.sinh

#### TODO: annotate plots with track names and slices

In [None]:
num_track_slices = 5
slice_size = 4000
train_X = train_data.X
num_bins, num_tracks = train_X.shape
print(f"Train data (num bins: {num_bins}, num tracks {num_tracks})")
print(f"Val data (num bins: {val_data.X.shape[0]}, num tracks {val_data.X.shape[1]})")

for i in range(num_track_slices):
    track_ix = np.random.choice(num_tracks)
    start_ix = np.random.choice(num_bins - slice_size)
    plt.figure()
    plt.plot(np.arange(slice_size), np.sinh(train_X[start_ix:start_ix+slice_size,track_ix]))

TODO: add a correlation evaluator.

In [None]:
hist = train(
    edice,
    train_loader,
    epochs=20,
    logger=logger,
    batch_size=batch_size,
    validation_loader=val_loader,
)