In [3]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, SubsetRandomSampler

# Convert a given batch to a tensor for the training method
def custom_collate(batch):
    # Convert PIL images to tensors
    inputs = [transforms.ToTensor()(img) for img, _ in batch]
    labels = [label for _, label in batch]
    return torch.stack(inputs), torch.tensor(labels)

# load in the dataset from local directory, split and batch
def DeFungiDataset(data_dir, use_grayscale=True, batch_size=32, test_size=0.2):
    
    # parallelization of training. I have a 16-core CPU
    num_workers = 16
    
    transform = transforms.Compose([
        transforms.Grayscale() if use_grayscale else transforms.Lambda(lambda x: x),  # Only apply Grayscale if needed
        transforms.ToTensor(),  # to tensor object
        transforms.Normalize((0.5,), (0.5,))  # mean = 0.5, std = 0.5
    ])


    # Extract images. Assumes sub-directories indicate class label
    dataset = datasets.ImageFolder(
        root=data_dir,
        transform=transform
    )

    # Split dataset into training and testing; 80:20
    num_data = len(dataset)
    indices = list(range(num_data))
    split = int(test_size * num_data)
    train_indices, test_indices = indices[split:], indices[:split]

    # Split dataset into training and testing; 80:20
    num_data = len(dataset)
    split = int(test_size * num_data)
    trainset, testset = indices[split:], indices[:split]

    # DataLoader for training and testing
    trainloader = DataLoader(dataset, sampler=train_sampler, batch_size=batch_size, num_workers=num_workers, collate_fn=custom_collate)
    testloader = DataLoader(dataset, sampler=test_sampler, batch_size=batch_size, num_workers=num_workers, collate_fn=custom_collate)

    return trainloader, testloader