In [1]:
import glob
from itertools import chain
import os
import random
import zipfile
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR, ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, models

In [2]:
torch.__version__

'2.0.1'

In [3]:
PATH_TRAIN = "../flower_data/train"
PATH_VALID = "../flower_data/valid"

## Model and transforms

In [8]:
class TripletData(Dataset):
    def __init__(self, path, transforms, split="train"):
        self.path = path
        self.split = split    # train or valid
        self.cats = 102       # number of categories
        self.transforms = transforms
    
    def __getitem__(self, idx):
        # our positive class for the triplet
        idx = str(idx%self.cats + 1)
        
        # choosing our pair of positive images (im1, im2)
        positives = os.listdir(os.path.join(self.path, idx))
        im1, im2 = random.sample(positives, 2)
        
        # choosing a negative class and negative image (im3)
        negative_cats = [str(x+1) for x in range(self.cats)]
        negative_cats.remove(idx)
        negative_cat = str(random.choice(negative_cats))
        negatives = os.listdir(os.path.join(self.path, negative_cat))
        im3 = random.choice(negatives)
        
        im1,im2,im3 = os.path.join(self.path, idx, im1), os.path.join(self.path, idx, im2), os.path.join(self.path, negative_cat, im3)
        
        im1 = self.transforms(Image.open(im1))
        im2 = self.transforms(Image.open(im2))
        im3 = self.transforms(Image.open(im3))
        
        return [im1, im2, im3]
        
    # we'll put some value that we want since there can be far too many triplets possible
    # multiples of the number of images/ number of categories is a good choice
    def __len__(self):
        return self.cats*8
    

# Transforms
train_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


# Datasets and Dataloaders
train_data = TripletData(PATH_TRAIN, train_transforms)
val_data = TripletData(PATH_VALID, val_transforms)

# train_loader = torch.utils.data.DataLoader(dataset = train_data, batch_size=32, shuffle=True, num_workers=4)
# val_loader = torch.utils.data.DataLoader(dataset = val_data, batch_size=32, shuffle=False, num_workers=4)
train_loader = torch.utils.data.DataLoader(dataset = train_data, batch_size=32, shuffle=True, num_workers=0)
val_loader = torch.utils.data.DataLoader(dataset = val_data, batch_size=32, shuffle=False, num_workers=0)

## Loss Function

In [9]:
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin
        
    def calc_euclidean(self, x1, x2):
        return (x1 - x2).pow(2).sum(1)
    
    # Distances in embedding space is calculated in euclidean
    def forward(self, anchor, positive, negative):
        distance_positive = self.calc_euclidean(anchor, positive)
        distance_negative = self.calc_euclidean(anchor, negative)
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()

## Trainning

In [10]:
epochs = 2
# device = 'cuda'
device = 'cpu'

# Our base model
# model = models.resnet18().cuda()
model = models.resnet18()
optimizer = optim.Adam(model.parameters(), lr=0.001)
triplet_loss = TripletLoss()

# Training
for epoch in range(epochs):
    
    model.train()
    epoch_loss = 0.0
    for data in tqdm(train_loader):
        optimizer.zero_grad()
        x1,x2,x3 = data
        e1 = model(x1.to(device))
        e2 = model(x2.to(device))
        e3 = model(x3.to(device)) 
        
        loss = triplet_loss(e1,e2,e3)
        epoch_loss += loss
        loss.backward()
        optimizer.step()
    print("Train Loss: {}".format(epoch_loss.item()))

  0%|          | 0/26 [00:00<?, ?it/s]

Train Loss: 222.89385986328125


  0%|          | 0/26 [00:00<?, ?it/s]

Train Loss: 32.15928649902344


## faiss index

In [12]:
#!pip install faiss-cpu
import faiss                            
faiss_index = faiss.IndexFlatL2(1000)   # build the index

im_indices = []
with torch.no_grad():
    for f in glob.glob(os.path.join(PATH_TRAIN, '*/*')):
        im = Image.open(f)
        im = im.resize((224,224))
        im = torch.tensor([val_transforms(im).numpy()])#.cuda()

    
        preds = model(im)
        preds = np.array([preds[0].cpu().numpy()])
        faiss_index.add(preds) #add the representation to index
        im_indices.append(f)   #store the image name to find it later on
        
        

