In [1]:
#HYPERPARAMETERS
BATCH_SIZE_TRAIN_LOADER = 32
BATCH_SIZE_TEST_LOADER = 4
TRAIN_EPOCHS = 3
INPUT_RESIZE = 224
FEATURE_VEC_SIZE = 256
LEARNING_RATE = 0.001
#DATA './Data/Sketchy/256x256' 12500 Images
DATA = './Data/Sketchy/256x256'

In [2]:
def imshow(img):
    img = img.to("cpu")
    img = torchvision.utils.make_grid(img)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [3]:
#Compare two paths, if both paths have same string until "-" then sketch matches image
#Argument order IS RELEVANT!
def is_matching(path_image, path_sketch):
    splitted_image = path_image.split('.')[0]
    splitted_sketch = path_sketch.split('-')[0]
    if splitted_image==splitted_sketch:
        return True
    else:
        return False
    
    
path_image = 'n07679356_23722.jpg'
path_sketch = 'n07679356_23722-1.png'

is_matching(path_image, path_sketch)

True

In [4]:
def showImgSketchAtIndex(dataset,index):
    print(f"Path: {dataset[index][2]}")
    imshow(dataset[index][0])
    print(f"Path: {dataset[index][3]}")
    imshow(dataset[index][1])
    print("---------------------------------------")


def show_pairs(dataset, start, stop):
    for i in range (start, stop):
        showImgSketchAtIndex(dataset,i)

In [5]:
import torch
from torch.utils.data import Dataset
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import os

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
     torchvision.transforms.Resize((INPUT_RESIZE,INPUT_RESIZE))])

  warn(f"Failed to load image Python extension: {e}")


In [6]:
#Expects root dir with EXACTLY 2 dirs in it
#                               ->Images(dir name irrelevant)
#                               ->Sketches(dir name irrelevant)
# In THAT order
class ImgSketchDataset(datasets.ImageFolder):
    def __init__(self, root, transform=None, target_transform=None, train=True, trainsplit_decimal=0.9):
        if not (trainsplit_decimal>0.0 and trainsplit_decimal<1.0):
            raise ValueError("trainsplit_decimal must be in range ]0.0,1.0[")

        super().__init__(root, transform, target_transform)
        self.trainsplit_decimal = trainsplit_decimal
        self.train = train
        # imgs_dir e.g. QuickDraw_images_two
        self.imgs_dirs = os.path.join(root, os.listdir(self.root)[0])
        # sketches_dir e.g. QuickDraw_sketches_two
        self.sketches_dirs = os.path.join(root, os.listdir(self.root)[1])
        self.samples = self.make_samples(self.imgs_dirs, self.sketches_dirs, multiple_sketches_per_img=False)

    
    def __getitem__(self, index: int):
        path_img, path_sketch = self.samples[index]
        img = self.loader(path_img)
        sketch = self.loader(path_sketch)
        
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            sketch = self.target_transform(sketch)

        return sketch, img, path_sketch, path_img

    #return list of tuples with path to imgs,sketches like [(airplane1img.jpeg, airplane1sketch.jpeg),...]
    #needs paths like './data/QuickDraw_images_two/airplane/image_00001.jpg' to load with loader   
    #If more sketches than imgs => only #images pairs    
    #Sketch and Image have to contain same folder structure (i.e. airplanes, alarmclocks... folders both in images and sketches folders)    
    # Given trainsplit_decimal == 0.9 IF Trainset => first 90% Images, first 90% Sketches, IF Testset => last 10% images, last 10%sketches
    #TODO Make only 1 sketch per 1 image pairs, no multiple sketches for each img -> Why? Because of ClipLoss (Use multiple_sketches_per_img=False)
    def make_samples(self, imgs_dirs, sketches_dirs, multiple_sketches_per_img=False):
        samples = []
        # nth_loop = 0
        for dir in os.listdir(imgs_dirs):
            #./data/QuickDraw_images_two/airplane
            path_to_imgs = os.path.join(imgs_dirs, dir)
            #./data/QuickDraw_sketches_two/airplane
            path_to_sketches = os.path.join(sketches_dirs, dir)

            imgs = os.listdir(path_to_imgs)
            sketches = os.listdir(path_to_sketches)
            dataset_size_imgs = int(self.trainsplit_decimal*len(imgs))
            if self.train == True:
                imgs = imgs[:dataset_size_imgs]
                # sketches = sketches[:dataset_size_imgs]
            else:
                imgs = imgs[dataset_size_imgs:]
                #Cut off not needed sketches
                sketches = sketches[dataset_size_imgs:]

            #makes (img,sketch) pairs in form [(img1,sketch1-1),(img1,sketch1-2),(img2,sketch2-1),...,(img4,sketch4-3)]
            if multiple_sketches_per_img == True:
                index = 0
                for img in imgs:
                    at_least_one_match = False
                    while(index<len(sketches)):
                        if is_matching(img,sketches[index]):
                            imgPath = os.path.join(path_to_imgs,img)
                            sketchPath = os.path.join(path_to_sketches,sketches[index])
                            samples.append((imgPath,sketchPath))
                            at_least_one_match = True
                            index+=1
                        elif at_least_one_match == False:
                            index+=1
                        else:
                            break
            #alternatively makes (img,sketch) pairs in form [(img1,sketch1-1), (img2,sketch2-1),...,(img4,sketch4-1)]            
            else:
                index = 0
                for img in imgs:
                    while(index<len(sketches)):
                        if is_matching(img,sketches[index]):
                            imgPath = os.path.join(path_to_imgs,img)
                            sketchPath = os.path.join(path_to_sketches,sketches[index])
                            samples.append((imgPath,sketchPath))
                            index+=1
                            break
                        else:
                            index+=1


        return samples

