In [39]:
import os
import json
import torch
from torch.utils.data import Dataset
from PIL import Image
from collections import defaultdict
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.Resize((128, 128)),           # Resize all images to 128x128
    transforms.ToTensor(),                   # Convert PIL image to PyTorch tensor
])

PATH = "../data/"
DATASETS = {
    'zooplankton': "SYKE-plankton_ZooScan_2024",
    'phytoplankton': "SYKE-plankton_IFCB_2022",
}

IMAGE_PATHS = {
    'zooplankton': "SYKE-plankton_ZooScan_2024",
    'phytoplankton': "SYKE-plankton_IFCB_2022",
}




In [48]:

# Load splits.json file
def load_json(filepath):
    with open(filepath, 'r') as f:
        return json.load(f)

# Define the get_datasets function
def get_datasets(dataset_name, transform, split, data_dir, N=None):
    """
    Load the train, validation, and test datasets with optional oversampling for training.
    """
    dataset = DATASETS.get(dataset_name)
    if dataset is None:
        raise KeyError(f'{dataset_name} not in available datasets: 'f'{", ".join(DATASETS.keys())}')
    
    
    # Load the splits.json file containing partitions and categories
    class_split_path = os.path.join(PATH, "class_splits", dataset, "splits.json")
    print(class_split_path)
    splits_data = load_json(class_split_path)
    # print(splits_data)



# Load datasets (train dataset with oversampling to get N images per class)
get_datasets(
    dataset_name='zooplankton',
    transform=train_transform,
    split = 0,
    data_dir='../data',
    N=1000  # Ensure each class in training set has 1000 images
)



../data/class_splits\SYKE-plankton_ZooScan_2024\splits.json
{'categories': {'0': 'Bivalvia', '1': 'Bivalvia_multiple', '2': 'Bosmina_sp', '3': 'Bubbles', '4': 'Ceriodaphnia_sp', '5': 'Copepoda_calanoida', '6': 'Copepoda_cyclopoida', '7': 'Copepoda_nauplius', '8': 'Daphnia_sp', '9': 'Eggs', '10': 'Evadne_sp', '11': 'Fibers_etc', '12': 'Fish_eggs', '13': 'Gastropoda', '14': 'Harpacticoida', '15': 'Mysis_sp', '16': 'Podon_sp', '17': 'Polychaeta', '18': 'Sessilia', '19': 'Synchaeta_sp'}, 'images': {'0': {'train': ['Bivalvia_242.jpg', 'Bivalvia_293.jpg', 'Bivalvia_168.jpg', 'Bivalvia_019.jpg', 'Bivalvia_100.jpg', 'Bivalvia_024.jpg', 'Bivalvia_173.jpg', 'Bivalvia_150.jpg', 'Bivalvia_178.jpg', 'Bivalvia_288.jpg', 'Bivalvia_245.jpg', 'Bivalvia_107.jpg', 'Bivalvia_217.jpg', 'Bivalvia_239.jpg', 'Bivalvia_237.jpg', 'Bivalvia_082.jpg', 'Bivalvia_073.jpg', 'Bivalvia_001.jpg', 'Bivalvia_266.jpg', 'Bivalvia_190.jpg', 'Bivalvia_236.jpg', 'Bivalvia_238.jpg', 'Bivalvia_182.jpg', 'Bivalvia_231.jpg', 'Biv

In [None]:
# Define the function to get images for a specific class
def getImagesForClass(path, class_name):
    class_path = os.path.join(path, class_name)
    return [os.path.join(class_path, img) for img in os.listdir(class_path)]

# Custom Dataset with Oversampling Logic for Equal Number of Images Per Class
class PlanktonSet(Dataset):
    """
    Custom dataset for plankton images with oversampling for training.
    """
    def __init__(self, path, classes, transform, class_names=None, N=None):
        self.transform = transform
        self.classes_uniq = classes
        self.class_images = defaultdict(list)

        # Get images for each class
        for class_name in self.classes_uniq:
            self.class_images[class_name] = getImagesForClass(path, class_name)

        # Oversample classes to have N images each (only for training)
        if N is not None:
            self.images, self.labels = self.oversample_classes(N)
        else:
            self.images, self.labels = self.get_all_images_and_labels()
        
        self.class_names = class_names if class_names is not None else classes
        self.num_classes = len(self.classes_uniq)

    def oversample_classes(self, N):
        """
        Oversample/augment classes to ensure each class has N images.
        """
        images = []
        labels = []

        for idx, class_name in enumerate(self.classes_uniq):
            class_images = self.class_images[class_name]
            num_images = len(class_images)
            
            # If the class has fewer images than N, repeat or augment images
            if num_images < N:
                # Repeat existing images and augment them
                augmented_images = random.choices(class_images, k=N-num_images)  # Randomly select images to oversample
                class_images += augmented_images

            # Now class_images will have at least N images
            images.extend(class_images[:N])
            labels.extend([idx] * N)

        return images, labels

    def get_all_images_and_labels(self):
        """
        Get all images and corresponding labels without oversampling.
        """
        images = []
        labels = []

        for idx, class_name in enumerate(self.classes_uniq):
            class_images = self.class_images[class_name]
            images.extend(class_images)
            labels.extend([idx] * len(class_images))
        
        return images, labels

    def __len__(self):
        return len(self.images)
        
    def __getitem__(self, idx):
        label = torch.tensor(self.labels[idx]).long()
        image_path = self.images[idx]

        image = Image.open(image_path).convert("L")  # Assuming grayscale images
        image = self.transform(image)

        return image, label


# Function to create train/val/test datasets with optional oversampling for training
def create_dataset_splits(path, splits_data, transform, dataset_type='train', N=None):
    """
    Create a dataset using the provided splits data (train, valid, test) with optional oversampling.
    """
    images = []
    classes = []

    for class_id, image_data in splits_data['images'].items():
        class_name = splits_data['categories'][class_id]
        if dataset_type == 'train':
            images.extend(image_data['train'])
        elif dataset_type == 'valid':
            images.extend(image_data['valid'])
        elif dataset_type == 'test':
            images.extend(image_data['test'])
        
        classes.append(class_name)
    
    # For training, pass N to the dataset for oversampling
    return PlanktonSet(path=path, classes=classes, transform=transform, N=N if dataset_type == 'train' else None)

def get_dataset_files(dataset_name):
    dataset = DATASETS.get(dataset_name)
    if dataset is None:
        raise KeyError(f'{dataset_name} not in available datasets: 'f'{", ".join(DATASETS.keys())}')
    
    # Load the splits.json file containing partitions and categories
    class_split_path = os.path.join(path, "splits.json")
    splits_data = load_json(class_split_path)
    
    # Load the class splits, 0.json, 1.json etc.
    class_splits = []
    for split_file in os.listdir(path):
        if split_file.endswith('.json') and split_file != 'splits.json':
            class_splits.append(load_json(os.path.join(path, split_file)))
    
    return splits_data, class_splits


