In [None]:
import torch
from torch.utils.data import DataLoader
from src.datasets.INRIAAerialImageLabellingDataset import (
    INRIAAerialImageLabellingDataset,
)
from src.datasets.UAVidSemanticSegmentationDataset import (
    UAVidSemanticSegmentationDataset,
)
from src.utils import (
    process_inria_dataloader_and_save,
    process_uavid_dataloader_and_save,
)

In [None]:
torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# if you get some cryptic CUDA error, set device to "cpu" and try again
print(device)

## UAVid

In [None]:
VAL_SIZE = 0.2
BATCH_SIZE = 1
SEED = 42
SAVE_VAL_DIR = "outputs/UAVid/val"
SAVE_TEST_DIR = "outputs/UAVid/test"
UAVID_DATASET_PATH = "data/UAVidSemanticSegmentationDataset"
IMAGE_WIDTH = 1024
IMAGE_HEIGHT = 576

## Prepare data

### Train

In [None]:
train_dataset = UAVidSemanticSegmentationDataset(
    UAVID_DATASET_PATH,
    split="train",
)
print(len(train_dataset))

In [None]:
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8
)

In [None]:
process_uavid_dataloader_and_save(train_loader, train=True)

### Val

In [None]:
val_dataset = UAVidSemanticSegmentationDataset(
    UAVID_DATASET_PATH,
    split="valid",
)
print(len(val_dataset))

In [None]:
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8
)

In [None]:
process_uavid_dataloader_and_save(val_loader, train=True)

### Test

In [None]:
test_dataset = UAVidSemanticSegmentationDataset(
    UAVID_DATASET_PATH,
    split="test",
)
print(len(test_dataset))

In [None]:
test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8
)

In [None]:
process_uavid_dataloader_and_save(test_loader, train=False)

## INRIA

In [None]:
VAL_SIZE = 0.2
BATCH_SIZE = 4
SEED = 42
IMAGE_SIZE = 576
SAVE_VAL_DIR = "outputs/INRIA/val"
SAVE_TEST_DIR = "outputs/INRIA/test"
INRIA_DATASET_PATH = "data/INRIAAerialImageLabellingDataset"  # home PC
# INRIA_DATASET_PATH = "data/TestSubsets/INRIAAerialImageLabellingDataset"  # laptop

### Train

In [None]:
labeled_dataset = INRIAAerialImageLabellingDataset(
    INRIA_DATASET_PATH,
    split="train",
)
print(len(labeled_dataset))

In [None]:
labeled_dataloader = torch.utils.data.DataLoader(
    labeled_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
)

In [None]:
# process_inria_dataloader_and_save(labeled_dataloader)

### Test

In [None]:
unlabeled_dataset = INRIAAerialImageLabellingDataset(
    INRIA_DATASET_PATH,
    split="test",
)
print(len(unlabeled_dataset))

In [None]:
unlabeled_dataloader = torch.utils.data.DataLoader(
    unlabeled_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
)

In [None]:
# process_inria_dataloader_and_save(unlabeled_dataloader, train=False)