In [None]:
"""
Add the zebrafish jaw lib to python path
"""

import os, sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(""), "..", "..")))

In [None]:
debug_plots = False

In [None]:
"""
The first thing we'll do is read a model from disk - we want to access its config file

"""

from fishjaw.model import model

jaw_model = model.load_model("attempt_3.pkl")
jaw_config = jaw_model.config

In [None]:
"""
We'll fine tune on some quadrates that Wahab segmented - we'll want to read these from the RDSF and crop to the right region of interest

"""

import re
import pathlib

import tifffile
from tqdm import tqdm

wahab_labels = (
    jaw_config["rdsf_dir"]
    / pathlib.Path("1Felix and Rich make models/Training dataset Tiffs/Training set 1")
).glob("*.tif")

# Remove these ones, since the 3D tifs dont exist
bad_labels = re.compile(r"(351|401|420|441)")
wahab_labels = [label for label in wahab_labels if not bad_labels.search(label.name)]

# Read the labels
quadrate_labels = [tifffile.imread(path) for path in tqdm(wahab_labels)]
quadrate_labels = [(l == 4) | (l == 5) for l in tqdm(quadrate_labels)]

In [None]:
"""
Read in the images
"""

from fishjaw.util import files

img_paths = [files.get_3d_tif(label_path) for label_path in wahab_labels]
for p in img_paths:
    assert p.exists()

quadrate_imgs = [tifffile.imread(path) for path in tqdm(img_paths)]

In [None]:
"""
Find the centre of the quadrates and crop the labels and images

"""

from scipy.ndimage import center_of_mass

centroids = [
    tuple(round(x) for x in center_of_mass(label)) for label in tqdm(quadrate_labels)
]

In [None]:
from fishjaw.images import transform

window_size = transform.window_size(jaw_config)
cropped_labels = [transform.crop(l, c, window_size, centred=True) for l, c in zip(tqdm(quadrate_labels), centroids)]
cropped_quadrates = [transform.crop(i, c, window_size, centred=True) for i, c in zip(tqdm(quadrate_imgs), centroids)]

In [None]:
"""
Plot the jaws and labels just to check

"""
from fishjaw.visualisation import images_3d

if debug_plots:
    for img, label in zip(cropped_quadrates, cropped_labels):
        images_3d.plot_slices(img, label)

In [None]:
"""
Create a dataloader for these

"""

import torchio as tio
from fishjaw.model import data


# This is the size of the training data
jaw_config["batch_size"] = 11

# Because we're in Jupyter
jaw_config["num_workers"] = 0

# Turn them all into tio subjects first
subjects = [
    data.imgs2subject(img, label)
    for img, label in zip(cropped_quadrates, cropped_labels)
]

train_subjects = tio.SubjectsDataset(
    subjects[:11], transform=data._transforms(jaw_config["transforms"])
)
val_subjects = tio.SubjectsDataset(
    [subjects[-1]], transform=data._transforms(jaw_config["transforms"])
)

quadrate_data = data.DataConfig(jaw_config, train_subjects, val_subjects)

In [None]:
"""
Plot the first bit of trainin data just to visualise it

"""

if debug_plots:
    for i, batch in enumerate(quadrate_data.train_data):
        images = batch[tio.IMAGE][tio.DATA]
        masks = batch[tio.LABEL][tio.DATA]
        # Images per batch
        for j, (image, mask) in enumerate(zip(images, masks)):
            fig, _ = images_3d.plot_slices(
                image.squeeze().numpy(), mask.squeeze().numpy()
            )

In [None]:
"""
Train a model from scratch
"""

import torch


def train_model(
    data_config: data.DataConfig, *, input_model: torch.nn.Module = None
) -> tuple[
    tuple[torch.nn.Module, list[list[float]], list[list[float]], torch.optim.Optimizer]
]:
    """
    Create a model, train and return it

    Returns the model, the training losses and the validation losses, and the optimiser

    :param input_model: The model to load from disk and fine-tune, if specified

    """
    # Create a model and optimiser
    net = model.model(jaw_config["model_params"])

    device = jaw_config["device"]
    net = net.to(device)
    print(f"Model loaded to {device}")

    optimiser = model.optimiser(jaw_config, net)

    # Define loss function
    loss = model.lossfn(jaw_config)

    train_config = model.TrainingConfig(
        device,
        jaw_config["epochs"],
        torch.optim.lr_scheduler.ExponentialLR(
            optimiser, gamma=jaw_config["lr_lambda"]
        ),
    )
    return (
        model.train(net, optimiser, loss, data_config, train_config),
        optimiser,
    )


jaw_config["epochs"] = 500
(net, train_losses, val_losses), optimiser = train_model(quadrate_data)

In [None]:
"""
Plot the training and validation losses

"""

from fishjaw.visualisation import training

fig = training.plot_losses(train_losses, val_losses)

In [None]:
"""
Plot
"""

fig = images_3d.plot_inference(
    net,
    next(iter(val_subjects)),
    patch_size=data.get_patch_size(jaw_config),
    patch_overlap=(4, 4, 4),
    activation=model.activation_name(jaw_config),
    batch_size=jaw_config["batch_size"],
)