In [7]:
#root = './Data/Sketchy/256x256_less'
root = DATA
train_set = ImgSketchDataset(root, transform, transform, train=True, trainsplit_decimal=0.9)
test_set =  ImgSketchDataset(root, transform, transform, train=False, trainsplit_decimal=0.9)

In [8]:
from torch.utils.data import DataLoader
#TODO Trainset batch size? Not sure what best size is, but for CLIPLOSS it should be pretty big
train_loader = DataLoader(train_set, BATCH_SIZE_TRAIN_LOADER, shuffle=True)
test_loader = DataLoader(test_set, BATCH_SIZE_TEST_LOADER, shuffle=True)


In [9]:
#Model
#TODO fc3 output size?
import torch.nn as nn
import torch.nn.functional as F
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)

        #for 32x32 inputs
        #self.fc1 = nn.Linear(16 * 5 * 5, 120)

        #for 64x64 inputs
        #self.fc1 = nn.Linear(2704, 120)
        
        #for 128x128 inputs
        #self.fc1 = nn.Linear(13456, 120)
        
        #for 224x224 inputs
        self.fc1 = nn.Linear(44944, 120)
        
        
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, FEATURE_VEC_SIZE)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        #print(x.shape)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net().to(device)



In [10]:
from torchvision.models import resnet50
class resNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = resnet50(pretrained=True)
       #self.model = nn.Sequential(*list(self.model.modules())[:-1]) # strips off last linear layer
        self.model = nn.Sequential(*(list(self.model.children())[:-1]))
        self.fc1 = nn.Linear(2048, FEATURE_VEC_SIZE)

    
    def forward(self, x):
       # print(x.shape)
        x = self.model(x)
        x = torch.squeeze(x)
       # print(x.shape)
        x = self.fc1(x)
        return x

net = resNet()
net = net.to(device)

In [11]:
class ClipLoss(nn.Module):

    def __init__(
            self,
            local_loss=False,
            gather_with_grad=False,
            cache_labels=False,
            rank=0,
            world_size=1,
            use_horovod=False,
    ):
        super().__init__()
        self.local_loss = local_loss
        self.gather_with_grad = gather_with_grad
        self.cache_labels = cache_labels
        self.rank = rank
        self.world_size = world_size
        self.use_horovod = use_horovod

        # cache state
        self.prev_num_logits = 0
        self.labels = {}

#What is logit_scale?
    def forward(self, image_features, text_features, logit_scale=1.0):
        device = image_features.device

        logits_per_image = logit_scale * image_features @ text_features.T
        logits_per_text = logit_scale * text_features @ image_features.T

        # calculated ground-truth and cache if enabled
        num_logits = logits_per_image.shape[0]
        if self.prev_num_logits != num_logits or device not in self.labels:
            labels = torch.arange(num_logits, device=device, dtype=torch.long)
            if self.cache_labels:
                self.labels[device] = labels
                self.prev_num_logits = num_logits
        else:
            labels = self.labels[device]

        total_loss = (
            F.cross_entropy(logits_per_image, labels) +
            F.cross_entropy(logits_per_text, labels)
            ) / 2
        return total_loss

