In [None]:
"""
Add the zebrafish jaw lib to python path
"""

import os, sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(""), "..", "..")))

In [None]:
debug_plots = False

In [None]:
"""
The first thing we'll do is read a model from disk - we want to access its config file

"""

from fishjaw.model import model

model_name = "attempt_3.pkl"
jaw_model = model.load_model("attempt_3.pkl")
jaw_config = jaw_model.config

In [None]:
"""
We'll fine tune on some quadrates that Wahab segmented - we'll want to read these from the RDSF and crop to the right region of interest

"""

import re
import pathlib

import tifffile
from tqdm import tqdm

wahab_labels = (
    jaw_config["rdsf_dir"]
    / pathlib.Path("1Felix and Rich make models/Training dataset Tiffs/Training set 1")
).glob("*.tif")

# Remove these ones, since the 3D tifs dont exist
bad_labels = re.compile(r"(351|401|420|441)")
wahab_labels = [label for label in wahab_labels if not bad_labels.search(label.name)]

# Read the labels
quadrate_labels = [tifffile.imread(path) for path in tqdm(wahab_labels)]
quadrate_labels = [(l == 4) | (l == 5) for l in tqdm(quadrate_labels)]

In [None]:
"""
Read in the images
"""

from fishjaw.util import files

img_paths = [files.get_3d_tif(label_path) for label_path in wahab_labels]
for p in img_paths:
    assert p.exists()

quadrate_imgs = [tifffile.imread(path) for path in tqdm(img_paths)]

In [None]:
"""
Find the centre of the quadrates and crop the labels and images

"""

from scipy.ndimage import center_of_mass

centroids = [
    tuple(round(x) for x in center_of_mass(label)) for label in tqdm(quadrate_labels)  # type: ignore
]

In [None]:
from fishjaw.images import transform

window_size = transform.window_size(jaw_config)
cropped_labels = [transform.crop(l, c, window_size, centred=True) for l, c in zip(tqdm(quadrate_labels), centroids)]
cropped_quadrates = [transform.crop(i, c, window_size, centred=True) for i, c in zip(tqdm(quadrate_imgs), centroids)]

In [None]:
"""
Plot the jaws and labels just to check

"""
from fishjaw.visualisation import images_3d

if debug_plots:
    for img, label in zip(cropped_quadrates, cropped_labels):
        images_3d.plot_slices(img, label)

In [None]:
"""
Create a dataloader for these

"""

import torchio as tio
from fishjaw.model import data


# This is the size of the training data
jaw_config["batch_size"] = 10

# Because we're in Jupyter
jaw_config["num_workers"] = 0

# Turn them all into tio subjects first
subjects = [
    data.imgs2subject(img, label)
    for img, label in zip(cropped_quadrates, cropped_labels)
]

train_subjects = tio.SubjectsDataset(
    subjects[:10], transform=data._transforms(jaw_config["transforms"])
)
val_subjects = tio.SubjectsDataset(
    [subjects[-2]], transform=data._transforms(jaw_config["transforms"])
)
test_subject = subjects[-1]

quadrate_data = data.DataConfig(jaw_config, train_subjects, val_subjects)

In [None]:
"""
Plot the first bit of trainin data just to visualise it

"""

if debug_plots:
    for i, batch in enumerate(quadrate_data.train_data):
        images = batch[tio.IMAGE][tio.DATA]
        masks = batch[tio.LABEL][tio.DATA]
        # Images per batch
        for j, (image, mask) in enumerate(zip(images, masks)):
            fig, _ = images_3d.plot_slices(
                image.squeeze().numpy(), mask.squeeze().numpy()
            )

In [None]:
"""
Plot the ground truth for the test data
"""
_ = images_3d.plot_subject(test_subject)

In [None]:
"""
Train a model from scratch
"""

import torch


