# Libraries

In [None]:
import yaml
from pathlib import Path

from src.utils import  get_data_info, train_val_test_split, focal_loss_with_logits, masked_smoothl1
from src.classes import NeuronPatchDataset, SwimUNETR_Heatmap_offsets

import tifffile
import os
from torch.utils.data import DataLoader

# Main

### read config

In [None]:
config_yaml_path = "config.yaml"

config_path = Path(config_yaml_path)

with config_path.open("r", encoding="utf-8") as f:
    cfg = yaml.safe_load(f) 

data_path = cfg["dir"]["input"]
output_dir = cfg["dir"]["output"]

tiff_suffixes = cfg["input_data"]["tiff_suffixes"]
marker_suffix = cfg["input_data"]["marker_suffix"]

ground_truth_config = cfg["ground_truth_config"]

pre_processing_config = cfg["pre_processing_config"]

print(data_path)
print(output_dir)

print(tiff_suffixes)
print(marker_suffix)

print(ground_truth_config)
print(pre_processing_config)

### create dataframe with all the data info

In [None]:
df = get_data_info(
    data_path=data_path,
    tiff_suffixes=tiff_suffixes,
    marker_suffix=marker_suffix
)

In [None]:
df

In [None]:
# from collections import Counter

# Counter(df["n_neurons"])

In [None]:
# df[df["n_neurons"] == 748]

### split the data into train, val and test

In [None]:
train_df, eval_df, test_df = train_val_test_split(
    df=df, 
    n_neurons_bins=5
)

print(train_df.shape, eval_df.shape, test_df.shape)

In [None]:
# train[train["tiff_img_name"] == "SST_11_20.tif"]

# train.iloc[34]

# initialise the NeuronPatchDataset object

In [None]:
train_ds = NeuronPatchDataset(
    df=train_df,
    ground_truth_config=ground_truth_config,
    pre_processing_config=pre_processing_config,
)

eval_ds = NeuronPatchDataset(
    df=eval_df,
    ground_truth_config=ground_truth_config,
    pre_processing_config=pre_processing_config,
)

In [None]:
# index = 34
# data = train_ds.__getitem__(index=index)

In [None]:
# for key, gt in data.items():
#     tifffile.imwrite(
#         os.path.join(output_dir,key+f"_{index}_.tiff"),
#         data=gt.numpy()
#     )

# initialise the dataloader object

In [None]:
train_dl = DataLoader(
    train_ds, 
    batch_size=1, 
    shuffle=True, 
    num_workers=4
)

eval_dl = DataLoader(
    eval_ds,
    batch_size=1,
    num_workers=4
)


### initialise the model

In [None]:
model = SwimUNETR_Heatmap_offsets()

### initialise optimiser and accelerator

In [None]:
from accelerate import Accelerator
from torch.optim import AdamW

optimiser = AdamW(
    params=model.parameters(),
    lr=1e-4,
    weight_decay=1e-4
)

accelerator = Accelerator()
train_dl, eval_dl, model, optimiser = accelerator.prepare(
    train_dl, eval_dl, model, optimiser, device_placement=[True, True, True, True]
)

### initialise the schedular

In [None]:
from transformers import get_scheduler

num_epochs = 200
num_training_steps = num_epochs * len(train_dl)

num_warmup_steps = int(0.1 * num_training_steps)

lr_schedular = get_scheduler(
    name="cosine",
    optimizer=optimiser,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

### Training the model

In [None]:
import mlflow
import mlflow.pytorch
from tqdm.notebook import tqdm
import torch


progress_bar = tqdm(range(num_training_steps))

def train():
    remote_server_uri="http://127.0.0.1:8080"
    mlflow.set_tracking_uri(remote_server_uri)
    mlflow.set_experiment("/NeuroCloud")

    with mlflow.start_run():
        for epoch in range(num_epochs):
            # Ensure model is in training mode
            model.train()
            train_epoch_loss = []

            for batch in train_dl:
                img = batch["image"] # (B,1,Z,Y,X)
                heatmap = batch["heatmap"]  # (B,1,Z,Y,X)
                offsets  = batch["offsets"] # (B,3,Z,Y,X)
                offset_mask = batch["offset_mask"] # (B,1,Z,Y,X)

                heatmap_logits, offsets_pred = model(img)


                progress_bar.update(1)
                loss_hmap = focal_loss_with_logits(
                    logits=heatmap_logits,
                    target=heatmap
                )
                loss_offsets = masked_smoothl1(
                    pred=offsets_pred,
                    target=offsets,
                    mask=offset_mask
                )

                loss = loss_hmap + 1.2 * loss_offsets

                accelerator.backward(loss)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimiser.step()
                lr_schedular.step()
                optimiser.zero_grad()

                # METRICS COMPUTATION (after backward pass is complete)
                with torch.no_grad():
                    train_epoch_loss.append(loss.detach().item())

            # COMPUTE TRAINING METRICS
            train_avg_loss = sum(train_epoch_loss)/len(train_epoch_loss)
            mlflow.log_metric(
                key="loss",
                value=train_avg_loss,
                step=epoch
            )

In [None]:
from accelerate import notebook_launcher

notebook_launcher(train, num_processes=1)