In [1]:
from pathlib import Path

import tifffile
from typing import Union
import matplotlib.pyplot as plt
from pytorch_lightning import Trainer
import albumentations as Aug

from careamics_portfolio import PortfolioManager
from careamics.prediction import stitch_prediction
from careamics.lightning_module import (
    CAREamicsModule,
    CAREamicsTrainValDataModule,
    CAREamicsPredictDataModule,
    CAREamicsFiring,
    predict_tiled_simple,
)
from careamics.config.data import DataModel
from careamics.utils.metrics import psnr
from careamics.dataset.dataset_utils import read_tiff
from careamics.utils.transforms import ManipulateN2V



ImportError: cannot import name 'predict_tiled_simple' from 'careamics.lightning_module' (/home/igor.zubarev/projects/caremics/src/careamics/lightning_module.py)

In [2]:
import numpy as np

def calculate_border_distances(radius):
    """
    Calculate the distances from the center of a circle with radius R to every pixel on its border.
    
    Parameters:
    - radius: The radius of the circle.
    
    Returns:
    A list of distances from the center to each border pixel.
    """
    center = (radius + 1, radius + 1)  # Offset to avoid 0 index and fit the circle in the grid
    grid_size = 2 * radius + 3  # Ensure the circle fits within the grid
    border_distances = []
    
    for x in range(grid_size):
        for y in range(grid_size):
            # Calculate distance from this pixel to the center
            distance = np.sqrt((x - center[0])**2 + (y - center[1])**2)
            
            # Check if this pixel is on the border of the circle
            if radius - 0.5 < distance < radius + 0.5:
                border_distances.append(distance)
                
    return border_distances

# Example usage
radius = 10
border_distances = calculate_border_distances(radius)
border_distances[:10]  # Display the first 10 distances


[10.44030650891055,
 10.198039027185569,
 10.04987562112089,
 10.0,
 10.04987562112089,
 10.198039027185569,
 10.44030650891055,
 10.295630140987,
 9.848857801796104,
 9.848857801796104]

### Import Dataset Portfolio

In [None]:
# Explore portfolio
portfolio = PortfolioManager()
print(portfolio.denoising)

In [None]:
# Download and unzip the files
root_path = Path("data")
files = portfolio.denoising.N2V_BSD68.download(root_path)
print(f"List of downloaded files: {files}")

In [None]:
data_path = Path(root_path / "denoising-N2V_BSD68.unzip/BSD68_reproducibility_data")
train_path = data_path / "train"
val_path = data_path / "val"
test_path = data_path / "test" / "images"
gt_path = data_path / "test" / "gt"

train_path.mkdir(parents=True, exist_ok=True)
val_path.mkdir(parents=True, exist_ok=True)
test_path.mkdir(parents=True, exist_ok=True)
gt_path.mkdir(parents=True, exist_ok=True)

### Visualize training data

In [None]:
train_image = tifffile.imread(next(iter(train_path.rglob("*.tiff"))))[0]
print(f"Train image shape: {train_image.shape}")
plt.imshow(train_image, cmap="gray")

### Visualize validation data

In [None]:
val_image = tifffile.imread(next(iter(val_path.rglob("*.tiff"))))[0]
print(f"Validation image shape: {val_image.shape}")
plt.imshow(val_image, cmap="gray")

### Initialize the Model

Create a Pytorch Lightning module

Please take as look at the [documentation](https://careamics.github.io) to see the full list of parameters and configuration options

In [None]:
model = CAREamicsModule(
    algorithm="n2v",
    loss="n2v",
    architecture="UNet",
)


### Define the Transforms

In [None]:
transforms = Aug.Compose(
    [Aug.Flip(), Aug.RandomRotate90(), Aug.Normalize(), ManipulateN2V()],
)

### Initialize the datamodule

In [None]:
train_data_module = CAREamicsTrainValDataModule(
    train_path=train_path,
    val_path=val_path,
    data_type="tiff",
    patch_size=(64, 64),
    axes="SYX",
    batch_size=128,
    transforms=transforms,
    num_workers=4,
)

### Run training 

We need to specify the paths to training and validation data

In [None]:
trainer = Trainer(max_epochs=1)

In [None]:
trainer.fit(model, datamodule=train_data_module)

### Define a prediction datamodule

In [None]:
transforms_predict = Aug.Compose(
    [Aug.Normalize()],
)

In [None]:
pred_data_module = CAREamicsPredictDataModule(
    pred_path=test_path,
    data_type="tiff",
    tile_size=(256, 256),
    axes="YX",
    batch_size=1,
    num_workers=0,
    transforms=transforms_predict,
)

### Run prediction

We need to specify the path to the data we want to denoise

In [None]:
tiled_loop = CAREamicsFiring(trainer)

In [None]:
trainer.predict_loop = tiled_loop

In [None]:
preds = trainer.predict(model, datamodule=pred_data_module)

### Visualize results and compute metrics


In [None]:
# Create a list of ground truth images

gts = [tifffile.imread(f) for f in sorted(gt_path.glob("*.tiff"))]

In [None]:
# Plot single image

image_idx = 0
_, subplot = plt.subplots(1, 2, figsize=(10, 10))

subplot[0].imshow(preds[image_idx].squeeze(), cmap="gray")
subplot[0].set_title("Prediction")
subplot[1].imshow(gts[image_idx], cmap="gray")
subplot[1].set_title("Ground truth")

In [None]:
# Calculate PSNR for single image

psnr_single = psnr(gts[image_idx], preds[image_idx].squeeze())
print(f"PSNR for image {image_idx}: {psnr_single}")

In [None]:
psnr_total = 0

for pred, gt in zip(preds, gts):
    psnr_total += psnr(gt, pred)

print(f"PSNR total: {psnr_total / len(preds)}")