def train_new_model(
    data_config: data.DataConfig,
) -> tuple[
    tuple[torch.nn.Module, list[list[float]], list[list[float]]], torch.optim.Optimizer
]:
    """
    Create a new model, train and return it

    Returns the model, the training losses and the validation losses, and the optimiser

    """
    # Create a model and optimiser
    net = model.model(jaw_config["model_params"])
    net = net.to(jaw_config["device"])
    print(f"Model loaded to {jaw_config['device']}")

    optimiser = model.optimiser(jaw_config, net)

    # Define loss function
    loss = model.lossfn(jaw_config)

    train_config = model.TrainingConfig(
        jaw_config["device"],
        jaw_config["epochs"],
        torch.optim.lr_scheduler.ExponentialLR(
            optimiser, gamma=jaw_config["lr_lambda"]
        ),
    )
    return (
        model.train(net, optimiser, loss, data_config, train_config),
        optimiser,
    )

In [None]:
jaw_config["epochs"] = 450
(net, train_losses, val_losses), optimiser = train_new_model(quadrate_data)

In [None]:
"""
Plot the training and validation losses

"""

from fishjaw.visualisation import training

fig = training.plot_losses(train_losses, val_losses)

In [None]:
"""
Plot the test subject
"""

fig = images_3d.plot_inference(
    net,
    test_subject,
    patch_size=data.get_patch_size(jaw_config),
    patch_overlap=(4, 4, 4),
    activation=model.activation_name(jaw_config),
    batch_size=jaw_config["batch_size"],
)

In [None]:
"""
Now fine-tune the trained model on the testing data. it should perform better
"""

from monai.networks.nets.attentionunet import AttentionBlock


def fine_tune_model(
    data_config: data.DataConfig,
    train_layers: str = "1,2",
    lr_multiplier: float = 0.1,  # Lower learning rate for fine-tuning
    epochs_frozen: int = 150,  # Train with frozen layers
    epochs_unfrozen: int = 50,  # Additional training with all layers
    verbose: bool = True,
) -> tuple[torch.nn.Module, list[list[float]], list[list[float]]]:
    """
    Fine-tune a model on the provided data

    :param freeze_layers: The layers to freeze, either:
        - "all": all layers are trainable
        - a comma-separated list of integers: the attention mechanisms layers to freeze

    """
    match train_layers:
        case "all":
            train_all = True
        case _:
            train_all = False
            train_layers = [int(x) for x in train_layers.split(",")]

    # Load the model from disk fresh so that we don't overwrite anything in memory
    new_model = model.load_model(model_name)
    net = new_model.load_model(set_eval=False)
    net.to(jaw_config["device"])

    if train_all:
        ...
    else:
        # Freeze all the parameters
        for param in net.parameters():
            param.requires_grad = False

        # Unfreeze the morphology layers
        attention_block_index = 0
        for module in net.modules():
            if isinstance(module, AttentionBlock):
                for name, submodule in module.named_children():
                    if name == "psi" and attention_block_index in train_layers:
                        if verbose:
                            print("unfreezing")
                        for param in submodule.parameters():
                            param.requires_grad = True
                    elif name == "psi":
                        if verbose:
                            print("not unfreezing")
                attention_block_index += 1

        if verbose:
            for name, param in net.named_parameters():
                if param.requires_grad:
                    print(f"Trainable: {name}")

    # Create a new optimiser that only updates the unfrozen layers
    # Get the right optimiser from the config
    # and set the learning rate to a lower value
    optimiser = getattr(torch.optim, jaw_config["optimiser"])(
        (p for p in net.parameters() if p.requires_grad),
        lr=jaw_config["learning_rate"] * lr_multiplier,
    )

    # Create a loss function
    loss = model.lossfn(jaw_config)

    # Train the model with the frozen layers
    train_config = model.TrainingConfig(
        jaw_config["device"],
        epochs_frozen,
        torch.optim.lr_scheduler.ExponentialLR(
            optimiser, gamma=jaw_config["lr_lambda"]
        ),
    )

    print(f"Training with selective freezing for {epochs_frozen} epochs...")
    return model.train(net, optimiser, loss, data_config, train_config)

