In [1]:
import torch
from matplotlib import pyplot as plt
import lightning.pytorch as pl
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import LearningRateFinder,LearningRateMonitor,EarlyStopping,StochasticWeightAveraging,ModelPruning,ModelSummary,ModelCheckpoint
from torchinfo import summary

In [2]:
torch.set_float32_matmul_precision('medium')

In [3]:
from data import OxfordPetDatamodule
from model import UNet
from transformation import mask_transforms,img_trasfroms

In [4]:
dm = OxfordPetDatamodule(transforms=img_trasfroms,mask_transforms=mask_transforms,batch_size=16)
dm.setup()

In [5]:
model = UNet(
    inchannels=3,
    outchannels=3,
    expansion_mode='upsample',
    contraction_mode='maxpool',
    channels_list=[16,32,64,128],
    epoch=5,
    lr=1e-4,
    scheduler_step=len(dm.train_dataloader()),
)

In [6]:
summary(model=model, 
        input_size=(62,3,256,256),
        col_names=[ "input_size", "output_size", "num_params", "params_percent", "kernel_size", "trainable"],
        row_settings=["depth", "var_names"]
)

Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Param %                   Kernel Shape              Mult-Adds                 Trainable
UNet (UNet)                              [62, 3, 256, 256]         [62, 3, 256, 256]         --                             --                   --                        --                        True
├─Encoder (encoder1): 1-1                [62, 3, 256, 256]         [62, 16, 128, 128]        --                             --                   --                        --                        True
│    └─Sequential (conv): 2-1            [62, 3, 256, 256]         [62, 16, 256, 256]        --                             --                   --                        --                        True
│    │    └─Conv2d (0): 3-1              [62, 3, 256, 256]         [62, 16, 256, 256]        448                         0.08%                   [3, 3]                    1,820,327,936   

In [7]:
trainer = pl.Trainer(
    fast_dev_run=False,
    max_epochs=15,
    enable_model_summary=False,
    enable_progress_bar=True,
    callbacks = [
        # EarlyStopping(monitor='train_loss',mode='min',patience=3,verbose=False,check_on_train_epoch_end=True,check_finite=True,),
        LearningRateMonitor(logging_interval='step',),
        # ModelCheckpoint(dirpath="experiment",monitor='train_loss',enable_version_counter=True), 
        ModelSummary(max_depth=-1)
    ],
    log_every_n_steps=30,
    precision='16-mixed'
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [8]:
trainer.fit(model=model,datamodule=dm)
trainer.validate(model=model,dataloaders=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
c:\Users\muthu\miniconda3\envs\venv\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.

   | Name                 | Type        | Params
------------------------------------------------------
0  | encoder1             | Encoder     | 2.8 K 
1  | encoder1.conv        | Sequential  | 2.8 K 
2  | encoder1.conv.0      | Conv2d      | 448   
3  | encoder1.conv.1      | BatchNorm2d | 32    
4  | encoder1.conv.2      | ReLU        | 0     
5  | encoder1.conv.3      | Dropout2d   | 0     
6  | encoder1.conv.4      | Conv2d      | 2.3 K 
7  | encoder1.conv.5      | BatchNorm2d | 32    
8  | encoder1.conv.6      | ReLU        | 0     
9  | encoder1.downsample  | M

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\muthu\miniconda3\envs\venv\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]