In [None]:
import sys
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import torch
from tqdm.auto import tqdm

import arpesnet as an

warnings.filterwarnings("ignore")

print("Python", sys.version)
GPU_ENABLED = torch.backends.mps.is_available() or torch.cuda.is_available()
print(f"Pytorch version: {torch.__version__} | GPU enabled = {GPU_ENABLED}")

#### WARNING: If running on Apple MX, and you get a `NotImplementedError`, please restart the kernel and run the notebook again, uncommenting the cell below.

In [None]:
# if torch.backends.mps.is_available():
#     os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

# Set random seed for reproducibility

In [None]:
torch.random.manual_seed(42)

# load data

### Set the path to the data directory
the `root` variable should be set to the path of the directory containing the data files. it should contain:
- test_datasets (directory containing the test images per material)
- train_datasets (directory containing the training images per material)
- test_imgs.pt

In [None]:
root = Path(r"path/to/your/data/folder/")
assert root.exists()

# setup data transformers

## image resizing and normalization
the `transform` variable should be set to the transformation to be applied to the images. it should contain:

In [None]:
normalizeAndResize = an.transform.Compose(
    [
        an.transform.Resize((256, 256)),
        an.transform.NormalizeMinMax(0, 100),
    ]
)

## possionian noise simulation

In [None]:
N = 100_000 # number of counts in the Poisson noise
setExposure = an.transform.SetRandomPoissonExposure(N)

# Load test images
for visual inspection

In [None]:
test_imgs = torch.load(root/"test_imgs.pt")
test_imgs = torch.stack([normalizeAndResize(s) for s in test_imgs])
test_imgs_noisy = torch.stack([setExposure(s) for s in test_imgs])

In [None]:
fig, ax = plt.subplots(2, len(test_imgs), figsize=(8, 3))
for i in range(len(test_imgs)):
    ax[0, i].imshow(test_imgs[i].numpy(), cmap="viridis", origin="lower")
    ax[1, i].imshow(test_imgs_noisy[i].numpy(), cmap="viridis", origin="lower")
    counts = test_imgs_noisy[i].sum()
    counts_per_pixel = counts / test_imgs[i].nelement()
    ax[1, i].set_title(f"N: {counts:,.0f}\nN/px: {counts_per_pixel:.2f}", fontsize=8)
    ax[0, i].axis("off")
    ax[1, i].axis("off")

# load training dataset

In [None]:
all_files = list((root/"train_data").glob("*.pt"))
train_data = torch.stack([normalizeAndResize(torch.load(f)) for f in tqdm(all_files)]).view(-1, 256, 256)
print(f"loaded {len(train_data):,.0f} training images with shape {train_data.shape[1:]}")

# setup training configuration

In [None]:
config = an.load_config(Path(an.__file__).parent.parent / "config.yml")
input_shape = [256,256]
norm = [0,100]
config["model"]["aenc"] = "arpesnet"
config["model"]["kwargs"] = dict(
    kernel_size = 11,
    kernel_decay = 2,
    n_layers = 1,
    start_channels = 4,
    max_channels = 32,
    n_blocks = 6,
    input_shape = input_shape,
    relu="PReLU",
    relu_kwargs=dict(num_parameters=1, init=0.25)
)
config["model"]["input_shape"] = input_shape

config['preprocessing']['Resize'] = input_shape
config['preprocessing']['NormalizeMinMax'] = norm

config['training_augmentations']["NormalizeMinMax"] = norm
config['training_augmentations']["RandomResizedCrop"]["size"] = input_shape

config["validation_augmentations"]["NormalizeMinMax"] = norm
config["validation_augmentations"]["Resize"] = input_shape

config["noise_augmentations"]["NormalizeMinMax"] = norm
config["noise_augmentations"]["SetRandomPoissonExposure"] = [50_000,100_000_000]

config["loss"]["criteria"] = ["mse"]

config["optimizer"]["name"] = "Adam"
config["optimizer"]["lr"] = 0.001

config["train"]["batch_size"] = 32
config["train"]["denoiser"] = False
config["train"]["shuffle"] = True
config["train"]["drop_last"] = True

In [None]:
trainer = an.ModelTrainer(config, verbose="full", train_dataset=train_data)
trainer.describe_model()

In [None]:
trainer.train(
    n_epochs=2,
    milestones=[2, 4, 6],
    milestone_every=10,
    save_dir="./",
    plot=True,
    save=True,
    test_imgs=test_imgs_noisy,
)


# test visualize and evaluate model

## load test data

In [None]:
all_test_files = list((root/"test_data").glob("*.pt"))
test_data = torch.stack([normalizeAndResize(torch.load(f)) for f in tqdm(all_test_files[:10])]).view(-1, 256, 256)
test_data = test_data[::10]
print(f"loaded {len(test_data):,.0f} test images with shape {test_data.shape[1:]}")

## clean data

In [None]:
trainer.plot_loss_and_reconstruction(test_imgs)

In [None]:
trainer.test_model(test_data, metrics=['mse','psnr']).mean()

## noisy data

In [None]:
test_data_noisy = torch.stack([setExposure(s) for s in test_data])

In [None]:
trainer.plot_loss_and_reconstruction(test_imgs_noisy)

In [None]:
trainer.test_model(test_data_noisy, metrics=['mse','psnr']).mean()