In [9]:
import os
import tarfile

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import scipy.io

from PIL import Image

In [None]:
# with tarfile.open(r'<path>', 'r:gz') as tar:
#     tar.list()
#     tar.extractall(path='./data')

## 1 Load & Preprocess Data


In [10]:
class FlowerDataset(Dataset):
    def __init__(self, image_dir, label_file, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        
        # Load image labels
        mat = scipy.io.loadmat(label_file)
        self.labels = mat['labels'].flatten() - 1 # converts 1-based idx in matlab to 0-based idx in python
        
        # Load image file names
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]
        self.image_files.sort()
    
    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_file[idx]
        img_path = os.path.join(self.image_dir, img_name)
        img = Image.open(img_path).convert('RGB')
        
        label = self.labels[idx]
        return img, torch.tensor(label, dtype=torch.long)
        

In [18]:
image_size = 64
batch_size = 64

transform = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)), # resize image to (image_size x image_size)
        transforms.ToTensor(), # convert arrays to tensors
        transforms.Normalize([.5, .5, .5], [.5, .5, .5]), # normalize pixels (pixel - mean) / STD -> mean & SD for R,G,B is 0.5
    ]
)

img_dir = "./data/jpg"
label_file = "imagelabels.mat"

dataset = FlowerDataset(img_dir, label_file, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)