In [12]:
import torch.optim as optim
from tqdm import tqdm
import torch
from torch.utils.tensorboard import SummaryWriter
#Train the model 
#For each epoch trains all data in trainloader and then validates on all testloader data
optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)
clip_loss = ClipLoss()
def train(EPOCHS=5):
    writer = SummaryWriter()
    net.train()
    for epoch in range(EPOCHS):
        #Training loop
        train_loss = 0
        train_iter = 0
        train_epoch_cosSim_mean = 0
        loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=False)
        for i, batch in loop:
            optimizer.zero_grad()
            sketches = batch[0].to(device)
            images = batch[1].to(device)
            sketch_feature_vecs = net(sketches)
            img_feature_vecs = net(images)
    
            #print(sketch_feature_vecs)
            #print(sketch_feature_vecs.shape)

            #loss = torch.mean(torch.cdist(sketch_feature_vecs, img_feature_vecs))
            loss = clip_loss.forward(sketch_feature_vecs, img_feature_vecs)
            train_loss += loss
            train_iter += 1
            loss.backward()
            optimizer.step()
            
            loop.set_description(f"Epoch [{epoch}/{EPOCHS}]")
            loop.set_postfix(loss = loss.item())

            # print(loss.item())
            cos = nn.CosineSimilarity(dim=1)
            similarity = cos(sketch_feature_vecs, img_feature_vecs)
            mean_cosSim_ofBatch = torch.mean(similarity)
            if i % 20 == 0:
                print(f"Training Mean Cosine Similarity of Batch #{i} is {mean_cosSim_ofBatch}]")

            train_epoch_cosSim_mean += mean_cosSim_ofBatch
        train_epoch_cosSim_mean = train_epoch_cosSim_mean/train_iter
        train_epoch_loss = train_loss/train_iter    
        print(f"Training Mean Cosine Similarity of EPOCH#{epoch} is [{train_epoch_cosSim_mean}]")    
        print(f"Training Mean Loss of EPOCH#{epoch} is [{train_epoch_loss}]")
        # print(loss.item())
        writer.add_scalar("CosineSim/Train", train_epoch_cosSim_mean, epoch)
        writer.add_scalar("Loss/Train", train_epoch_loss, epoch)


        #Validation Loop
        loop = tqdm(enumerate(test_loader), total=len(test_loader), leave=False)
        net.eval()
        val_loss = 0
        val_iter = 0
        validation_epoch_cosSim_mean = 0
        with torch.no_grad():
            for i, batch in loop:
                sketches = batch[0].to(device)
                images = batch[1].to(device)
                sketch_feature_vecs = net(sketches)
                img_feature_vecs = net(images)
                

                #print(sketch_feature_vecs)
                #print(sketch_feature_vecs.shape)


                #loss = torch.mean(torch.cdist(sketch_feature_vecs, img_feature_vecs))
                loss = clip_loss.forward(sketch_feature_vecs, img_feature_vecs)
                val_loss += loss
                val_iter += 1
                loop.set_description(f"Validation Epoch [{epoch}/{EPOCHS}]")
                loop.set_postfix(loss = loss.item())

                #Cosinesim sketch_vec,img_vec
                cos = nn.CosineSimilarity(dim=1)
                similarity = cos(sketch_feature_vecs, img_feature_vecs)
                mean_cosSim_ofBatch = torch.mean(similarity)
                if i % 20 == 0:
                    print(f"Validation Mean Cosine Similarity of Batch #{i} is {mean_cosSim_ofBatch}]")

                validation_epoch_cosSim_mean += mean_cosSim_ofBatch

        validation_epoch_cosSim_mean = validation_epoch_cosSim_mean/train_iter
        validation_epoch_loss = val_loss/val_iter   
        print(f"Validation Mean Cosine Similarity of EPOCH#{epoch} is [{validation_epoch_cosSim_mean}]")    
        print(f"Validation Mean Loss of EPOCH#{epoch} is [{validation_epoch_loss}]")
        writer.add_scalar("CosineSim/Validation", validation_epoch_cosSim_mean, epoch)
        writer.add_scalar("Loss/Validation", validation_epoch_loss, epoch)

    writer.flush()
    writer.close
    PATH = './sketchy_net.pth'
    torch.save(net.state_dict(), PATH)