## test

In [13]:
PATH_TEST = "../flower_data/test/"

In [15]:
with torch.no_grad():
    for f in os.listdir(PATH_TEST):
        im = Image.open(os.path.join(PATH_TEST,f))
        im = im.resize((224,224))
        im = torch.tensor([val_transforms(im).numpy()])
    
        test_embed = model(im).cpu().numpy()
        _, I = faiss_index.search(test_embed, 5)
        print("Retrieved Image: {}".format(im_indices[I[0][0]]))

Retrieved Image: ../flower_data/train/71/image_08089.jpg
Retrieved Image: ../flower_data/train/4/image_05649.jpg
Retrieved Image: ../flower_data/train/41/image_02286.jpg
Retrieved Image: ../flower_data/train/20/image_04952.jpg
Retrieved Image: ../flower_data/train/101/image_07945.jpg
Retrieved Image: ../flower_data/train/65/image_03273.jpg
Retrieved Image: ../flower_data/train/73/image_00411.jpg
Retrieved Image: ../flower_data/train/29/image_04131.jpg
Retrieved Image: ../flower_data/train/88/image_00519.jpg
Retrieved Image: ../flower_data/train/71/image_04530.jpg
Retrieved Image: ../flower_data/train/84/image_02619.jpg
Retrieved Image: ../flower_data/train/77/image_00181.jpg
Retrieved Image: ../flower_data/train/4/image_05682.jpg
Retrieved Image: ../flower_data/train/83/image_01812.jpg
Retrieved Image: ../flower_data/train/93/image_06047.jpg
Retrieved Image: ../flower_data/train/74/image_01234.jpg
Retrieved Image: ../flower_data/train/9/image_06427.jpg
Retrieved Image: ../flower_data/t

Retrieved Image: ../flower_data/train/11/image_03152.jpg
Retrieved Image: ../flower_data/train/96/image_07664.jpg
Retrieved Image: ../flower_data/train/45/image_07146.jpg
Retrieved Image: ../flower_data/train/81/image_00812.jpg
Retrieved Image: ../flower_data/train/51/image_01418.jpg
Retrieved Image: ../flower_data/train/43/image_02354.jpg
Retrieved Image: ../flower_data/train/50/image_06549.jpg
Retrieved Image: ../flower_data/train/11/image_03155.jpg
Retrieved Image: ../flower_data/train/8/image_03368.jpg
Retrieved Image: ../flower_data/train/57/image_07243.jpg
Retrieved Image: ../flower_data/train/74/image_01160.jpg
Retrieved Image: ../flower_data/train/49/image_06233.jpg
Retrieved Image: ../flower_data/train/96/image_07630.jpg
Retrieved Image: ../flower_data/train/59/image_05060.jpg
Retrieved Image: ../flower_data/train/51/image_01348.jpg
Retrieved Image: ../flower_data/train/51/image_01428.jpg
Retrieved Image: ../flower_data/train/81/image_00839.jpg
Retrieved Image: ../flower_data/

Retrieved Image: ../flower_data/train/59/image_05050.jpg
Retrieved Image: ../flower_data/train/96/image_07671.jpg
Retrieved Image: ../flower_data/train/72/image_03623.jpg
Retrieved Image: ../flower_data/train/88/image_00552.jpg
Retrieved Image: ../flower_data/train/80/image_02064.jpg
Retrieved Image: ../flower_data/train/83/image_01801.jpg
Retrieved Image: ../flower_data/train/62/image_08164.jpg
Retrieved Image: ../flower_data/train/74/image_01204.jpg
Retrieved Image: ../flower_data/train/55/image_04704.jpg
Retrieved Image: ../flower_data/train/83/image_01761.jpg
Retrieved Image: ../flower_data/train/20/image_04898.jpg
Retrieved Image: ../flower_data/train/69/image_05964.jpg
Retrieved Image: ../flower_data/train/23/image_03443.jpg
Retrieved Image: ../flower_data/train/41/image_02291.jpg
Retrieved Image: ../flower_data/train/11/image_03179.jpg
Retrieved Image: ../flower_data/train/18/image_04301.jpg
Retrieved Image: ../flower_data/train/101/image_07975.jpg
Retrieved Image: ../flower_dat

