In [1]:
from PIL import Image 
import matplotlib.pylab as plt
import numpy as np
import os
from scraper import *
import torch
from torch import dist
from torch import nn
from torch.nn import functional as f
from torch import optim
from torchvision.transforms import *

torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f3b0c18f2b0>

In [2]:
def tensor_from_image(path, dims):
    
    image = Image.open(path)
    transform = Compose([Resize(dims)])
    image = transform(image)
    imagematrix = np.asarray(image)
    
    if(len(imagematrix.shape) != 3):
        return None
    imagematrix = np.transpose(imagematrix, (2, 0, 1))
    imagetensor = torch.from_numpy(imagematrix)
    imagetensor = imagetensor.float()
    
    N, H, W = imagetensor.size()
    imagetensor = imagetensor.view((1, N, H, W))
    
    return imagetensor

In [3]:
def normalize(imagetensor):
    
    means = torch.mean(imagetensor, dim=(1, 2))
    stds = torch.std(imagetensor, dim=(1, 2))
    
    for i in range(3):
        imagetensor[i] = (imagetensor[i] - means[i]) / stds[i]
    return imagetensor


In [4]:
class ImageDataset:    
        
    def __init__(self, source_dir, dims):
        
        dirs = os.listdir(source_dir)
        samples = []
        
        for i, dir in enumerate(dirs):
            
            sample = {}
            sample['name'] = dir
            
            q_imagefile = os.path.join(source_dir, dir, 'q.png')
            p_imagefile = os.path.join(source_dir, dir, 'p.png')
            n_imagefile = os.path.join(source_dir, dir, 'n.png')
            
            sample['query'] = tensor_from_image(q_imagefile, dims)
            sample['pos'] = tensor_from_image(p_imagefile, dims)
            sample['neg'] = tensor_from_image(n_imagefile, dims)
            
            if(sample['query'] is None or sample['pos'] is None or sample['neg'] is None):
                continue
            
            samples.append(sample)
        
        self.samples = samples
        self.size = len(samples)
        self.dims = dims
    
    def __iter__(self):
        for sample in self.samples:
            yield sample
        
    def __getitem__(self, idx):
        return self.samples[idx]
    
    def display(self, idx):
        _, ax = plt.subplots(1, 3, figsize=(15, 15))
        images = []
        sample = self.samples[idx]
        for i, category in enumerate(['query', 'pos', 'neg']):
            imagetensor = sample[category]
            imagetensor = imagetensor.cpu()
            imagematrix = imagetensor.numpy()
            imagematrix = imagematrix[0]
            imagematrix = np.transpose(imagematrix, (1, 2, 0))
            imagematrix = np.int32(imagematrix)
            ax[i].imshow(imagematrix)