In [3]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, SubsetRandomSampler

def custom_collate(batch):
    # Flatten the image tensors and convert labels to LongTensor
    inputs = [transforms.Normalize((0.5,), (0.5))(transforms.ToTensor()(img)).view(-1) for img, _ in batch]
    labels = [label for _, label in batch]
    return torch.stack(inputs).long(), torch.tensor(labels).long()

def DeFungiDataset(data_dir, use_rgb=True, batch_size=32, test_size=0.2):
    # parallelization of training. I have a 16-core CPU
    num_workers = 16
    
    transform = transforms.Compose([
        transforms.Grayscale() if not use_rgb else transforms.Lambda(lambda x: x),  # Only apply Grayscale if needed
    ])

    # Extract images. Assumes sub-directories indicate class label
    dataset = datasets.ImageFolder(
        root=data_dir,
        transform=transform
    )

    # Split dataset into training and testing; 80:20
    num_data = len(dataset)
    indices = list(range(num_data))
    split = int(test_size * num_data)
    train_indices, test_indices = indices[split:], indices[:split]

    # DataLoader for training and testing
    train_sampler = SubsetRandomSampler(train_indices)
    test_sampler = SubsetRandomSampler(test_indices)

    # DataLoader for training and testing
    trainloader = DataLoader(dataset, sampler=train_sampler, batch_size=batch_size, num_workers=num_workers, collate_fn=custom_collate)
    testloader = DataLoader(dataset, sampler=test_sampler, batch_size=batch_size, num_workers=num_workers, collate_fn=custom_collate)

    for batch_idx, (inputs, labels) in enumerate(trainloader):
        print(f"Batch {batch_idx + 1}:")
        print("Inputs Type:", inputs.type())
        print("Labels Type:", labels.type())
        print("Inputs Shape:", inputs.shape)
        print("Labels Shape:", labels.shape)
        break  # Stop after printing the first batch to avoid long output

    return trainloader, testloader