In [1]:
from careamics import CAREamist
from careamics.config import create_n2v_configuration
import numpy as np
import logging as log
import sys
import os

# Add the top-level and the script directories to the sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..', 'scripts')))

from helpers import get_paths, ground_truth, normalize_image
from metrics import compute_metrics

In [None]:
log.basicConfig(level=log.INFO)

def load_dataset(data_path):
    """
    Load all images from the dataset into a single NumPy array.

    :param data_path: Path to the root data directory.
    :return: NumPy array of shape (120, 3, 512, 512) containing the images.
    """
    log.info(f"Loading dataset from {data_path}")
    
    # Define the dataset parameters
    num_images = 120
    num_channels = 3
    image_shape = (512, 512)
    
    # Initialize an empty array to hold the dataset
    dataset = np.zeros((num_images, num_channels, *image_shape), dtype=np.float32)
    
    for channel in range(num_channels):
        for image_index in range(1, num_images + 1):
            image_index_str = str(image_index).zfill(3)
            image_path = os.path.join(data_path, f'Image{image_index_str}', f'wf_channel{channel}.npy')
            image = np.load(image_path)
            dataset[image_index - 1, channel, :, :] = image[249, :, :]
    
    return dataset

# Example usage
data_path, output_path = get_paths()
log.info(f"Data path: {data_path}")
log.info(f"Output path: {output_path}")
dataset = load_dataset(data_path)
log.info(f"Dataset shape: {dataset.shape}")


In [3]:
image = np.load('data/raw/Image001/wf_channel0.npy')
noisy_image = image[249, :, :]
ground_truth_image = normalize_image(ground_truth(image))
print(ground_truth_image.shape)
print(noisy_image.shape)


(512, 512)
(512, 512)


In [4]:
dataset_bis = dataset.reshape(360, 1, 512, 512)

log.info("Splitting the dataset into training and validation sets...")
split_ratio = 0.8
split_idx = int(len(dataset_bis) * split_ratio)

seed = 42
np.random.seed(seed)
np.random.shuffle(dataset_bis)
train, val = dataset_bis[:split_idx], dataset_bis[split_idx:]
log.info(f"Training set shape: {train.shape}")
log.info(f"Validation set shape: {val.shape}")

INFO:root:Splitting the dataset into training and validation sets...
INFO:root:Training set shape: (288, 1, 512, 512)
INFO:root:Validation set shape: (72, 1, 512, 512)


In [5]:
config = create_n2v_configuration(
    experiment_name="w2s_n2v_test",
    data_type="array",
    axes="SYX",
    patch_size=(64, 64),
    batch_size=32, # 256, 32
    num_epochs=100, # 1000, 15
    n_channels=1
)

log.info("Initializing CAREamist...")
careamist = CAREamist(
    source=config,
    work_dir='models/noise2void_weights/'
)

print(config)

INFO:root:Initializing CAREamist...


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


{'algorithm_config': {'algorithm': 'n2v',
                      'loss': 'n2v',
                      'lr_scheduler': {'name': 'ReduceLROnPlateau',
                                       'parameters': {}},
                      'model': {'architecture': 'UNet',
                                'conv_dims': 2,
                                'depth': 2,
                                'final_activation': 'None',
                                'in_channels': 1,
                                'independent_channels': True,
                                'n2v2': False,
                                'num_channels_init': 32,
                                'num_classes': 1},
                      'optimizer': {'name': 'Adam',
                                    'parameters': {'lr': 0.0001}}},
 'data_config': {'axes': 'SYX',
                 'batch_size': 32,
                 'data_type': 'array',
                 'patch_size': [64, 64],
                 'transforms': [{'flip_x': True,
    

In [6]:
log.info("Training the model...")
careamist.train(train_source=train.reshape(-1, 512, 512), val_source=val.reshape(-1, 512, 512))     
log.info("Training complete.")

INFO:root:Training the model...
Computed dataset mean: [155.10997], std: [67.09785]
/opt/anaconda3/lib/python3.12/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /Users/gattimartina/ML4Science-CS433/models/noise2void_weights/checkpoints exists and is not empty.

  | Name  | Type | Params | Mode 
---------------------------------------
0 | model | UNet | 509 K  | train
---------------------------------------
509 K     Trainable params
0         Non-trainable params
509 K     Total params
2.037     Total estimated model params size (MB)
39        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/anaconda3/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/opt/anaconda3/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 2:  24%|██▎       | 136/576 [01:40<05:24,  1.36it/s, train_loss_step=0.165, val_loss=0.152, train_loss_epoch=0.216] 


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [None]:
noisy_image.shape

(512, 512)

In [None]:
log.info("Predicting on a noisy image...")

prediction = careamist.predict(
    source=noisy_image.reshape(1, 512, 512),
    batch_size=1,
)

INFO:root:Predicting on a noisy image...
/opt/anaconda3/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]

In [None]:
# Compute metrics
log.info("Computing metrics...")
metrics = compute_metrics(np.array(prediction[0]).squeeze(), ground_truth_image)

# Print the computed metrics
log.info(f"PSNR: {metrics[0]}")
log.info(f"SI-PSNR: {metrics[1]}")
log.info(f"SSIM: {metrics[2]}")

INFO:root:Computing metrics...
INFO:root:Computing metrics for denoised image.
INFO:root:PSNR: 9.55374758808943
INFO:root:SI-PSNR: 15.430167449246257
INFO:root:SSIM: 0.28437857189073207
