In [14]:
import os
from dataset import EchoDatasetMasks, EchoDatasetHeatmaps
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.model_selection import train_test_split
import FILE_PATHS

# transforms
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2


In [15]:
def get_loaders_masks(
    images_dir,
    masks_dir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    test_size=0.2,
    pin_memory=True,
):
    image_paths = [os.path.join(images_dir, img) for img in os.listdir(images_dir)]
    mask_paths = [os.path.join(masks_dir, mask) for mask in os.listdir(masks_dir)]

    (
        train_image_paths,
        val_image_paths,
        train_mask_paths,
        val_mask_paths,
    ) = train_test_split(image_paths, mask_paths, test_size=test_size, random_state=42)

    print("TRAIN PATHS")
    print(train_image_paths)
    print(train_mask_paths)

    print("VALIDATION PATHS")
    print(val_image_paths)
    print(val_mask_paths)

    train_dataset = EchoDatasetMasks(
        train_image_paths, train_mask_paths, transform=train_transform
    )
    val_dataset = EchoDatasetMasks(
        val_image_paths, val_mask_paths, transform=val_transform
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return train_loader, val_loader

In [16]:
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 32
NUM_EPOCHS = 12
NUM_WORKERS = 8
IMAGE_HEIGHT = 112
IMAGE_WIDTH = 112
PIN_MEMORY = True
LOAD_MODEL = False

In [17]:
train_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        # A.Rotate(limit=35, p=1.0),
        # A.HorizontalFlip(p=0.5),
        # A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

val_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

In [18]:
train_loader, val_loader = get_loaders_masks(
    FILE_PATHS.IMAGES,
    FILE_PATHS.MASKS,
    BATCH_SIZE,
    train_transforms,
    val_transforms,
    NUM_WORKERS,
    test_size=0.2,
)

TRAIN PATHS
['../EchoNet-Dynamic/data/images\\0X1A2E9496910EFF5B_39.jpeg', '../EchoNet-Dynamic/data/images\\0X1A2A76BDB5B98BED_78.jpeg', '../EchoNet-Dynamic/data/images\\0X1A3D565B371DC573_176.jpeg', '../EchoNet-Dynamic/data/images\\0X1A0A263B22CCD966_83.jpeg', '../EchoNet-Dynamic/data/images\\0X1A8D85542DBE8204_117.jpeg', '../EchoNet-Dynamic/data/images\\0X1A6ACFE7B286DAFC_124.jpeg', '../EchoNet-Dynamic/data/images\\0X1A3E7BF1DFB132FB_73.jpeg', '../EchoNet-Dynamic/data/images\\0X1A0A263B22CCD966_72.jpeg', '../EchoNet-Dynamic/data/images\\0X1A2E9496910EFF5B_55.jpeg', '../EchoNet-Dynamic/data/images\\0X1A8D85542DBE8204_91.jpeg', '../EchoNet-Dynamic/data/images\\0X1A2A76BDB5B98BED_63.jpeg', '../EchoNet-Dynamic/data/images\\0X1A3E7BF1DFB132FB_55.jpeg', '../EchoNet-Dynamic/data/images\\0X1A2C60147AF9FDAE_62.jpeg', '../EchoNet-Dynamic/data/images\\0X1A3D565B371DC573_152.jpeg', '../EchoNet-Dynamic/data/images\\0X1A5FAE3F9D37794E_35.jpeg', '../EchoNet-Dynamic/data/images\\0X1A2C60147AF9FDAE_4

In [19]:
def get_loaders_heatmaps(
    images_dir,
    heatmaps_dir,
    masks_dir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    test_size=0.2,
    pin_memory=True,
):
    image_paths = [os.path.join(images_dir, img) for img in os.listdir(images_dir)]
    mask_paths = [os.path.join(masks_dir, mask) for mask in os.listdir(masks_dir)]
    heatmap_paths = [
        os.path.join(heatmaps_dir, mask) for mask in os.listdir(heatmaps_dir)
    ]

    (
        train_image_paths,
        val_image_paths,
        train_heatmaps_paths,
        val_heatmaps_paths,
        train_mask_paths,
        val_mask_paths,
    ) = train_test_split(
        image_paths, heatmap_paths, mask_paths, test_size=test_size, random_state=42
    )

    print("TRAIN HEATMAP")
    print(train_image_paths)
    print(train_mask_paths)
    print(train_heatmaps_paths)

    print("VALIDATION HEATMAP")
    print(val_image_paths)
    print(val_mask_paths)
    print(val_heatmaps_paths)

    train_dataset = EchoDatasetHeatmaps(
        train_image_paths,
        train_mask_paths,
        train_heatmaps_paths,
        transform=train_transform,
    )
    val_dataset = EchoDatasetHeatmaps(
        val_image_paths, val_mask_paths, val_heatmaps_paths, transform=val_transform
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return train_loader, val_loader

In [20]:
train_loader, val_loader = get_loaders_heatmaps(
    FILE_PATHS.IMAGES,
    f"{FILE_PATHS.HEATMAPS}/gaussian",
    FILE_PATHS.MASKS,
    BATCH_SIZE,
    train_transforms,
    val_transforms,
    NUM_WORKERS,
    test_size=0.2,
)

TRAIN HEATMAP
['../EchoNet-Dynamic/data/images\\0X1A2E9496910EFF5B_39.jpeg', '../EchoNet-Dynamic/data/images\\0X1A2A76BDB5B98BED_78.jpeg', '../EchoNet-Dynamic/data/images\\0X1A3D565B371DC573_176.jpeg', '../EchoNet-Dynamic/data/images\\0X1A0A263B22CCD966_83.jpeg', '../EchoNet-Dynamic/data/images\\0X1A8D85542DBE8204_117.jpeg', '../EchoNet-Dynamic/data/images\\0X1A6ACFE7B286DAFC_124.jpeg', '../EchoNet-Dynamic/data/images\\0X1A3E7BF1DFB132FB_73.jpeg', '../EchoNet-Dynamic/data/images\\0X1A0A263B22CCD966_72.jpeg', '../EchoNet-Dynamic/data/images\\0X1A2E9496910EFF5B_55.jpeg', '../EchoNet-Dynamic/data/images\\0X1A8D85542DBE8204_91.jpeg', '../EchoNet-Dynamic/data/images\\0X1A2A76BDB5B98BED_63.jpeg', '../EchoNet-Dynamic/data/images\\0X1A3E7BF1DFB132FB_55.jpeg', '../EchoNet-Dynamic/data/images\\0X1A2C60147AF9FDAE_62.jpeg', '../EchoNet-Dynamic/data/images\\0X1A3D565B371DC573_152.jpeg', '../EchoNet-Dynamic/data/images\\0X1A5FAE3F9D37794E_35.jpeg', '../EchoNet-Dynamic/data/images\\0X1A2C60147AF9FDAE