Retrieved Image: ../flower_data/train/97/image_07693.jpg
Retrieved Image: ../flower_data/train/50/image_06300.jpg
Retrieved Image: ../flower_data/train/77/image_00094.jpg
Retrieved Image: ../flower_data/train/76/image_02539.jpg
Retrieved Image: ../flower_data/train/77/image_00089.jpg
Retrieved Image: ../flower_data/train/74/image_01195.jpg
Retrieved Image: ../flower_data/train/53/image_03713.jpg
Retrieved Image: ../flower_data/train/86/image_02914.jpg
Retrieved Image: ../flower_data/train/82/image_01593.jpg
Retrieved Image: ../flower_data/train/8/image_03304.jpg
Retrieved Image: ../flower_data/train/20/image_04899.jpg
Retrieved Image: ../flower_data/train/7/image_07201.jpg
Retrieved Image: ../flower_data/train/60/image_02943.jpg
Retrieved Image: ../flower_data/train/83/image_01826.jpg
Retrieved Image: ../flower_data/train/81/image_00793.jpg
Retrieved Image: ../flower_data/train/52/image_04164.jpg
Retrieved Image: ../flower_data/train/67/image_07057.jpg
Retrieved Image: ../flower_data/t

Retrieved Image: ../flower_data/train/78/image_01961.jpg
Retrieved Image: ../flower_data/train/75/image_02185.jpg
Retrieved Image: ../flower_data/train/81/image_00889.jpg
Retrieved Image: ../flower_data/train/41/image_02217.jpg
Retrieved Image: ../flower_data/train/56/image_02848.jpg
Retrieved Image: ../flower_data/train/83/image_01779.jpg
Retrieved Image: ../flower_data/train/8/image_03355.jpg
Retrieved Image: ../flower_data/train/69/image_05972.jpg
Retrieved Image: ../flower_data/train/87/image_05472.jpg
Retrieved Image: ../flower_data/train/62/image_07281.jpg
Retrieved Image: ../flower_data/train/94/image_07459.jpg
Retrieved Image: ../flower_data/train/73/image_00408.jpg
Retrieved Image: ../flower_data/train/101/image_07956.jpg
Retrieved Image: ../flower_data/train/94/image_07426.jpg
Retrieved Image: ../flower_data/train/99/image_07847.jpg
Retrieved Image: ../flower_data/train/95/image_07587.jpg
Retrieved Image: ../flower_data/train/64/image_06121.jpg
Retrieved Image: ../flower_data

Retrieved Image: ../flower_data/train/7/image_08102.jpg
Retrieved Image: ../flower_data/train/75/image_02175.jpg
Retrieved Image: ../flower_data/train/69/image_05960.jpg
Retrieved Image: ../flower_data/train/2/image_05128.jpg
Retrieved Image: ../flower_data/train/51/image_03914.jpg
Retrieved Image: ../flower_data/train/63/image_05902.jpg
Retrieved Image: ../flower_data/train/100/image_07921.jpg
Retrieved Image: ../flower_data/train/47/image_05015.jpg
Retrieved Image: ../flower_data/train/28/image_05219.jpg
Retrieved Image: ../flower_data/train/51/image_03949.jpg
Retrieved Image: ../flower_data/train/71/image_04547.jpg
Retrieved Image: ../flower_data/train/78/image_01898.jpg
Retrieved Image: ../flower_data/train/71/image_04525.jpg
Retrieved Image: ../flower_data/train/12/image_04068.jpg
Retrieved Image: ../flower_data/train/96/image_07614.jpg
Retrieved Image: ../flower_data/train/82/image_01631.jpg
Retrieved Image: ../flower_data/train/83/image_01761.jpg
Retrieved Image: ../flower_data/