# Sample notebook

Use this notebook for local experiments. See `scripts/train.py` for Weights & Biases integration.

In [None]:
%cd ..

In [None]:
from munch import DefaultMunch
import wandb

import numpy as np
import torch
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, ToTensor, Resize
from torchvision.datasets.utils import download_and_extract_archive
import deepinv as dinv

from utils import *

In [None]:
config = DefaultMunch(
    epochs=1,
    batch_size=1,
    lr_init=1e-3,
    seed=0,
)

torch.manual_seed(config.seed)
np.random.seed(config.seed)
device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

Define physics

In [None]:
physics = dinv.physics.Inpainting((3, 256, 256))

Prepare dataloaders

In [None]:
# Download Urban100 dataset
download_and_extract_archive(
    "https://huggingface.co/datasets/eugenesiow/Urban100/resolve/main/data/Urban100_HR.tar.gz?download=true",
    "Urban100",
    filename="Urban100_HR.tar.gz",
    md5="65d9d84a34b72c6f7ca1e26a12df1e4c",
)

train_dataset, test_dataset = random_split(
    ImageFolder(
        "Urban100", transform=Compose([ToTensor(), Resize(256)])
    ),
    (0.8, 0.2),
)

# Prepare dataset of images and measurements
dataset_path = dinv.datasets.generate_dataset(
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    physics=physics,
    device=device,
    save_dir="Urban100",
)

train_dataloader = DataLoader(
    dinv.datasets.HDF5Dataset(dataset_path, train=True), shuffle=True, batch_size=config.batch_size,
)
test_dataloader = DataLoader(
    dinv.datasets.HDF5Dataset(dataset_path, train=False), shuffle=False, batch_size=config.batch_size,
)

Define loss

In [None]:
losses = dinv.loss.SupLoss()

Define model

In [None]:
model = dinv.models.UNet().to(device)

Define trainer and train model

In [None]:
trainer = dinv.training.Trainer(
    model = model,
    physics = physics,
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr_init),
    train_dataloader = train_dataloader,
    eval_dataloader = test_dataloader,
    epochs = config.epochs,
    losses = losses,
    scheduler = None,
    metrics = dinv.loss.PSNR(),
    online_measurements = False,
    ckp_interval = 1000,
    device = device,
    eval_interval = 1,
    save_path = f"models/{wandb.run.id}",
    plot_images = True,
    wandb_vis = False,
)

trainer.train()