In [13]:
train(TRAIN_EPOCHS)

----------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------

INFERENCE PHASE FROM HERE

In [14]:
net = net.to(device)
PATH = './sketchy_net.pth'
net.load_state_dict(torch.load(PATH))
net.eval()

resNet(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(64, 256

In [15]:
# Run this after training is done to save all img feature vecs from training+test datasets, to be able to retrieve them later in testing.
# returns retrieval_imgs_db => [(img_feature_vec1, img1), (img_feature_vec2, img2),...]
def save_all_imgs_to_retrieval_db():
    retrieval_imgs_db = []
    with torch.no_grad():
        for batch in train_loader:
            images = batch[1].to(device)
            images_feature_vecs = net(images)
            zipped = tuple(zip(images_feature_vecs, images))
            for element in zipped:
                retrieval_imgs_db.append(element)
                
                    
        for batch in test_loader:
            images = batch[1].to(device)
            images_feature_vecs = net(images)
            zipped = tuple(zip(images_feature_vecs, images))
            for element in zipped:
                retrieval_imgs_db.append(element)

        assert torch.equal(images_feature_vecs[-2], retrieval_imgs_db[-2][0]) == True
        assert torch.equal(images[-1], retrieval_imgs_db[-1][1]) == True

        return retrieval_imgs_db


In [16]:
imgs_db = save_all_imgs_to_retrieval_db()

In [17]:
#goes over whole batch
#returns most similar images ranked list to corresponding sketches (n=batch size) list: [most_similar_imgs_list_to_sketch#1, most_similar_imgs_list_to_sketch#2, ... most_similar_imgs_list_to_sketch#n]
def retrieve_imgs(sketch_feature_vecs, images_db):
    list_of_images = []
    for sketch_feature_vec in sketch_feature_vecs:
        max_similarity = 0
        images = []
        for img_feature_vec, image in images_db:
            cos = nn.CosineSimilarity(dim=0)
            similarity = cos(sketch_feature_vec, img_feature_vec)
            if similarity > max_similarity:
                max_similarity = similarity
                max_image = image
                images.append(max_image)
                #print(max_similarity)
        images.reverse()
        list_of_images.append(images)        
    return list_of_images   

In [18]:
#runs over data_loader and calculates how many images are found correctly in the imgs_db given input sketch
def calc_model_accuracy(data_loader):
    with torch.no_grad():
        hits = 0
        images_amount = len(data_loader.dataset)
        for data in tqdm(data_loader):
            sketches = data[0].to(device)
            images = data[1].to(device)
            sketch_vecs = net(sketches)
            result_imgs = retrieve_imgs(sketch_vecs, imgs_db)

            # print("no zipped sketch")
            # imshow(sketches[0])
            # print("no zipped inpute image")
            # imshow(images[0])
            # print("no zipped result images")
            # for rank, resultImage in enumerate(result_imgs[0],1):
            #     print("Rank ",rank)
            #     imshow(resultImage)
            
            inputSketch_inputImage_resultImages = tuple(zip(sketches, images, result_imgs))
            for inputSketch, inputImage, resultImages in inputSketch_inputImage_resultImages:
                # print("inputSketch")
                # imshow(inputSketch)
                # print("inputImage")
                # imshow(inputImage)
                # print("resultImages")
                for rank, resultImage in enumerate(resultImages,1):
                    # print("Rank ",rank)
                    # imshow(resultImage)
                    if torch.equal(inputImage,resultImage) == True:
                        hits += 1
                        print("Hit! at rank ",rank)
                        # print("Hit_Input_Sketch")
                        # imshow(inputSketch)
                        # print("Hit_Input_Image")
                        # imshow(inputImage)
                        # print("Hit_Result_Image")
                        # imshow(resultImage)
                        break
            
        accuracy = hits/images_amount     
        print("Total hits ",hits)  
        print("Accuracy: ",accuracy)
        

In [19]:
calc_model_accuracy(test_loader)

  5%|▌         | 16/313 [02:24<44:45,  9.04s/it]


KeyboardInterrupt: 