In [84]:
import torch
from torch import nn
from torch.utils.data import Dataset
import pandas as pd
import cv2
import numpy as np
import os

class DeepFashionDataset(Dataset):
    def __init__(self,img_dir,train_path=None,test_path=None,validation_path=None,mode="train", transform=None, ):
        
        self.transform= transform
        self.img_dir = img_dir

        if mode=="train":
            assert(train_path is not None)
            self.file_list = train_path
        elif mode=="test":
            assert(test_path is not None)
            self.file_list = test_path
        elif mode=="validation":
            assert(validation_path is not None)
            self.file_list = validation_path
        else:
            return
        


        df = pd.read_csv(self.file_list, header=0)
        
        self.data = df["file"].to_numpy().tolist()
        self.main_labels = df["main_category"].to_numpy().tolist() # class id of each sample
        self.sub_labels = df["sub_category"].to_numpy().tolist()
        self.clothes_types = df["clothes_type"].to_numpy().tolist()
        self.source_types = df["source_type"].to_numpy().tolist()
        self.variation_types=df["variation_type"].to_numpy().tolist()
        self.bboxes = df["bbox"].to_numpy().tolist()
        self.landmarks = df["landmarks"].to_numpy().tolist()
        self.attributes = df["attributes"].to_numpy().tolist()
        self.labels = [i+"/"+j for i,j in zip(self.main_labels,self.sub_labels)  ]
        del df

        self.unique_labels = set(self.labels)
        #
        # ["file","main_category","sub_category","clothes_type","source_type","variation_type","bbox","landmarks","attributes"]
        # with open(self.file_list,"r") as file : 
        #     lines = file.readlines()
        #     for line in lines[1:]:
        #         splitted_data = line.split(",")
        #         self.data.append(splitted_data[0])
        #         self.main_labels.append(splitted_data[1])
        #         self.sub_labels.append(splitted_data[2])
        #         self.clothes_types.append(splitted_data[3])
        #         self.source_types.append(splitted_data[4])
        #         self.variation_types.append(splitted_data[5])
        #         self.bboxes.append(eval(splitted_data[6]))
        #         self.landmarks.append(eval(splitted_data[7]))
        #         self.attributes.append(eval(splitted_data[8]))






        # for class_path in self.file_list:
        #     class_name = class_path.split("/")[-1]
        #     self.classes.append(class_name)
        #     for img_path in glob.glob(class_path + "/*.jpg"):
        #         self.data.append([img_path, class_name])
                
        self.idx_to_class = {i:j for i, j in enumerate(self.labels)}
        self.class_to_idx = {label: np.squeeze(np.where(np.array(self.labels) == label)).tolist()
                                 for label in self.unique_labels}

        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data[idx]
        class_name = self.labels[idx]
        img = cv2.imread(os.path.join(self.img_dir,*img_path.split("/")))
        img_tensor = torch.from_numpy(img)
        img_tensor = img_tensor.permute(2, 0, 1)
        
        if self.transform is not None:
            image = self.transform(img_tensor)
        else :
            image = img_tensor
        return image, class_name

In [85]:
dataset = DeepFashionDataset(img_dir="./img",train_path="DeepFashionDataset_train.csv",mode="train")

In [86]:
len(dataset)

159495

In [87]:
dataset.unique_labels

{'CLOTHING/Blouse',
 'CLOTHING/Coat',
 'CLOTHING/Jeans',
 'CLOTHING/Pants',
 'CLOTHING/Polo_Shirt',
 'CLOTHING/Summer_Wear',
 'CLOTHING/T_Shirt',
 'CLOTHING/Tank_Top',
 'DRESSES/Dress',
 'DRESSES/Lace_Dress',
 'DRESSES/Skirt',
 'DRESSES/Sleeveless_Dress',
 'DRESSES/Suspenders_Skirt',
 'TOPS/Blouse',
 'TOPS/Chiffon',
 'TOPS/Coat',
 'TOPS/Lace_Shirt',
 'TOPS/Summer_Wear',
 'TOPS/T_Shirt',
 'TOPS/Tank_Top',
 'TROUSERS/Leggings',
 'TROUSERS/Pants',
 'TROUSERS/Summer_Wear'}

