### Notebook for training

Make sure you have run the download_data script (or notebook) beforehand.

Let's first begin by importing from our source files:

In [None]:
from src.dataloader import Dataloader
from src.train import Trainer
from src.model import UNet

Then we can load the datasets for training and validation:

In [None]:
image_dir = "../data/div2k-hr-train/"
texture_dir = "../data/textures/"
color_mode = "rgb"

dataloader = Dataloader(
    image_dir,
    texture_dir,
    image_size=(32 * 21, 32 * 32),  # = (672, 1024)
    batch_size=8,
    color_mode=color_mode,
    validation_split=0.2,
)

train_ds, val_ds = dataloader.load_datasets(
    noisy=True,
    textured=True,
    texture_alpha=0.1,
    shuffle=True,
)

dataloader.show_samples(2)

Now we are prepared to train the model:

In [None]:
# instantiate the model
model = UNet(color_mode=color_mode)

# instantiate the trainer
trainer = Trainer(
    model=model,
    train_ds=train_ds,
    val_ds=val_ds,
    checkpoint_filepath="../checkpoints/checkpoint.weights.h5",
    epochs=10,
    learning_rate=1e-3,
)

# run the training loop
trainer.train(num_samples_shown=4)