In [None]:
import os
from pytorch_lightning import loggers as pl_loggers
import pytorch_lightning as pl
from image_inpainting.datamodule.tiny_image_net_data_module import TinyImageNetDataModule
from pytorch_lightning.callbacks import ModelCheckpoint
from image_inpainting.model.context_encoder import ContextEncoder
from image_inpainting.utils import print_results_images

In [None]:
from datetime import datetime
now = datetime.now()
now = now.strftime("%Y-%m-%d_%H-%M-%S")

## Create datamodule

In [None]:
data_dir = "data"

dm = TinyImageNetDataModule(
    data_dir=os.path.join(data_dir, "tiny-imagenet-200"), 
    batch_size_train=128,
    batch_size_val=128,
    batch_size_test=128,
    num_workers=10, 
    pin_memory=True, 
    persistent_workers=True
)

## Create a ContextEncoder model from scratch

In [None]:
model = ContextEncoder(input_size=(3, 128, 128), hidden_size=4000, save_image_per_epoch=True)

## Or load it from a checkpoint

In [None]:
# model = ContextEncoder.load_from_checkpoint("checkpoints/tiny_imagenet/2024-12-16_14-34-46-epoch=76-val_loss=0.32.ckpt") # change the path to your checkpoint
# model.enable_save_image_per_epoch()
# model.to("cuda")

## Train it

In [None]:
checkpoint_callback = ModelCheckpoint(
    dirpath='checkpoints/tiny_imagenet',
    filename=now+'-{epoch:02d}-{val_loss:.2f}',
    monitor='val_loss',
    save_top_k=-1,  # Save all checkpoints
    every_n_epochs=1  # Save checkpoint every n epochs
)

tb_logger = pl_loggers.TensorBoardLogger("Context_Encoder_Inpainting")
trainer = pl.Trainer(max_epochs=300, devices=-1, accelerator="cuda", logger=tb_logger, callbacks=[checkpoint_callback])

In [None]:
trainer.fit(model, dm)

A notebook "tensorboard" exists if you want to check how the metrics evolve during training

## Display some images and evaluate the model performances

Here the results of this cell are after 100 epochs on the Tiny Image Net dataset (10x less images compared to Image Net) in 64x64

- **Number of steps**: 19 599
- **Time**: 3h45
- **Observation (with tensorboard)**: Note that this dataset is quite small compared to the full Image Net, we have 100 000 images. We can clearly see the overfitting after step 2 000 (after around 23 min, around epoch 10). So in this cell we will plot the last epoch and then in the following one we will plot the epoch with the best validation score.

In [None]:
trainer.test(model, dm)

x, y = next(iter(dm.test_dataloader()))
    
x = x.to(model.device)
y = y.to(model.device)

out = model.forward(x)

print_results_images(x, y, out, "Results on test set", dm.inverse_transform)

dm.setup("fit") # in case "fit" wasn't called before this cell

x, y = next(iter(dm.train_dataloader()))
x = x.to(model.device)
y = y.to(model.device)
out = model.forward(x)

print_results_images(x, y, out, "Results on training set", dm.inverse_transform)

## Results on the best validation loss model

Here the results of this cell are after 10 epochs on the Tiny Image Net dataset (10x less images compared to Image Net) in 64x64.

- **Number of steps**: around 2 000
- **Time**: around 20min

In [None]:
trainer.test(model, dm)

x, y = next(iter(dm.test_dataloader()))
    
x = x.to(model.device)
y = y.to(model.device)

out = model.forward(x)

print_results_images(x, y, out, "Results on test set", dm.inverse_transform)

dm.setup("fit") # in case "fit" wasn't called before this cell

x, y = next(iter(dm.train_dataloader()))
x = x.to(model.device)
y = y.to(model.device)
out = model.forward(x)

print_results_images(x, y, out, "Results on training set", dm.inverse_transform)

## Comparison

The first epochs are still blurry, so even if the loss is lower, the results is actually less realistic. On the other hand the last epoch is less blurry there is still some noise. It's especially visible in the logged image (validation set) that are not displayed here but that you can find using the tensorboard notebook and the GIF create_output_per_epoch_animated_result notebook.

The other tests (other notebooks) will try with a bigger dataset (Image Net) and with different parameters on Tiny Image Net