In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import tifffile
from careamics_portfolio import PortfolioManager

# from itkwidgets import compare, view  # "pip install itkwidgets "if necessary
from pytorch_lightning import Trainer

from careamics import CAREamicsModule
from careamics.lightning_datamodule import (
    CAREamicsPredictDataModule,
    CAREamicsTrainDataModule,
)
from careamics.lightning_prediction import CAREamicsPredictionLoop

In [None]:
%reload_ext autoreload

### Import Dataset Portfolio


In [None]:
# Explore portfolio
portfolio = PortfolioManager()
print(portfolio.denoising)

Read the specific dataset 
bla 

In [None]:
# Download files
root_path = Path("data")
files = portfolio.denoising.Flywing.download(root_path)
print(f"List of downloaded files: {files}")

In [None]:
data_path = Path(root_path / "denoising-Flywing.unzip")

data_path.mkdir(parents=True, exist_ok=True)

### Visualize the data

In [None]:
train_image = tifffile.imread(next(iter(data_path.rglob("*.tif"))))
print(f"Train image shape: {train_image.shape}")
plt.imshow(np.max(train_image, axis=0), cmap="magma")

### [Optional] Visualize the data in 3D

In [None]:
# View 3D image
view(train_image)

### Initialize the Model

Create a Pytorch Lightning module

Please take as look at the [documentation](https://careamics.github.io) to see the full list of parameters and configuration options

In [None]:
# N2V2 requires changes to the UNet model and to the Dataset (augmentations)
use_n2v2 = False  # change to True to use N2V2

In [None]:
model = CAREamicsModule(
    algorithm="n2v",
    loss="n2v",
    architecture="UNet",
    model_parameters={"n2v2": use_n2v2, "conv_dims": 3},
    optimizer_parameters={"lr": 1e-3},
    lr_scheduler_parameters={"factor": 0.5, "patience": 10},
)

### Initialize the datamodule

The data module can take a `Path` or `str` to a folder or file, or a `np.ndarray`.

For custom types, you need to pass a read function and an extension_filter.

In [None]:
train_data_module = CAREamicsTrainDataModule(
    train_data=train_image,
    data_type="array",  # to use np.ndarray, set data_type to "array"
    patch_size=(32, 64, 64),
    axes="ZYX",
    batch_size=32,
    dataloader_params={"num_workers": 0},
    use_n2v2=use_n2v2,
    struct_n2v_axis="none",  # choice between "horizontal", "vertical", or "none" (no # structN2V)
    struct_n2v_span=7,
)

### Run training 

We need to specify the paths to training and validation data

In [None]:
trainer = Trainer(max_epochs=1, default_root_dir="bsd_test")

In [None]:
trainer.fit(model, datamodule=train_data_module)

### Define a prediction datamodule

In [None]:
pred_data_module = CAREamicsPredictDataModule(
    pred_data=train_image[:, :128, :128],
    data_type="array",
    tile_size=(32, 64, 64),
    tile_overlap=(16, 48, 48),
    axes="ZYX",
    batch_size=1,
    tta_transforms=True,
    dataloader_params={"num_workers": 0},
)

### Run prediction

First, we want to use CAREamics prediction loop, which allows tiling:

In [None]:
tiled_loop = CAREamicsPredictionLoop(trainer)
trainer.predict_loop = tiled_loop

Then, we predict using the datamodule.

In [None]:
preds = trainer.predict(model, datamodule=pred_data_module)

### Visualize predictions


In [None]:
print(f"Train image shape: {preds.shape}")
plt.imshow(np.max(preds.squeeze(), axis=0), cmap="magma")

### [Optional] Visualize predictions in 3D

In [None]:
compare(train_image, preds.squeeze())