##### Imports

In [None]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
from torchvision.transforms import functional as F

##### Mushroom Dataset (with mapping to csv with labels)

In [None]:
class BaseMushroomDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None, has_labels=True):
        self.annotations = pd.read_csv(csv_file, dtype={0: str})
        self.root_dir = root_dir
        self.transform = transform
        self.has_labels = has_labels

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.annotations.iloc[idx, 0] + '.jpg')
        image = Image.open(img_name).convert("RGB")
        if self.has_labels:
            label = int(self.annotations.iloc[idx, 1])
        else:
            label = -1  
        if self.transform:
            image = self.transform(image)
        return image, label

##### Adding white padding around images

In [None]:
def pad_to_square(image):
    width, height = image.size
    max_dim = max(width, height)
    left_padding = (max_dim - width) // 2
    top_padding = (max_dim - height) // 2
    padding = (left_padding, top_padding, max_dim - width - left_padding, max_dim - height - top_padding)
    return F.pad(image, padding, 255, 'constant')


##### Transform images to uniform size 224x224 (for AlexNet)

In [None]:
transform = transforms.Compose([
    transforms.Lambda(pad_to_square),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

##### Paths

In [None]:
# Base paths
root_path = os.path.dirname(os.getcwd())
dataset_path = os.path.join(root_path, 'dataset')
dataset_raw_path = os.path.join(dataset_path, 'raw')
dataset_preprocessed_path = os.path.join(dataset_path, 'preprocessed')
csv_path = os.path.join(dataset_path, 'csv_mappings')

# Target preprocessing folder
preprocessed_train_path = os.path.join(root_path, 'dataset', 'preprocessed', 'train')
os.makedirs(preprocessed_train_path, exist_ok=True)

preprocessed_test_path = os.path.join(root_path, 'dataset', 'preprocessed', 'test')
os.makedirs(preprocessed_test_path, exist_ok=True)

# Mappings to names in CSV
os.makedirs(csv_path, exist_ok=True)
train_csv_path = os.path.join(csv_path, 'train.csv')
test_csv_path = os.path.join(csv_path, 'test.csv')

##### Dataloaders

In [None]:
train_dataset = BaseMushroomDataset(csv_file=train_csv_path, root_dir=dataset_raw_path, transform=transform, has_labels=True)
test_dataset = BaseMushroomDataset(csv_file=test_csv_path, root_dir=dataset_raw_path, transform=transform, has_labels=False)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
single_image, single_label = train_dataset[0]  
print(f"Data Shape: {single_image.shape}")


##### Sample images

In [None]:
def denormalize(image, mean, std):
    image = image.clone()
    for t, m, s in zip(image, mean, std):
        t.mul_(s).add_(m)
    return image

In [None]:
def show_samples(dataset, num_samples=10, images_per_row=5):
    num_rows = (num_samples + images_per_row - 1) // images_per_row  
    fig, axes = plt.subplots(num_rows, images_per_row, figsize=(20, num_rows * 4))
    axes = axes.flatten() 

    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])

    for i in range(num_samples):
        image, label = dataset[i]
        image = denormalize(image, mean, std)
        image = image.permute(1, 2, 0).numpy()  
        image = np.clip(image, 0, 1)           
        
        axes[i].imshow(image)
        axes[i].set_title(f'Label: {label}')
        axes[i].axis('off')

    for i in range(num_samples, len(axes)):
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
show_samples(train_dataset)


In [None]:
show_samples(test_dataset)


##### Save preprocessed

In [None]:
def save_preprocessed_images_as_tensors(dataset, save_dir):
    for idx in range(len(dataset)):
        image, label = dataset[idx]
        img_name = dataset.annotations.iloc[idx, 0] + '.pt'
        torch.save(image, os.path.join(save_dir, img_name))

In [None]:
save_preprocessed_images_as_tensors(train_dataset, preprocessed_train_path)
save_preprocessed_images_as_tensors(test_dataset, preprocessed_test_path)

##### Test loading preprocessed again (from tensors)

In [None]:
def show_preprocessed_samples(preprocessed_path, num_samples=25, images_per_row=5):
    image_files = [f for f in os.listdir(preprocessed_path) if f.endswith('.pt')]
    num_samples = min(num_samples, len(image_files))
    
    num_rows = (num_samples + images_per_row - 1) // images_per_row  
    fig, axes = plt.subplots(num_rows, images_per_row, figsize=(20, num_rows * 4))
    axes = axes.flatten() 

    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])

    for i in range(num_samples):
        image_tensor = torch.load(os.path.join(preprocessed_path, image_files[i]))
        image_tensor = denormalize(image_tensor, mean, std)
        image = transforms.ToPILImage()(image_tensor)  
        axes[i].imshow(image)
        axes[i].set_title(f'Image: {image_files[i]}')
        axes[i].axis('off')

    for i in range(num_samples, len(axes)):
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
show_preprocessed_samples(preprocessed_train_path)
show_preprocessed_samples(preprocessed_test_path)