In [None]:
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import tifffile
from careamics_portfolio import PortfolioManager

from careamics import CAREamist

### Import Dataset Portfolio

In [None]:
# Explore portfolio
portfolio = PortfolioManager()
print(portfolio.denoising)

In [None]:
# Download files
root_path = Path("./data")
files = portfolio.denoising.N2N_SEM.download(root_path)
print(f"List of downloaded files: {files}")

### Visualize training data

In [None]:
# Load images
train_image = tifffile.imread(files[0])
print(f"Train image shape: {train_image.shape}")

# Display images
side = int(np.ceil(np.sqrt(train_image.shape[0])))
fig, ax = plt.subplots(side, side, figsize=(15, 15))

for i in range(train_image.shape[0]):
    ax.flat[i].imshow(train_image[i], cmap="gray")
    ax.flat[i].axis("off")

### Visualize validation data

In [None]:
val_image = tifffile.imread(files[2])
print(f"Validation image shape: {val_image.shape}")

# Display images
side = int(np.ceil(np.sqrt(val_image.shape[0])))
fig, ax = plt.subplots(side, side, figsize=(15, 15))
for i in range(val_image.shape[0]):
    ax.flat[i].imshow(val_image[i], cmap="gray")
    ax.flat[i].axis("off")

In [None]:
# Set paths

data_path = Path(root_path / "n2n_sem")
train_path = data_path / "train"
test_path = data_path / "val"

train_path.mkdir(parents=True, exist_ok=True)
test_path.mkdir(parents=True, exist_ok=True)

shutil.copy(root_path / files[0], train_path / "train_image.tif")
shutil.copy(root_path / files[1], test_path / "test_image.tif")

#### Initialize the Model

Create a Pytorch Lightning module

Please take as look at the [documentation](https://careamics.github.io) to see the full list of parameters and configuration options

In [None]:
engine = CAREamist(source="n2n_2D_SEM.yml")

### Part 3. Run training 

We need to specify the paths to training and validation data

In [None]:
engine.train(
    train_source=train_image[0],
    val_source=train_image[1],
    train_target=train_image[2],
    val_target=train_image[3],
)

### Run prediction


In [None]:
preds = engine.predict(source=val_image[0], tile_size=(256, 256))

### Visualize the prediction

In [None]:
fi, ax = plt.subplots(1, 2, figsize=(15, 15))
ax[0].imshow(preds[0].squeeze(), cmap="gray")
ax[0].set_title("Prediction")
ax[1].imshow(val_image[0].squeeze(), cmap="gray")
ax[1].set_title("Ground Truth")