In [None]:
"""
First demonstrate what the fine tuning looks like if we basically dont train it at all
"""

fine_tuned_model, fine_tune_train_losses, fine_tune_val_losses = fine_tune_model(
    quadrate_data,
    train_layers="all",
    lr_multiplier=2.0,
    epochs_frozen=2,
    epochs_unfrozen=1,
)

In [None]:
_ = images_3d.plot_inference(
    fine_tuned_model,
    test_subject,
    patch_size=data.get_patch_size(jaw_config),
    patch_overlap=(4, 4, 4),
    activation=model.activation_name(jaw_config),
    batch_size=jaw_config["batch_size"],
)

In [None]:
"""
Now train it more properly
"""

fine_tuned_model, fine_tune_train_losses, fine_tune_val_losses = fine_tune_model(
    quadrate_data,
    train_layers="all",
    lr_multiplier=2.0,
    epochs_frozen=100,
    epochs_unfrozen=50,
)

In [None]:
from fishjaw.visualisation import training

training.plot_losses(fine_tune_train_losses, fine_tune_val_losses)
_ = images_3d.plot_inference(
    fine_tuned_model,
    test_subject,
    patch_size=data.get_patch_size(jaw_config),
    patch_overlap=(4, 4, 4),
    activation=model.activation_name(jaw_config),
    batch_size=jaw_config["batch_size"],
)