In [88]:
len(dataset.idx_to_class)

159495

In [89]:
len(dataset.class_to_idx)

23

In [90]:
dataset.class_to_idx

{'CLOTHING/Jeans': [699,
  700,
  701,
  702,
  703,
  704,
  705,
  706,
  707,
  708,
  709,
  710,
  711,
  712,
  713,
  714,
  715,
  716,
  717,
  718,
  719,
  720,
  721,
  722,
  723,
  724,
  725,
  726,
  727,
  728,
  729,
  730,
  731,
  732,
  733,
  734,
  735,
  736,
  737,
  738,
  739,
  740,
  741,
  742,
  743,
  744,
  745,
  746,
  747,
  748,
  749,
  750,
  751,
  752,
  753,
  754,
  755,
  756,
  757,
  758,
  759,
  760,
  761,
  762,
  763,
  764,
  765,
  766,
  767,
  768,
  769,
  770,
  771,
  772,
  773,
  774,
  775,
  776,
  777,
  778,
  779,
  780,
  781,
  782,
  783,
  784,
  785,
  786,
  787,
  788,
  789,
  790,
  791,
  792,
  793,
  794,
  795,
  796,
  797,
  798,
  799,
  800,
  801,
  802,
  803,
  804,
  805,
  806,
  807,
  808,
  809,
  810,
  811,
  812,
  813,
  814,
  815,
  816,
  817,
  818,
  819,
  820,
  821,
  822,
  823,
  824,
  825,
  826,
  827,
  828,
  829,
  830,
  831,
  832,
  833,
  834,
  835,
  836,
  837,
  838,
  

In [91]:
dataset.__getitem__(0)

(tensor([[[107, 109, 110,  ...,  62,  62,  56],
          [107, 109, 110,  ...,  63,  60,  59],
          [105, 107, 109,  ...,  61,  60,  60],
          ...,
          [ 43,  45,  47,  ...,  13,  11,   6],
          [ 42,  45,  47,  ...,  15,  12,   9],
          [ 42,  45,  47,  ...,  18,  13,  14]],
 
         [[192, 194, 195,  ..., 151, 151, 148],
          [192, 194, 195,  ..., 152, 152, 151],
          [192, 194, 196,  ..., 153, 152, 152],
          ...,
          [146, 148, 150,  ...,  86,  86,  82],
          [145, 148, 150,  ...,  88,  87,  85],
          [145, 148, 150,  ...,  91,  90,  90]],
 
         [[248, 250, 251,  ..., 231, 231, 227],
          [248, 250, 251,  ..., 232, 231, 230],
          [248, 250, 252,  ..., 230, 229, 229],
          ...,
          [233, 235, 237,  ..., 166, 165, 164],
          [232, 235, 237,  ..., 168, 166, 167],
          [232, 235, 237,  ..., 171, 169, 173]]], dtype=torch.uint8),
 'CLOTHING/Blouse')

In [100]:
class TripletDeepFashion(Dataset):
    """
    Train: For each sample (anchor) randomly chooses a positive and negative samples
    Test: Creates fixed triplets for testing
    """

    def __init__(self, inner_dataset, mode="train"):
        self.inner_dataset = inner_dataset
        self.mode= mode
        
        self.data = self.inner_dataset.data
        self.labels= self.inner_dataset.labels
        self.labels_set = inner_dataset.unique_labels
        # keeps the sample indices belonging to each label:
        self.label_to_indices = inner_dataset.class_to_idx
        if mode != "train":
            # generate fixed triplets for testing

            random_state = np.random.RandomState(29)

            triplets = [[i,
                         random_state.choice(self.label_to_indices[self.labels[i].item()]),
                         random_state.choice(self.label_to_indices[
                                                 np.random.choice(
                                                     list(self.labels_set - set([self.labels[i].item()]))
                                                 )
                                             ])
                         ]
                        for i in range(len(self.data))]
            self.triplets = triplets

    def __getitem__(self, index):
        if self.mode == "train":
            #img1, label1 = self.data[index], self.labels[index].item()
            img1, label1 = self.inner_dataset.__getitem__(index)
            
            positive_index = index
            while positive_index == index:
                positive_index = np.random.choice(self.label_to_indices[label1])
            
            negative_label = np.random.choice(list(self.labels_set - set([label1])))
            negative_index = np.random.choice(self.label_to_indices[negative_label])

            img2, label2= self.inner_dataset.__getitem__(positive_index)
            img3, label3= self.inner_dataset.__getitem__(negative_index)

            return (img1, img2, img3), [label1, label2, label3]
        else:
            img1, label1= self.inner_dataset.__getitem__(self.triplets[index][0])
            img2, label2 = self.inner_dataset.__getitem__(self.triplets[index][1])
            img3, label3= self.inner_dataset.__getitem__(self.triplets[index][2])

            return (img1, img2, img3), [label1, label2, label3]

    def __len__(self):
        return len(self.inner_dataset)

In [101]:
TripletDataset = TripletDeepFashion( dataset, mode="train")

In [106]:
TripletDataset.label_to_indices

{'CLOTHING/Jeans': [699,
  700,
  701,
  702,
  703,
  704,
  705,
  706,
  707,
  708,
  709,
  710,
  711,
  712,
  713,
  714,
  715,
  716,
  717,
  718,
  719,
  720,
  721,
  722,
  723,
  724,
  725,
  726,
  727,
  728,
  729,
  730,
  731,
  732,
  733,
  734,
  735,
  736,
  737,
  738,
  739,
  740,
  741,
  742,
  743,
  744,
  745,
  746,
  747,
  748,
  749,
  750,
  751,
  752,
  753,
  754,
  755,
  756,
  757,
  758,
  759,
  760,
  761,
  762,
  763,
  764,
  765,
  766,
  767,
  768,
  769,
  770,
  771,
  772,
  773,
  774,
  775,
  776,
  777,
  778,
  779,
  780,
  781,
  782,
  783,
  784,
  785,
  786,
  787,
  788,
  789,
  790,
  791,
  792,
  793,
  794,
  795,
  796,
  797,
  798,
  799,
  800,
  801,
  802,
  803,
  804,
  805,
  806,
  807,
  808,
  809,
  810,
  811,
  812,
  813,
  814,
  815,
  816,
  817,
  818,
  819,
  820,
  821,
  822,
  823,
  824,
  825,
  826,
  827,
  828,
  829,
  830,
  831,
  832,
  833,
  834,
  835,
  836,
  837,
  838,
  

In [108]:
TripletDataset.__getitem__(0)

((tensor([[[107, 109, 110,  ...,  62,  62,  56],
           [107, 109, 110,  ...,  63,  60,  59],
           [105, 107, 109,  ...,  61,  60,  60],
           ...,
           [ 43,  45,  47,  ...,  13,  11,   6],
           [ 42,  45,  47,  ...,  15,  12,   9],
           [ 42,  45,  47,  ...,  18,  13,  14]],
  
          [[192, 194, 195,  ..., 151, 151, 148],
           [192, 194, 195,  ..., 152, 152, 151],
           [192, 194, 196,  ..., 153, 152, 152],
           ...,
           [146, 148, 150,  ...,  86,  86,  82],
           [145, 148, 150,  ...,  88,  87,  85],
           [145, 148, 150,  ...,  91,  90,  90]],
  
          [[248, 250, 251,  ..., 231, 231, 227],
           [248, 250, 251,  ..., 232, 231, 230],
           [248, 250, 252,  ..., 230, 229, 229],
           ...,
           [233, 235, 237,  ..., 166, 165, 164],
           [232, 235, 237,  ..., 168, 166, 167],
           [232, 235, 237,  ..., 171, 169, 173]]], dtype=torch.uint8),
  tensor([[[255, 255, 255,  ..., 255, 25