In [None]:
import csv
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torchvision.transforms import Resize, InterpolationMode, ToTensor, ToPILImage, TenCrop, Compose, Lambda
from torchmetrics import JaccardIndex, Precision, Recall, F1Score
import segmentation_models_pytorch as smp

from src.models.BaselineModel import BaselineModel
from src.evaluation.evaluate_result import evaluate_result
from src.callbacks.SaveRandomImagesCallback import SaveRandomImagesCallback
from src.callbacks.SaveTestPredsMulticlass import SaveTestPredsMulticlass
from src.datasets.utils.Squeeze5DimIfNeeded import Squeeze5DimIfNeeded
from src.datasets.DubaiSemanticSegmentationDataset import (
    DubaiSemanticSegmentationDataset,
)
from src.datasets.utils.ResizeToDivisibleBy32 import ResizeToDivisibleBy32

## Prepare environment

In [None]:
torch.cuda.is_available()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# if you get some cryptic CUDA error, set device to "cpu" and try again
print(device)

In [None]:
VAL_SIZE = 0.2
BATCH_SIZE = 1
SEED = 42
SAVE_VAL_DIR = "outputs/Dubai/val"
SAVE_TEST_DIR = "outputs/Dubai/test"
INPUT_DUBAI_DATASET_PATH = "data/DubaiSemanticSegmentationDataset"
OUTPUT_DUBAI_DATASET_PATH = "data/DubaiSemanticSegmentationDatasetPatches"
IMAGE_SIZE = 576
CROP_IMAGE_SIZE = 224

In [None]:
labeled_dataset = DubaiSemanticSegmentationDataset(
    INPUT_DUBAI_DATASET_PATH, 
    transforms=Compose([
        TenCrop(CROP_IMAGE_SIZE, vertical_flip=True),
        Lambda(lambda crops: torch.stack([crop for crop in crops]))
    ])
)
print(len(labeled_dataset))

## Data preparation

### Sanity check data

In [None]:
sanity_check_loader = DataLoader(labeled_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
for batch in sanity_check_loader:
    print(type(batch))
    print(len(batch))
    images = batch[0]
    masks = batch[1]
    print(images.shape)
    print(masks.shape)
    for sample in batch:
        print(type(sample))
        # images, masks = sample
        # print(images.shape)
        # print(masks.shape)
        break
    break

# with TenCrop use:
# for batch in train_loader:
#     images, masks = batch
#     print(images.shape)
#     print(masks.shape)
#     break