In [None]:
def get_weight_deltas(
    model_before, model_after
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
    """
    Get the difference in weights between two models, and the original weights
    """
    deltas = {}
    orig_weights = {}
    for (name1, param1), (name2, param2) in zip(
        model_before.named_parameters(), model_after.named_parameters()
    ):
        assert name1 == name2, f"Names do not match: {name1} != {name2}"

        if param1.requires_grad and "num_batches_tracked" not in name1:
            delta = param2.data - param1.data
            deltas[name1] = delta

            orig_weights[name1] = param1.data
    return deltas, orig_weights

In [None]:
"""
Plot histograms
"""
deltas, orig_weights = get_weight_deltas(jaw_model.load_model().to(jaw_config["device"]), fine_tuned_model)

In [None]:
import textwrap
import numpy as np
import matplotlib.pyplot as plt

bins = np.linspace(
    torch.min(torch.cat(list(d.flatten() for d in deltas.values()))).item(),
    torch.max(torch.cat(list(d.flatten() for d in deltas.values()))).item(),
    100,
)

fig, axes = plt.subplots(
    len(deltas) // 4, 4, figsize=(8, len(deltas) // 4 * 2), sharey=True
)
for axis, (name, delta) in zip(axes.flatten(), tqdm(deltas.items())):
    axis.hist(delta.flatten().cpu().numpy(), bins=bins, density=True)
    axis.set_title("\n".join(textwrap.wrap(name, 30)), fontsize=8)
    # axis.set_yscale("log")

fig.tight_layout()

In [None]:
"""
We want a way to isolate each type of weight (conv, psi, merge, etc)
"""

import re

weight_type_regex = {
    "down_conv_0_weight": r".*conv.0.conv.weight",
    "down_conv_0_bias": r".*conv.0.conv.bias",
    "down_conv_1_weight": r".*conv.1.conv.weight",
    "down_conv_1_bias": r".*conv.1.conv.bias",
    "down_adn_0_weight": r".*conv.0.adn.N.weight",
    "down_adn_0_bias": r".*conv.0.adn.N.bias",
    "down_adn_1_weight": r".*conv.1.adn.N.weight",
    "down_adn_1_bias": r".*conv.1.adn.N.bias",
    "attention_wg_0_weight": r".*attention.W_g.0.conv.weight",
    "attention_wg_0_bias": r".*attention.W_g.0.conv.bias",
    "attention_wg_1_weight": r".*attention.W_g.1.weight",
    "attention_wg_1_bias": r".*attention.W_g.1.bias",
    "attention_wx_0_weight": r".*attention.W_x.0.conv.weight",
    "attention_wx_0_bias": r".*attention.W_x.0.conv.bias",
    "attention_wx_1_weight": r".*attention.W_x.1.weight",
    "attention_wx_1_bias": r".*attention.W_x.1.bias",
    "attention_psi_0_weight": r".*attention.psi.0.conv.weight",
    "attention_psi_0_bias": r".*attention.psi.0.conv.bias",
    "attention_psi_1_weight": r".*attention.psi.1.weight",
    "attention_psi_1_bias": r".*attention.psi.1.bias",
    "upconv_weight": r".*upconv.up.conv.weight",
    "upconv_bias": r".*upconv.up.conv.bias",
    "upconv_adn_weight": r".*upconv.up.adn.N.weight",
    "upconv_adn_bias": r".*upconv.up.adn.N.bias", 
    "merge_weight": r".*merge.conv.weight",
    "merge_bias": r".*merge.conv.bias",
    "merge_adn": r".*merge.adn.A.weight",
}

In [None]:
"""
Make a u-net style diagram
"""

from matplotlib.patches import FancyArrowPatch


def _draw_arrows(fig: plt.Figure, axes: dict) -> None:
    """
    draw arrows
    """
    # Define the skip connections (encoder to bottleneck to decoder)
    skip_connections = [
        ("A", "a", "K"),  # Encoder level 1 → Bottleneck → Decoder level 1
        ("B", "b", "J"),  # Encoder level 2 → Bottleneck → Decoder level 2
        ("C", "c", "I"),  # Encoder level 3 → Bottleneck → Decoder level 3
        ("D", "d", "H"),  # Encoder level 4 → Bottleneck → Decoder level 4
        ("E", "e", "G"),  # Encoder level 5 → Bottleneck → Decoder level 5
    ]

    # Add arrows for skip connections
    arrow_params = dict(
        connectionstyle="arc3,rad=-0.3",
        arrowstyle="simple,head_length=5,head_width=5",
        linewidth=0.5,
        transform=fig.transFigure,
        color="k",
    )
    for encoder, bottleneck, decoder in skip_connections:
        # Get positions of axes
        encoder_pos = axes[encoder].get_position()
        attn_pos = axes[bottleneck].get_position()
        decoder_pos = axes[decoder].get_position()

        # Calculate arrow coordinates
        # Encoder to bottleneck arrow
        x1 = encoder_pos.x1  # Right side of encoder
        y1 = encoder_pos.y0 + 0.75 * encoder_pos.height  # 3/4 up
        x2 = attn_pos.x0  # Left side of bottleneck
        y2 = attn_pos.y0 + 0.5 * attn_pos.height  # Middle

        # Draw arrow
        fig.patches.extend([FancyArrowPatch((x1, y1), (x2, y2), **arrow_params)])

        # Bottleneck to decoder arrow
        x1 = attn_pos.x1  # Right side of bottleneck
        y1 = attn_pos.y0 + 0.5 * attn_pos.height  # Middle
        x2 = decoder_pos.x0  # Left side of decoder
        y2 = decoder_pos.y0 + 0.75 * decoder_pos.height  # 3/4 up

        fig.patches.extend([FancyArrowPatch((x1, y1), (x2, y2), **arrow_params)])


def unet_hists(pattern: str) -> plt.Figure:
    """
    Draw histograms in a u-net shape for weights matching the given pattern
    """
    fig, axes = plt.subplot_mosaic(
        """
        AA.........aKK
        AA..........KK
        .BB.......bJJ.
        .BB........JJ.
        ..CC.....cII..
        ..CC......II..
        ...DD...dHH...
        ...DD....HH...
        ....EE.eGG....
        ....EE..GG....
        ......FF......
        ......FF......
        """,
        figsize=(10, 10),
    )

    _draw_arrows(fig, axes)

    for axis in axes.values():
        axis.set_xticks([])
        axis.set_yticks([])

    return fig

In [None]:
_ = unet_hists(weight_type_regex["down_conv_1_weight"])