## Setup

In [83]:
import os
import importlib

# Pytorch and lighting module
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.optim import Adam
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer

# Custom modules
import segmentation.utils
importlib.reload(segmentation.utils)
from segmentation.utils import preprocessing


import segmentation.dataset
importlib.reload(segmentation.dataset)
from segmentation.dataset import PointDataset

# Importing the model
from models.unet_model import UNET



In [84]:
# Defining some constants that will be used throughout the notebook
DATA_DIR = 'Dataset/'
x_test_dir = os.path.join(DATA_DIR, 'Test/color')
y_test_dir = os.path.join(DATA_DIR, 'Test/label')
x_trainVal_dir = os.path.join(DATA_DIR, 'TrainVal/color')
y_trainVal_dir = os.path.join(DATA_DIR, 'TrainVal/label')

# Splitting into training and validation
# Getting a list of relative paths to the images (x) and the masks/labels (y)
x_test_fps, y_test_fps = preprocessing.get_testing_paths(x_test_dir, y_test_dir)

# Splitting relative path names into into training and validation
x_train_fps, x_val_fps, y_train_fps, y_val_fps = preprocessing.train_val_split(x_trainVal_dir, y_trainVal_dir, 0.2)

In [85]:
# Creating the dataset
train_augmentation = preprocessing.get_training_augmentation()
preprocessing_fn = preprocessing.get_preprocessing()
val_augmentation = preprocessing.get_validation_augmentation()

train_dataset = PointDataset(x_train_fps, y_train_fps, augmentation=train_augmentation, preprocessing=preprocessing_fn)
val_dataset  = PointDataset(x_val_fps, y_val_fps, augmentation=val_augmentation, preprocessing=preprocessing_fn)

In [86]:
class UNetLightningModule(pl.LightningModule):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512], lr=1e-3):
        super().__init__()
        self.save_hyperparameters()  # Optional: saves hyperparams to checkpoint
        self.model = UNET(in_channels, out_channels, features)
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=255)  # or BCEWithLogitsLoss(), etc.
        self.lr = lr

    def forward(self, x):
        """Just forward the input through UNet."""
        return self.model(x)

    def training_step(self, batch, batch_idx):
        """One iteration of training."""
        images, masks = batch  
        preds = self(images)  
        loss = self.loss_fn(preds, masks)
        
        self.log("train_loss", loss, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        """One iteration of validation."""
        images, masks = batch
        preds = self(images)
        loss = self.loss_fn(preds, masks)
        self.log("val_loss", loss, on_step=False, on_epoch=True)
        return loss

    def configure_optimizers(self):
        """Define your optimizer (and optionally LR scheduler)."""
        optimizer = Adam(self.parameters(), lr=self.lr)
        return optimizer

In [87]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=8, num_workers=0)

model = UNetLightningModule(in_channels=4, out_channels=3, lr=1e-3)

trainer = Trainer(
    max_epochs=10,
    accelerator="gpu",  # or "cpu"
    devices=1,          # number of GPUs (if you have them)
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [88]:
trainer.fit(model, train_loader, val_loader)


  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | model   | UNET             | 31.0 M | train
1 | loss_fn | CrossEntropyLoss | 0      | train
-----------------------------------------------------
31.0 M    Trainable params
0         Non-trainable params
31.0 M    Total params
124.153   Total estimated model params size (MB)
91        Modules in train mode
0         Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/opt/miniconda3/envs/cv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


                                                                           

/opt/miniconda3/envs/cv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 0:   1%|          | 2/368 [00:01<05:21,  1.14it/s, v_num=7]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined