In [27]:
import numpy as np 
import h5py
import os 

import torch 
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

In [20]:
import os
import h5py
from torch.utils.data import Dataset

class LAHeart(Dataset):
    def __init__(self, base_dir, split='train', transform=None, num=None):
        self._base_dir = base_dir
        self.split = split
        self.transform = transform
        self.sample_list = []
        
        # Path for train/test list
        list_file = os.path.join(self._base_dir, f"{split}.list")
        if not os.path.isfile(list_file):
            raise ValueError(f"The {split} list file is missing: {list_file}")
        
        with open(list_file, 'r') as file:
            self.sample_list = [item.strip() for item in file.readlines()]
        
        if num is not None:
            self.sample_list = self.sample_list[:num]

        print(f"Mode = {self.split}, total samples: {len(self.sample_list)}")

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, index):
        case = self.sample_list[index]
        file_path = os.path.join(self._base_dir, f'2018LA_Seg_Training Set/{case}/mri_norm2.h5')
        
        # Load data safely
        try:
            with h5py.File(file_path, 'r') as h5f:
                image = h5f['image'][:]
                label = h5f['label'][:]
        except FileNotFoundError:
            raise FileNotFoundError(f"File not found: {file_path}")
        
        sample = {'image': image, 'label': label}
        if self.transform:
            sample = self.transform(sample)
        
        return sample


In [29]:
def random_rot_flip(image, label): 
    k = np.random.randint(0, 4, 1) 
    image = np.rot90(image, k) 
    label = np.rot90(label, k) 

    axis = np.random.randint(0, 2)
    image = np.flip(image, axis) 
    label = np.flip(label, axis) 

    return image, label 

class RandomRotFlip: 
    def __call__(self, sample): 
        image, label = sample['image'], sample['label']
        image, label = random_rot_flip(image, label) 
        sample = {'image': image, 'label': label}

        return sample 

In [30]:
# randomcrop 

In [None]:
data_transform = transforms.Compose([
    transforms.RandomRot
])
train_db = LAHeart(
    base_dir= 'LA', 
    split= 'train'
)

trainloader = DataLoader(train_db, batch_size= 1, shuffle= True)
for i, batch in enumerate(trainloader): 
    image, label = batch['image'], batch['label']
    print(f'Image.shape = {image.shape}')
    print(f'Label.shape = {label.shape}')

    if i == 0: 
        break 

Mode = train, total samples: 80
Image.shape = torch.Size([1, 183, 141, 88])
Label.shape = torch.Size([1, 183, 141, 88])
