# Unzip defungi.zip

In [1]:
%%capture
import os
if not os.path.isdir('data'): 
    !mkdir data
    !tar -xvzf defungi.zip -C data

# Necessary Imports

In [2]:
import numpy as np
import polars as pl
import cv2
import os
import torch
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, random_split, DataLoader

# Constants

In [3]:
CLASSES = {
    'H1': 0,
    'H2': 1,
    'H3': 2,
    'H5': 3,
    'H6': 4
}

SEED = torch.Generator().manual_seed(42)


# Define Dataset class

In [4]:
class FungiDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, files, labels):
        """
        Arguments:
            file (string): Absolute path to the jpeg file.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.images = []
        self.files = files
        self.labels = labels
        self.transform = transforms.Compose([transforms.PILToTensor()])
        self._preprocess()

    def _preprocess(self):
        for file in self.files:
            image = Image.open(file)
            tensor = self.transform(image)
            self.images.append(tensor)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = torch.tensor(self.labels[idx])
        file = torch.tensor(self.files[idx])
        sample = {'file': file, 'image': self.transform(image), 'label': label}
        return sample

# Get all valid images

In [5]:
def load_images_from_folder(folder):
    images = []
    for root, _, files in os.walk(folder):
        if not str(root).startswith("data/H"):
            continue
        for file in files:
            c = file.split('_')[0]
            images.append((root+'/'+file, CLASSES[c]))
    return images

images = load_images_from_folder('data/')

assert(len(images) == 9114)

files, labels = zip(*images)


# Instantiate Torch Dataset

In [12]:
dataset = FungiDataset(files=files, labels=labels)

train, test = random_split(dataset=dataset, lengths=[0.7, 0.3], generator=SEED)

# train_loader = DataLoader(train, batch_size=4)
# test_loader = DataLoader(test, batch_size=4)