In [None]:
%load_ext autoreload
%autoreload 2

!mkdir -p /data/sets/nuimages  # Make the directory to store the nuImages dataset in.
!wget https://www.nuscenes.org/data/nuimages-v1.0-mini.tgz  # Download the nuImages mini split.
!tar -xf nuimages-v1.0-mini.tgz -C /data/sets/nuimages  # Uncompress the nuImages mini split.
!pip install lightning numpy nuscenes-devkit pillow torch

!git clone https://github.com/GordonGustafson/semantic-segmentation.git
%cd semantic-segmentation/

In [None]:
!git pull

In [None]:
from dataloader import *
from nuimages import NuImages

from matplotlib import pyplot as plt


nuimages = NuImages(dataroot='/data/sets/nuimages', version='v1.0-mini', verbose=True, lazy=True)
dataset = NuImagesDataset(nuimages)

sem_seg_sample = dataset['0f37924ef2b54da7a233091d95311a38']

fig, axs = plt.subplots(2)
axs[0].imshow(sem_seg_sample.image, interpolation='nearest')
axs[1].imshow(sem_seg_sample.segmentation_mask, interpolation='nearest')
fig.show()

In [None]:
from dataloader import get_mini_dataloader, image_transform_to_dict_transform

from torch import optim, nn
import torchvision.transforms as T
import lightning as L

# TODO: figure out the correct value for this.
# 31 is the largest value I've seen so far, but haven't looked very hard.
NUM_CLASSES = 32


# define the LightningModule
class PixelWiseSegmentation(L.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def training_step(self, batch, batch_idx):
        images = batch["image"]                             # shape = (N, C, H, W)
        gt_segmentation_masks = batch["segmentation_mask"]  # shape = (N, H, W)
        predicted_segmentation_masks = self.model(images)
        return nn.functional.cross_entropy(predicted_segmentation_masks, gt_segmentation_masks)

    def configure_optimizers(self):
        return optim.AdamW(self.model.parameters(), lr=1e-3, betas=(0.9, 0.999))


# init the model
conv = nn.Conv2d(in_channels=3, out_channels=NUM_CLASSES, kernel_size=1, padding=0, stride=1)
model = PixelWiseSegmentation(conv)

# setup data
transform = image_transform_to_dict_transform(T.ToTensor())
train_dataloader = get_mini_dataloader(batch_size=2, transform=transform)

# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(model=model, train_dataloaders=train_dataloader)
