In [9]:
import torch
from torchvision import datasets, transforms as T, models
import os
from PIL import Image
import torch.nn as nn

In [5]:
# preprocessing the image

In [7]:
def resize_image(image, size):
    """Resize an image to the given size."""
    return image.resize(size, Image.ANTIALIAS)

def resize_images(image_dir, output_dir, size):
    """Resize the images in 'image_dir' and save into 'output_dir'."""
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    images = os.listdir(image_dir)
    num_images = len(images)
    for i, image in enumerate(images):
        with open(os.path.join(image_dir, image), 'r+b') as f:
            with Image.open(f) as img:
                img = resize_image(img, size)
                img.save(os.path.join(output_dir, image), img.format)
        if (i+1) % 100 == 0:
            print ("[{}/{}] Resized the images and saved into '{}'."
                   .format(i+1, num_images, output_dir))

In [11]:
# Image preprocessing, normalization for the pretrained resnet
def transform_img(crop_size):
    transform = T.Compose([ 
        T.RandomCrop(crop_size),
        T.RandomHorizontalFlip(), 
        T.ToTensor(), 
        T.Normalize((0.485, 0.456, 0.406), 
                    (0.229, 0.224, 0.225))])
    return transform

In [3]:
# we used resnet-152 model pretrained on the ILSVRC-2012-CLS image classification dataset.

In [10]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        """Load the pretrained ResNet-152 and replace top fc layer."""
        super(EncoderCNN, self).__init__()
        resnet = models.resnet152(pretrained=True)
        modules = list(resnet.children())[:-1]      # delete the last fc layer.
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
        
    def forward(self, images):
        """Extract feature vectors from input images."""
        with torch.no_grad():
            features = self.resnet(images)
        features = features.reshape(features.size(0), -1)
        features = self.bn(self.linear(features))
        return features
