In [1]:
import yaml
from icecream import ic
from PIL.Image import NEAREST

from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.transforms import Compose, ToTensor, Normalize, Resize

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

from datamodules.datamodule import VOC2012SegmentationDataModule
from models.mobilenet import MobileNetV2Segmentation

In [2]:
IMAGE_DIR = "/home/haim/hdd/data/voc/VOCdevkit/VOC2012/JPEGImages"
MASK_DIR = "/home/haim/hdd/data/voc/VOCdevkit/VOC2012/SegmentationObject"
TRAIN_FILE = (
    "/home/haim/hdd/data/voc/VOCdevkit/VOC2012/ImageSets/Segmentation/train.txt"
)
VAL_FILE = (
    "/home/haim/hdd/data/voc/VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt"
)
MAX_EPOCHS = 2
BATCH_SIZE = 4
NUM_WORKERS=2

In [3]:
image_transforms = Compose(
    [
        ToTensor(),
        Resize((374, 500), antialias=True),
    ]
)

mask_transforms = Compose(
    [
        ToTensor(),
        Resize((374, 500), interpolation=NEAREST, antialias=True),
    ]
)

In [4]:
datamodule = VOC2012SegmentationDataModule(
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    image_dir=IMAGE_DIR,
    mask_dir=MASK_DIR,
    image_transforms=image_transforms,
    mask_transforms=mask_transforms,
    train_file=TRAIN_FILE,
    val_file=VAL_FILE,
)

In [6]:
train_dataloader = datamodule.train_dataloader()

In [7]:
imgs, masks = next(iter(train_dataloader))

In [10]:
from icecream import ic
ic(imgs[0].shape)
ic(masks[0].shape)

ic| imgs[0].shape: torch.Size([3, 374, 500])
ic| masks[0].shape: torch.Size([1, 374, 500])


torch.Size([1, 374, 500])

In [None]:
model = MobileNetV2Segmentation(num_classes=20)

In [None]:
trainer = pl.Trainer(
    max_epochs=MAX_EPOCHS,
    accelerator="gpu",
    devices=1,
    fast_dev_run=True,
)

In [None]:
# train the lightning model
trainer.fit(model, datamodule)