In [4]:
import os
from torch.utils.data import DataLoader, Dataset ,random_split
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

In [5]:
class FishSegmentationDataset(Dataset):
    def __init__(self, base_path, transform=None, mask_transform=None):
        self.base_path = base_path
        self.transform = transform
        self.mask_transform = mask_transform
        self.fish_categories = [d for d in os.listdir(base_path) 
                               if os.path.isdir(os.path.join(base_path, d))]
        
        
        self.samples = []
        
        for category in self.fish_categories:
            img_path = os.path.join(base_path, category, category)
            mask_path = os.path.join(base_path, category, f"{category} GT")
            
            if not os.path.exists(img_path) or not os.path.exists(mask_path):
                continue
                
            image_files = [f for f in os.listdir(img_path) if f.endswith('.png')]
            image_files.sort(key=lambda x: int(x.split('.')[0]))
            
            for img_file in image_files:
                img_full_path = os.path.join(img_path, img_file)
                mask_full_path = os.path.join(mask_path, img_file)
                
                if os.path.exists(mask_full_path):
                    self.samples.append({
                        'image_path': img_full_path,
                        'mask_path': mask_full_path
                    })
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Load image
        image = Image.open(sample['image_path']).convert("L")
        if self.transform:
            image = self.transform(image)
        
        # Load mask
        mask = Image.open(sample['mask_path']).convert("1")
        if self.mask_transform:
            mask = self.mask_transform(mask)
        
        # Return only image and mask for segmentation
        return image, mask

class TestSegmentationDataset(Dataset):
    def __init__(self, base_path, transform=None):
        self.base_path = base_path
        self.transform = transform
        self.fish_categories = [d for d in os.listdir(base_path) 
                               if os.path.isdir(os.path.join(base_path, d))]
        
       
        self.samples = []
        
        for category in self.fish_categories:
            category_path = os.path.join(base_path, category)
            
            if not os.path.exists(category_path):
                continue
                
            image_files = [f for f in os.listdir(category_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            
            for img_file in image_files:
                img_full_path = os.path.join(category_path, img_file)
                self.samples.append({
                    'image_path': img_full_path
                })
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Load image
        image = Image.open(sample['image_path']).convert("L")
        if self.transform:
            image = self.transform(image)
        
        return image

def TrainSegmentationDataloader(train_batch_size=64, val_batch_size=64, val_split=0.2):
    base_path = "../src/data/Fish_Dataset/A_Fish_Dataset"
    
    # transforms
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    
    mask_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: (x > 0.5).float())
    ])
    
    # full dataset
    full_dataset = FishSegmentationDataset(base_path, transform, mask_transform)
    
    # Split dataset
    total_size = len(full_dataset)
    val_size = int(total_size * val_split)
    train_size = total_size - val_size
    
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    
    # training dataloaders and validation dataloaders
    train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_dataloader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=False, num_workers=4, pin_memory=True)
    
    return train_dataloader, val_dataloader

def TestSegmentationDataloader(batch_size=32):
    test_base_path = "../src/data/Fish_Dataset/NA_Fish_Dataset"
    
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    
    test_dataset = TestSegmentationDataset(test_base_path, transform)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    
    return test_dataloader

In [None]:
train_dl, val_dl = TrainSegmentationDataloader()
train_batch = next(iter(train_dl))
print(f"Train batch: {len(train_batch)} items")
if len(train_batch) == 2:
    images, masks = train_batch
    print(f"  Images: {images.shape}")
    print(f"  Masks: {masks.shape}")

val_batch = next(iter(val_dl))
print(f"Val batch: {len(val_batch)} items")
if len(val_batch) == 2:
    val_images, val_masks = val_batch
    print(f"  Images: {val_images.shape}")
    print(f"  Masks: {val_masks.shape}")

test_dl = TestSegmentationDataloader()
test_batch = next(iter(test_dl))
print(f"Test batch: {len(test_batch)} items")

if isinstance(test_batch, tuple) and len(test_batch) == 1:
    test_images = test_batch[0]
elif isinstance(test_batch, torch.Tensor):
    test_images = test_batch
else:
    test_images = test_batch  # fallback

print(f"  Test Images: {test_images.shape}")

fig, axes = plt.subplots(2, 3, figsize=(15, 10))

axes[0, 0].imshow(images[0].squeeze().numpy(), cmap='gray')
axes[0, 0].set_title('Train Image')
axes[0, 0].axis('off')

axes[0, 1].imshow(masks[0].squeeze().numpy(), cmap='gray')
axes[0, 1].set_title('Train Mask')
axes[0, 1].axis('off')

axes[0, 2].imshow(val_images[0].squeeze().numpy(), cmap='gray')
axes[0, 2].set_title('Val Image')
axes[0, 2].axis('off')

axes[1, 0].imshow(val_masks[0].squeeze().numpy(), cmap='gray')
axes[1, 0].set_title('Val Mask')
axes[1, 0].axis('off')

axes[1, 1].imshow(test_images[0].squeeze().numpy(), cmap='gray')
axes[1, 1].set_title('Test Image')
axes[1, 1].axis('off')

axes[1, 2].axis('off')
axes[1, 2].text(0.5, 0.5, 'No mask for\ntest data', ha='center', va='center', fontsize=12)

plt.tight_layout()
plt.show()

Train batch: 2 items
  Images: torch.Size([64, 1, 224, 224])
  Masks: torch.Size([64, 1, 224, 224])
Val batch: 2 items
  Images: torch.Size([64, 1, 224, 224])
  Masks: torch.Size([64, 1, 224, 224])
Test batch: 32 items


NameError: name 'torch' is not defined

: 