<a href="https://colab.research.google.com/github/MusliHyseni/CustomFunctions/blob/main/CustomFunctions_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch, requests, zipfile, os, random, numpy as np, matplotlib.pyplot as plt
from torch import nn
from pathlib import Path
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from typing import Tuple, Dict, List

# Create a custom function, which does the ImageFolder's job
"""
1. Subclass torch.utils.data.Dataset.

2. Initialize our subclass with a targ_dir parameter (the target data directory)
 and transform parameter (so we have the option to transform our data if needed).

3. Create several attributes for paths (the paths of our target images),
transform (the transforms we might like to use, this can be None), classes and class_to_idx (from our find_classes() function).

4. Create a function to load images from file and return them, this could be using PIL or torchvision.io (for input/output of vision data).

5. Overwrite the __len__ method of torch.utils.data.Dataset to return the number of samples in the Dataset,
this is recommended but not required. This is so you can call len(Dataset).

6. Overwrite the __getitem__ method of torch.utils.data.Dataset to return a single sample from the Dataset, this is required.
"""
# This is what function at lines 88-97 does
class ImageFolderCustom(Dataset):
    def __init__(self, target_dir: str, tranform: None) -> None:
        self.paths = list(Path(target_dir).glob("*/*.jpg"))
        self.transform = transforms
        self.classes, self.class_to_index = find_classes(target_dir)

    def load_image(self, index: int) -> Image.Image:
        "Opens an image via a path and returns it."
        image_path = self.paths[index]
        return Image.open(image_path)

    def __len__(self) -> int:
        return len(self.paths)

def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
    image = self.load_image(index)
    class_name = self.paths[index].parent.name
    class_index = self.class_to_index[class_name]

    if self.transform:
        return self.transform(img), class_index
    else:
        return img, class_index


# Augment train data
train_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor()
])

# Don't augment test data, only reshape
test_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])

custom_train_data = ImageFolderCustom(train_dir, train_transforms)
custom_test_data = ImageFolderCustom(test_dir, test_transforms)
print(custom_train_data, custom_test_data)



def display_random_images(dataset: torch.utils.data.dataset.Dataset,
                          classes: List[str] = None,
                          n: int = 10,
                          display_shape: bool =True,
                          seed: int = None):
    if n > 10:
        n = 10
        display_shape = False
        print(f"For display purposes, n shouldn't be larger than 10, setting to 10 and removing shape display.")

    if seed:
        random.seed(seed)

    random_index = random.sample(range(len(dataset)), k=n)
    plt.figure(figsize=(16, 8))

    for i, target_sample in enumerate(random_index):
        target_image, target_label = dataset[target_sample][0], dataset[target_sample][1]

        target_image_adjust = target_image.permute(1, 2, 0)

        plt.subplot(1, n, i+1)
        plt.imshow(target_image_adjust)
        plt.axis("off")
        if classes:
            title = f"Class: {classes[target_label]}"
            if display_shape:
                title += f"\tShape: {target_image_adjust.shape}"
        plt.title(title)

# Display random images from ImageFolder created Dataset
display_random_images(train_data,
                      n=5,
                      classes=class_names,
                      seed=None)
