In [1]:
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree as jt
from jaxtyping import Array
import optax
from optax import OptState, softmax_cross_entropy_with_integer_labels
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm, trange

from hyper_lap.datasets import MediDec, SlicedDataset
from hyper_lap.models import Unet

In [2]:
BATCH_SIZE = 16
EPOCHS = 10

In [3]:
_key = jr.key(0)


def consume():
    global _key
    _key, _consume = jr.split(_key)
    return _consume

In [4]:
dataset = MediDec("/media/LinuxData/datasets/MediDec/Task01_BrainTumour")

sliced_dataset = SlicedDataset(dataset)

train_loader = DataLoader(sliced_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

In [5]:
model = Unet(8, [1, 2, 4, 6, 6], in_channels=4, out_channels=4, key=consume())

In [6]:
opt = optax.adamw(1e-4)

opt_state = opt.init(eqx.filter(model, eqx.is_array))

In [7]:
@jax.jit
def loss_fn(logits: Array, labels: Array) -> Array:
    # b c h w
    neg_log_prob = softmax_cross_entropy_with_integer_labels(
        jnp.moveaxis(logits, 0, -1), labels
    )

    # sum over spatial dims
    neg_log_likelihood = neg_log_prob.sum()

    return neg_log_likelihood

In [8]:
@eqx.filter_jit
def training_step(model: Unet, images: Array, labels: Array, opt_state: OptState) -> tuple[float, Unet, OptState]:
    dynamic_model, static_model = eqx.partition(model, eqx.is_array)

    def grad_fn(dynamic_model: Unet) -> Array:
        model = eqx.combine(dynamic_model, static_model)

        logits = jax.vmap(model)(images)

        loss = jax.vmap(loss_fn)(logits, labels).sum()

        return loss

    loss, grads = eqx.filter_value_and_grad(grad_fn)(dynamic_model)

    updates, opt_state = opt.update(grads, opt_state, dynamic_model)

    dynamic_model = eqx.apply_updates(dynamic_model, updates)

    model = eqx.combine(dynamic_model, static_model)

    return loss, model, opt_state

In [9]:
for epoch in (pbar := trange(EPOCHS)):
    losses = []

    for batch in tqdm(train_loader, leave=False):
        batch = jt.map(jnp.asarray, batch)

        images = batch["image"]
        labels = batch["label"]

        loss, model, opt_state = training_step(model, images, labels, opt_state)

        losses.append(loss.item())

    mean_loss = jnp.mean(jnp.array(losses))

    pbar.write(f"Loss: {mean_loss:.3}")

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/31 [00:00<?, ?it/s]

  self.pid = os.fork()
  self.pid = os.fork()


KeyboardInterrupt: 

In [None]:
it = iter(train_loader)

losses = []

In [None]:
images.shape

In [None]:
batch = jt.map(jnp.asarray, next(it))

images = batch["image"]
labels = batch["label"]

loss, model, opt_state = training_step(model, images, labels, opt_state)

losses.append(loss.item())

In [None]:
mean_loss = jnp.mean(jnp.array(losses))

print(f"Loss: {mean_loss:.3}")