In [None]:
from torch.utils.data import DataLoader
from torchvision import transforms
from pathlib import Path
import sys

In [None]:
NB_DIR = Path.cwd()
PROJECT_ROOT = NB_DIR.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

In [None]:
from src.dataset import ImageDataset

In [None]:
dataset = ImageDataset()

##### Check classes

In [None]:
print(dataset.classes)

##### Check sample

In [None]:
image, label = dataset[-1]
print(dataset.classes[label])
image

In [None]:
image.size

##### Let's see the dimension of the images

In [None]:
test_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
])

In [None]:
dataset_transformed = ImageDataset(transform=test_transform)

In [None]:
image, label = dataset_transformed[-1]
image.size()

In [None]:
print(len(dataset_transformed))

In [None]:
loader = DataLoader(dataset_transformed, batch_size=32, shuffle=False)

In [None]:
sample = next(iter(loader))
image, label = sample
print(f"Image shape: {image.shape}")
print(f"Labels shape: {label.shape}")

##### Now lets calculate the mean and standard deviation of our dataset.

In [None]:
def get_mean_and_std(loader):
    mean = 0.0
    std = 0.0
    total_images = 0
    
    for image, _ in loader:
        image_batch_count = image.size(0)
        image = image.view(image_batch_count, image.size(1), -1)
        mean += image.mean(2).sum(0)
        std += image.std(2).sum(0)
        total_images += image_batch_count
    
    mean /= total_images
    std /= total_images
    
    return mean, std

In [None]:
get_mean_and_std(loader)

##### After running the function, it returned:
mean: [0.4843, 0.4340, 0.3911]
std: [0.2415, 0.2331, 0.2263]

##### Let's test the data loader script we made.

In [None]:
from src.loaders import get_loaders

In [None]:
test_train_loader, test_val_loader = get_loaders(batch_size=32)

In [None]:
sample = next(iter(test_train_loader))
image, label = sample
print(f"Image shape: {image.shape}")
print(f"Labels shape: {label.shape}")

In [None]:
import matplotlib.pyplot as plt
import math

In [None]:
image_folder = dataset.dataset

unique_samples = {}
for path, label_idx in image_folder.samples:
    class_name = image_folder.classes[label_idx]
    if class_name not in unique_samples:
        unique_samples[class_name] = []
    
    if len(unique_samples[class_name]) < 5:
        unique_samples[class_name].append(path)
        

rows = 2
cols = 5

plt.figure(figsize=(15, 3 * rows))
sorted_classes = sorted(unique_samples.keys())

for class_idx, class_name in enumerate(sorted_classes):
    for img_idx in range(cols): 
        image_path = unique_samples[class_name][img_idx]
        image = plt.imread(image_path)
        
        subplot_index = (class_idx * cols) + img_idx + 1
        
        plt.subplot(rows, cols, subplot_index)
        plt.imshow(image)
        plt.axis('off')
        
        if img_idx == 0:
            plt.title(class_name, fontsize=14, loc='left')

plt.tight_layout()
plt.show()