In [3]:
from aiutils import CamouflagedAnimalsModel
from dataset import CamouflagedAnimalsDataset, colorMaskToOneHot

import lightning as L
from lightning.pytorch.callbacks import ModelSummary
import torch as T
from torchvision.transforms import v2 as TV

from torchview import draw_graph
from lightning.pytorch.callbacks import Callback

class VisualizeModel(Callback):
    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
        # transforms = [ hl.transforms.Prune('Constant') ]
        
        # graph = hl.build_graph(pl_module, batch[0], transforms=transforms)
        # print(hl.graph.THEMES)
        # graph.theme = hl.graph.THEMES["blue"].copy()
        # graph.save("model_graph", format="png")
        pass
        

# enable TPU
T.set_float32_matmul_precision("medium")


# checkpoint = "lightning_logs/version_17/checkpoints/epoch=16-step=289.ckpt"
checkpoint = None

if checkpoint is None:
    model_lightning = CamouflagedAnimalsModel()
else:
    model_lightning = CamouflagedAnimalsModel.load_from_checkpoint(checkpoint)

common_transform = TV.Compose(
    [
        TV.RandomHorizontalFlip(0.5),
        TV.Resize((256, 256), interpolation=TV.InterpolationMode.NEAREST),
        TV.ToDtype(T.float, scale=True),
    ]
)
image_transform = TV.Compose(
    [TV.ColorJitter(brightness=0.3, hue=0.15), TV.RandomEqualize(0.5)]
)
mask_transform = TV.Compose([colorMaskToOneHot])

dataset = CamouflagedAnimalsDataset(
    images_path="images",
    masks_path="masks",
    common_transform=common_transform,
    image_transform=image_transform,
    mask_transform=mask_transform,
)
seed = T.Generator().manual_seed(42)
train_set, valid_set = T.utils.data.random_split(dataset, [0.9, 0.1], generator=seed)
train_loader = T.utils.data.DataLoader(
    train_set,
    batch_size=15,
    shuffle=True,
    num_workers=4,
    persistent_workers=True,
)

graph = draw_graph(model_lightning, input_size=(20,3,512,512), expand_nested=True)
graph.visual_graph.render(format='png')

valid_loader = T.utils.data.DataLoader(
    valid_set,
    batch_size=10,
    shuffle=False,
    num_workers=4,
    persistent_workers=True,
)

trainer = L.Trainer(
    limit_train_batches=250,
    accelerator="auto",
    devices="auto",
    strategy="auto",
    max_epochs=9999,
    callbacks=[ModelSummary(max_depth=2), VisualizeModel()],
    log_every_n_steps=5,
)
trainer.fit(
    model=model_lightning,
    train_dataloaders=train_loader,
    val_dataloaders=valid_loader,
)