In [1]:
import os
import urllib.request
import tarfile
import scipy
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [2]:
def download_dataset():
    #download oxford 102 flowers dataset
    images_url = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz"
    labels_url = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat"

    os.makedirs("flower_data", exist_ok=True)

    #download images
    print("Downloading images..")
    images_path = "flower_data/102flowers.tgz"
    urllib.request.urlretrieve(images_url, images_path)

    #extract images
    print("Extracting images..")
    with tarfile.open(images_path, "r:gz") as tar:
        tar.extractall("flower_data")

    #download labels
    print("Downloading labels..")
    labels_path = "flower_data/imageslabels.mat"
    urllib.request.urlretrieve(labels_url, labels_path)

    print("Download complete")

In [67]:
download_dataset()

Downloading images..
Extracting images..


  tar.extractall("flower_data")


Downloading labels..
Download complete


In [3]:
class OxfordFlowersDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.img_dir = os.path.join(root_dir, 'jpg')

        #load matlab labels
        labels_mat = scipy.io.loadmat(os.path.join(root_dir, 'imageslabels.mat'))
        self.labels = labels_mat['labels'][0] - 1 #need to set label range to 0-101 from 1-102
        self.transform = transform
        
        # #min label
        # print(f"Min label: {self.labels.min()}")
        # #max label
        # print(f"Max label: {self.labels.max()}")

    def __len__(self):
        return len(self.labels) #8189 samples

    def __getitem__(self, idx):
        #build image filename
        img_name = f'image_{idx+1:05d}.jpg' #added 1 bcus img name start from 00001
        img_path = os.path.join(self.img_dir, img_name)

        #load image
        image = Image.open(img_path)
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)

        return image, label

In [4]:
# #create dataset
# dataset = OxfordFlowersDataset('./flower_data')
# print(f"Total samples: {len(dataset)}")
# #loading image
# img, label = dataset[0]

# #testing dataset
# import numpy as np
# img_np = np.array(img)

# print(f"First image: ({img_np.shape[0]}, {img_np.shape[1]}), Label: {label}")

In [5]:
#img transformation
# img, _ = dataset[0]
# resized = transforms.Resize(256)(img)
# print(f"After resize: {resized.size}") 
# cropped = transforms.CenterCrop(224)(resized)
# print(f"After crop: {cropped.size}")
# img = cropped

# img_tensor = transforms.ToTensor()(img)
# print(f"Tensor size: {img_tensor.shape}")

# print(img_tensor[0, :3, :3]) 

In [6]:
#transformation
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                     std=[0.229, 0.224, 0.225])
])
#creatng dataset
dataset = OxfordFlowersDataset('./flower_data', transform=transform)

In [9]:
#data loading
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

for images, labebls in dataloader:
    print(f"Success! Batch shape: {images.shape}")
    break

Success! Batch shape: torch.Size([4, 3, 224, 224])
