# Create Image Embeddings with VGG16

Convert each `224 x 224 x 3` (_aka_ `150,528`-dimensional) image into a `4,096`-dimensional embeddings.

In [1]:
import json
import struct
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import h5py
from PIL import Image

In [2]:
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models

Load pre-trained `VGG16` model from PyTorch. Then remove the final layer of the classifier.

In [3]:
vgg16 = models.vgg16(pretrained=True)
vgg16.classifier = nn.Sequential(*[vgg16.classifier[i] for i in range(4)])
vgg16

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [4]:
stills_dir = Path("data/vgg16images/stills")
stills_dir

PosixPath('data/vgg16images/stills')

In [5]:
def load_image(mid,sid,path):
    # Load image into a BytesIO obj
    with path.open("rb") as f:
        # Load into a PIL Image
        img = Image.open(f)
        # Transform to be (224 x 224)
        w, h = img.size
        pad = 224 - h
        img = img.crop((0,-pad/2,w,h+pad/2))
    aimg = np.expand_dims(np.array(img),0).astype("float32") / 255
    t = torch.from_numpy(aimg)
    return t.permute(0, 3, 1, 2)

Load in the stills data as a DataFrame with `movie-id`, `still-id`, and the `filename` of the still.

In [6]:
import itertools as it

In [7]:
with open("data/image-data.json") as f:
    images = json.load(f)
    
images = pd.DataFrame(
    it.chain(*[
        [{
            "mid": int(m["mid"]),
            "sid": sid,
            "filename": path
        } for sid, path in enumerate(m["stills"])]
        for m in images
    ])
)
images.head()

Unnamed: 0,mid,sid,filename
0,0,0,0000_10CloverfieldLane_00000.png
1,0,1,0000_10CloverfieldLane_00001.png
2,0,2,0000_10CloverfieldLane_00002.png
3,0,3,0000_10CloverfieldLane_00003.png
4,0,4,0000_10CloverfieldLane_00004.png


In [8]:
from tqdm import tqdm

In [9]:
n_images = len(images)
vgg_output_shape = vgg16.classifier[-1].out_features
print("Output shape = {:,d} x {:,d}".format(n_images, vgg_output_shape))

Output shape = 132,617 x 4,096


Create a HDF5 file to hold the embedding data.

In [10]:
f = h5py.File("data/vgg16images/embeddings.hdf5","w")
f

<HDF5 file "embeddings.hdf5" (mode r+)>

In [11]:
dset = f.create_dataset("mydataset", (n_images,vgg_output_shape), dtype='f')
dset

<HDF5 dataset "mydataset": shape (132617, 4096), type "<f4">

Loop through the images and:
* Load the image
* Pass it through VGG16
* Store the embedding in the HDF5 file
* Store errors.

In [None]:
load_errors = []
vgg_errors  = []
images_processed = 0

progress_bar = tqdm(
    enumerate(images.itertuples(index=False)),
    desc="Images Processed",
    total=len(images),
    ncols=100
)
for i, (mid, sid, file) in progress_bar:
    path = stills_dir / f"{mid:04d}" / file
    try:
        img = load_image(mid,sid,path)
    except:
        load_errors.append((mid,sid,file))
        continue
    try:
        embedding = vgg16(img)
    except:
        vgg_errors.append((mid,sid,file))
        continue
    dset[i] = embedding.detach().numpy().reshape((vgg_output_shape,))
    images_processed += 1

print(f"Successfully processed {images_processed:6,d} of {len(images):6,d} images")
print(f"{len(load_errors):6,d} image loading errors")
print(f"{len(vgg_errors):6,d} vgg16 embedding errors")

Images Processed:  24%|████████▎                         | 32305/132617 [2:45:39<8:30:02,  3.28it/s]

In [None]:
f.close()