# 2. VGG16 Image Embeddings

_created by Austin Poor_

In this notebook, I use a pretrained VGG-16 model to create image embeddings for each of the film stills.

The notebook [1.format-images.ipynb](./1.format-images.ipynb), has placed uniform images in an S3 bucket for this notebook to pull down, process, and then upload the results (as individual parquet files) to another S3 bucket.

In [1]:
import datetime as dt
from pathlib import Path

import boto3
import numpy as np
from PIL import Image

import pyarrow as pa
import pyarrow.parquet as pq

import tensorflow as tf

In [2]:
from concurrent.futures import ThreadPoolExecutor

In [3]:
tmp_dir = Path("./tmp")
tmp_dir.mkdir(exist_ok=True)
[f.unlink() for f in tmp_dir.glob("*") if f.is_file()]

[]

In [4]:
SOURCE_BUCKET = "apoor-clean-movie-stills"
DEST_BUCKET = "apoor-vgg-movie-vecs"

s3 = boto3.client("s3")

In [5]:
batch_size = 100 # Max of 1,000 per S3

In [6]:
input_shape = (300, 300, 3)

vgg16 = tf.keras.applications.VGG16(
    include_top=False,
    weights='imagenet',
    input_shape=input_shape
)
vgg16.summary()

Model: "vgg16"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 300, 300, 3)]     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 300, 300, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 300, 300, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 150, 150, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 150, 150, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 150, 150, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 75, 75, 128)       0     

In [7]:
def iter_keys(bucket: str, batch_size: int = 1_000):
    last_key = ""
    while True:
        resp = s3.list_objects_v2(
            Bucket=SOURCE_BUCKET,
            MaxKeys=batch_size,
            StartAfter=last_key
        )
        keys = [c["Key"] for c in resp["Contents"]]
        yield keys
        if not resp["IsTruncated"]: break
        else: last_key = keys[-1]
            
            
def download_object(bucket: str, key: str, tmp_dir: Path) -> Path:
    res = s3.get_object(Bucket=bucket, Key=key)
    filename = tmp_dir / key
    with open(filename, "wb") as f:
        f.write(res["Body"].read())
    return filename
    
    
def batch_download(bucket: str, keys: [str], tmp_dir: Path) -> [Path]:
    def curried_download(key: str): 
        return download_object(bucket,key,tmp_dir)
    with ThreadPoolExecutor() as P:
        return list(P.map(curried_download,keys))

    
def clean_tmp_files(paths: [Path]):
    [Path(p).unlink() for p in paths]
    
    
def load_image(path: Path) -> np.ndarray:
    img = Image.open(path)
    return np.array(img)

    
def load_images(paths: [Path]) -> np.ndarray:
    return np.concatenate([
        np.expand_dims(load_image(p),0)
        for p in paths
    ],0)


def format_input(data: np.ndarray) -> np.ndarray:
    """Rescale from [0,255] to [0.0,1.0]"""
    return data.astype("float32") / 255


def vgg_process(data: np.ndarray):
    res = vgg16.predict(data)
    return res


def format_output(data: np.ndarray) -> np.ndarray:
    batch_size, *_ = data.shape
    return data.reshape((batch_size, -1))


def make_arrow_table(row: np.ndarray, filename: Path) -> pa.Table:
    return pa.table({Path(filename).stem: row})


def write_parquet(row: np.ndarray, filename: Path) -> Path:
    table = make_arrow_table(row, filename)
    new_filename = filename.with_suffix(".parquet")
    pq.write_table(table, new_filename)
    return new_filename


def write_parquets(data: np.ndarray, filenames: [Path]) -> [Path]:
    return [write_parquet(r, f) for r, f in zip(data, filenames)]


def upload_parquet_files(bucket: str, filenames: [Path]):
    for filename in filenames:
        key = filename.name
        s3.upload_file(str(filename), bucket, key)

In [None]:
start_time = dt.datetime.now()
print(f"START TIME: {start_time}")
print(f"Loading batches of {batch_size:,d} images.\n")

for i, image_keys in enumerate(iter_keys(SOURCE_BUCKET, batch_size)):
    print(f"[{i:4,d}] Batch starting.")
    print("> Downloading images...")
    s = dt.datetime.now()
    image_paths = batch_download(SOURCE_BUCKET, image_keys, tmp_dir)
    print(f"  TIME TO COMPLETE: {dt.datetime.now() - s}")
    print("> Loading into array...")
    input_data = load_images(image_paths)
    input_data = format_input(input_data)
    print("> Removing local image files...")
    clean_tmp_files(image_paths)
    print("> VGG encoding images...")
    s = dt.datetime.now()
    encoding = vgg_process(input_data)
    print(f"  TIME TO COMPLETE: {dt.datetime.now() - s}")
    output_data = format_output(encoding)
    print("> Saving to parquet...")
    parquet_paths = write_parquets(output_data, image_paths)
    print("> Uploading encodings...")
    s = dt.datetime.now()
    upload_parquet_files(DEST_BUCKET, parquet_paths)
    print(f"  TIME TO COMPLETE: {dt.datetime.now() - s}")
    print("> Removing local parquet files...")
    clean_tmp_files(parquet_paths)
    print("> Complete.")
    print("="*70)

print(f"\nFULL TIME TO COMPLETE: {dt.datetime.now() - start_time}")