In [None]:
import torch
import lightning.pytorch as pl
from torchinfo import summary
from lightning.pytorch.tuner import Tuner

In [None]:
torch.set_float32_matmul_precision('medium')

In [None]:
from data import OxfordPetDatamodule
from model import LitUNet
from transformation import mask_transforms,img_trasfroms

In [None]:
dm = OxfordPetDatamodule(transforms=img_trasfroms,mask_transforms=mask_transforms,batch_size=16)
dm.setup()

In [None]:
model = LitUNet(
    inchannels=3,
    outchannels=3,
    expansion_mode='upsample',
    contraction_mode='strided_conv',
    channels_list=[16,32,64,128],
    loss_fn='dice',
    epoch=15,
    lr=1e-4,
    scheduler_step=len(dm.train_dataloader()),
)

In [None]:
summary(model=model, 
        input_size=(16, 3,256,256),
        col_names=[ "input_size", "output_size", "num_params", "params_percent", "kernel_size", "trainable"],
        row_settings=["depth", "var_names"]
)

In [None]:
trainer = pl.Trainer(
    fast_dev_run=False,
    max_epochs=15,
    enable_model_summary=False,
    enable_progress_bar=True,
    # callbacks = [
    #     # EarlyStopping(monitor='train_loss',mode='min',patience=3,verbose=False,check_on_train_epoch_end=True,check_finite=True,),
    #     # LearningRateMonitor(logging_interval='step'),
    #     # ModelCheckpoint(dirpath="experiment",monitor='train_loss',enable_version_counter=True), 
    #     # ModelSummary(max_depth=-1),
    #     # CSVLogger(save_dir='experiments/')
    # ],
    log_every_n_steps=30,
    precision='32'
)

In [None]:
# tuner = Tuner(trainer)
# lr_finder = tuner.lr_find(model=model, datamodule=dm)
# maxlr = lr_finder.suggestion()
# fig = lr_finder.plot(suggest=True)
# model.lr = maxlr
# model.max_lr = maxlr

In [None]:
trainer.fit(model=model,datamodule=dm)
trainer.validate(model=model,dataloaders=dm)