In [None]:
from PIL import Image
import os
import os.path

import torch.utils.data
import torchvision.transforms as transforms

import numpy as np
import pandas as pd

In [None]:
# Image loader helper function
def default_image_loader(path):
    return Image.open(path).convert('RGB')

In [None]:
# Dataset
im = Image.open(r"C:\Users\xiaow\OneDrive\Desktop\Spring 2022\Introduction to Machine Learning\Introduction-to-Machine-Learning\Task 3\food\00001.jpg")

In [None]:
display(im)

In [None]:
data = np.asarray(im)
data.shape

In [None]:
im.resize((354,242))

In [None]:
class TripletImageLoader(torch.utils.data.Dataset):
    def __init__(self, base_path, triplets_file_name, transform=None, loader=default_image_loader):
        """ base_path: The path contains the text file of the training triplets
            triplets_file_name: The text file with each line containing three integers, 
            where integer i refers to the i-th image in the filenames file.  
            Each line contains three integers (a triplet).
            For example, the triplet "00723 00478 02630" denotes that the dish in image "00723.jpg" is more similar in taste 
            to the dish in image "00478.jpg" than to the dish in image "02630.jpg" according to a human annotator.
         """
        self.base_path = base_path  
        triplets = []
        for line in open(triplets_file_name):
            triplets.append((line.split()[0], line.split()[1], line.split()[2])) # anchor, positive, negative
        self.triplets = triplets
        self.transform = transform
        self.loader = loader

    def __getitem__(self, index):
        path1, path2, path3 = self.triplets[index]
        img1 = self.loader(os.path.join(self.base_path, f'{path1}.jpg'))
        img2 = self.loader(os.path.join(self.base_path, f'{path2}.jpg'))
        img3 = self.loader(os.path.join(self.base_path, f'{path3}.jpg'))
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            img3 = self.transform(img3)

        return img1, img2, img3

    def __len__(self):
        return len(self.triplets)

In [None]:
# Initialization: importing the packages that we will use
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu' # Google colab offers time limited use of GPU for free

################# Configuration  ######################
IMAGE_SIZE = (354, 242) # bigger image size improves performance but makes training slower.

# Training parameters 
BATCH_SIZE = 32

In [None]:
# Dataset and Trasformations
import torchvision
import torchvision.transforms as transforms

############# Datasets and Dataloaders ################
transform_train = transforms.Compose([
    transforms.ToTensor(), # The output of torchvision datasets are PILImage images of range [0, 1].
    transforms.Resize(IMAGE_SIZE),
    transforms.RandomHorizontalFlip(p=0.5), # we want our network to be robust over geometrical transformations that leave the image semantically invariant
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), #  We transform them to Tensors of normalized range [-1, 1].
    # (mean, mean, mean) , (std, std, std): output[channel] = (input[channel] - mean[channel]) / std[channel]
])

transform_validate = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(IMAGE_SIZE),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(IMAGE_SIZE),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

path = 'C:/Users/xiaow/OneDrive/Desktop/Spring 2022/Introduction to Machine Learning/Introduction-to-Machine-Learning/Task 3/food'
train_dataset = TripletImageLoader(path.rstrip('\n'), 'train_triplets.txt', transform=transform_train)

In [None]:
train_dataset[59514]

In [None]:
from torch.utils.data import DataLoader

loader = DataLoader(train_dataset, batch_size=1000, num_workers=0)
data = next(iter(loader))
len(data[0])

In [None]:
test_dataset = TripletImageLoader(path.rstrip('\n'), 'test_triplets.txt', transform=transform_test)

In [None]:
len(test_dataset)