In [None]:
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import tifffile
from careamics_portfolio import PortfolioManager
from pytorch_lightning import Trainer

from careamics import CAREamicsModule
from careamics.lightning_datamodule import (
    CAREamicsPredictDataModule,
    CAREamicsTrainDataModule,
)
from careamics.lightning_prediction import CAREamicsPredictionLoop

### Import Dataset Portfolio

In [None]:
# Explore portfolio
portfolio = PortfolioManager()
print(portfolio.denoising)

In [None]:
# Download files
root_path = Path("./data")
files = portfolio.denoising.N2V_SEM.download(root_path)
print(f"List of downloaded files: {files}")

### Visualize training data

In [None]:
# Load images
train_image = tifffile.imread(files[0])
print(f"Train image shape: {train_image.shape}")
plt.imshow(train_image, cmap="gray")

### Visualize validation data

In [None]:
val_image = tifffile.imread(files[1])
print(f"Validation image shape: {val_image.shape}")
plt.imshow(val_image, cmap="gray")

In [None]:
data_path = Path(root_path / "n2v_sem")
train_path = data_path / "train"
val_path = data_path / "val"

train_path.mkdir(parents=True, exist_ok=True)
val_path.mkdir(parents=True, exist_ok=True)

shutil.copy(root_path / files[0], train_path / "train_image.tif")
shutil.copy(root_path / files[1], val_path / "val_image.tif")

### Initialize the Model

Please take as look at the [documentation](https://careamics.github.io) to see the full list of parameters and configuration options

In [None]:
# N2V2 requires changes to the UNet model and to the Dataset (augmentations)
use_n2v2 = False  # change to True to use N2V2

In [None]:
model = CAREamicsModule(
    algorithm="n2v",
    loss="n2v",
    architecture="UNet",
    model_parameters={"n2v2": False},
    optimizer_parameters={"lr": 1e-3},
    lr_scheduler_parameters={"factor": 0.5, "patience": 10},
)

### Initialize the datamodule

In [None]:
train_data_module = CAREamicsTrainDataModule(
    train_data=train_path,
    val_data=val_path,
    data_type="tiff",
    patch_size=(64, 64),
    axes="YX",
    batch_size=128,
    dataloader_params={"num_workers": 0},
    use_n2v2=use_n2v2,
)

### Run training 

We need to specify the paths to training and validation data

In [None]:
trainer = Trainer(max_epochs=1, default_root_dir="sem_n2v2_test_struct")

In [None]:
trainer.fit(model, datamodule=train_data_module)

### Define the prediction datamodule

In [None]:
pred_data_module = CAREamicsPredictDataModule(
    pred_data=train_path,
    data_type="tiff",
    tile_size=(256, 256),
    axes="YX",
    batch_size=1,
    tta_transforms=True,
)

### Run prediction

We need to specify the path to the data we want to denoise

In [None]:
tiled_loop = CAREamicsPredictionLoop(trainer)
trainer.predict_loop = tiled_loop

In [None]:
preds = trainer.predict(model, datamodule=pred_data_module)

### Visualize the prediction

In [None]:
image_idx = 0
_, subplot = plt.subplots(1, 2, figsize=(10, 10))

subplot[0].imshow(preds[0].squeeze(), cmap="gray")
subplot[0].set_title("Prediction")
subplot[1].imshow(train_image, cmap="gray")
subplot[1].set_title("Initial image")