In [19]:
import pickle
import random
import json
import os
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
from PIL import Image 
import matplotlib.pyplot as plt
import requests
from io import BytesIO
from tqdm import tqdm
# Import dotenv library to get environment variables
from dotenv import load_dotenv
# Import pymango to inset data into mangodb
import pymongo

## Loading Pre Trained Model

As we do not want the classification part, we select the last layer where we would want to extract information. We need to select after a trail and error or retrain the model (only the last layers).

In [2]:
# Loading pretrained Resnet model
model = models.resnet18(pretrained=True)
layer = model._modules.get('avgpool')
# Set model to evaluation mode as we plan to use pre trained model
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

## Transform the image to fit Resnet requirements

In [3]:
scaler = transforms.Resize((224, 224))
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
to_tensor = transforms.ToTensor()
def transform_image(image):
    return Variable(normalize(to_tensor(scaler(image))).unsqueeze(0))

## Converting image to vector

Steps
1. Loading the image
2. Transform the image
3. Create a zero vector to store the feature vector
4. Remap the output of resnet to select the elements from the layer we selected.

In [4]:
def get_vector(image_url):
    response = requests.get(image_url)
    img = Image.open(BytesIO(response.content))
    t_img = transform_image(img)
    #  The avgpool layer that we selected has an output size of 512
    my_embedding = torch.zeros(512)
    def copy_data(m, i, o):
        my_embedding.copy_(o.data.squeeze())
    h = layer.register_forward_hook(copy_data)
    model(t_img)
    h.remove()
    return my_embedding

## Image selection

Generic class to store an image and its IIIF representation

In [5]:
class Terzani_Photo(object):
    def __init__(self, iiif, country):
        self.iiif = iiif
        self.photo = country
        
    def get_photo_link(self):
        return self.iiif["images"][0]["resource"]["@id"]

In [6]:
nu_of_images_per_type = 40

all_photos = pickle.load(open("terzani_recto_iiif.pickle", "rb"))
all_photos = random.sample(all_photos, nu_of_images_per_type)

In [7]:
image_vecs = {} # The keys would be the image labels and the values feature vectors generated using Resnet50
failed_images_vecs = [] # The list of images for which feature vectors were not created
for img in tqdm(all_photos):
     # if the feature vector of the image is not already present
    if img.iiif["label"] not in image_vecs and img.iiif["label"] not in failed_images_vecs:
        # get the image label
        img_lbl = img.iiif["label"]
        
        try:
            feature_vec = get_vector(img.get_photo_link())
            image_vecs[img_lbl] = feature_vec
        except:
            failed_images_vecs.apped(img_lbl)

100%|██████████| 40/40 [01:29<00:00,  2.23s/it]


In [14]:
# with open('image_vectors.json', 'w') as fp:
#     json.dump(image_vecs, fp, indent=4)
    
torch.save(image_vecs, 'image_vecs.pt')
    
if len(failed_images_vecs) > 0:
    print("There are failed images")
    with open('failed_images_vecs.json', 'w') as fp:
        json.dump(failed_images_vecs, fp, indent=4)

## Setup the service account credentials to use the API

In [20]:
load_dotenv()

MANGO_CLIENT_URI = os.getenv('MONGO_URI')
os.environ['MANGO_CLIENT_URI'] = MANGO_CLIENT_URI

In [21]:
# creating a client to work with mango db
mangoclient = pymongo.MongoClient(MANGO_CLIENT_URI)
# selecting the <terzani_photos> database
mango_db = mangoclient["terzani_photos"]

### Storing the Image Tags

In [23]:
# creating a new collection named <sample_tagging>
mango_tag_collection = mango_db["sample_image_vecs"]
# inserting the dictionary into the db
for img_label, img_vec in image_vecs.items():
    mango_tag_collection.insert_one({"image":img_label,"feature_vec":img_vec.tolist()})