In [None]:
#@title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the "License")

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# Storing training data in Firestore

This notebook presupposes a slightly unconventional data collection technique: scraping data from an online source (e.g. Reddit) and using that data to train a Vertex AI model. Here, you will collect fictional maps, used for virtual role-playing games, to train a model that can detect gridlines on said maps.

This notebook covers the following steps:

1. Collecting images from Reddit
1. Storing the images in Cloud Storage
1. Inferring training data from the images
1. Storing that training data in a Firestore collection

## Before you begin

You'll need to have a [Reddit API key](https://www.reddit.com/wiki/api/) to access Reddit programmatically. Once you have your API key, you must store it somewhere safe. We recommend storing your API key as a JSON-formatted string in Cloud Secret Manager.

### Store your API key in Cloud Secret Manager

Although you can [create a new secret in Cloud Secret Manager programmatically](https://cloud.google.com/secret-manager/docs/creating-and-accessing-secrets#create), in this notebook you must create it using the Cloud Console.

To create a new secret in the Cloud Console, do the following:

  1. Open the [Cloud Console](https://console.cloud.google.com/security/secret-manager).
  1. Click **Create secret**.
  1. In the **Create secret** page, do the following:
     
     + Give your secret a memorable name. This notebook uses the Reddit API, so the name of the secret
       is `reddit-api-key`.
     + Upload the credentials file. In this example, the `client_id`, `secret`, and `user_agent` credentials
       provided by Reddit are stored as JSON in a single file.
  
  1. Click **Create secret** at the bottom of the page.
  

### Set IAM permissions

[REVIEWERS: this may only be true for accessing Secret Manager from a pipeline]

When you run a notebook on Vertex Workbench, the notebook runs in a Compute Engine context that has its own service account. You will need to give your service account IAM permissions to access Secret Manager before you can use it (in a pipeline).



### Enable the Cloud resources

For this notebook, you must have a Google Cloud project with the following resources:

+ A Cloud Storage bucket
+ The following APIs enabled:
  - Cloud Firestore
  - Vertex AI
  - Storage
  - Secret Manager

In [1]:
# Get your GCP project id from gcloud
shell_output=!gcloud config list --format 'value(core.project)' 2>/dev/null
PROJECT_ID=shell_output[0]
print("Project ID: ", PROJECT_ID)

Project ID:  fantasymaps-334622


In [2]:
BUCKET = "fantasy-maps" # Google Cloud Storage bucket
COLLECTION_NAME = "FantasyMapsTest" # Firestore collection name

### Install the required Python libraries

In [3]:
! rm -rfd requirements.txt

In [4]:
%%writefile requirements.txt
google-cloud-firestore
google-cloud-storage
google-cloud-secret-manager
praw
pandas
numpy
spacy
pillow

Writing requirements.txt


In [5]:
! pip install -r requirements.txt



We will also use a simple natural language parsing library to analyze posts. For this use case, we'll use the open source library [spaCy](https://spacy.io). spaCy requires that a language model be downloaded before it can be used.

In [6]:
! python -m spacy download en_core_web_sm

Collecting en-core-web-sm==3.4.1
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.4.1/en_core_web_sm-3.4.1-py3-none-any.whl (12.8 MB)
     |████████████████████████████████| 12.8 MB 2.0 MB/s            
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')


## Get Reddit API key from Secret Manager

The important bit about an API key is that it should remain _secret_. You don't want to have it embedded in a notebook where anyone can see it!

The next step is to make sure that you can access your Reddit API key programmatically from the notebook. We'll use the API key stored in Secret Manager to make calls to Reddit, both in the notebook and later from a Vertex AI pipeline.

This notebook assumes that your Reddit API key is stored as a JSON-formatted string, with the following fields:

```
{
    "secret": "YOUR_SECRET",
    "client_id": "YOUR_CLIENT_ID",
    "user_agent": "YOUR_USER_AGENT",
    "user_name": "YOUR_REDDIT_USER_NAME"
}
```

In [88]:
def get_reddit_credentials(project_id):
    """Gets the Reddit API key out of Secrets Manager
    
    Arguments:
        project_id (str): the current project ID
    
    Returns:
        JSON object (dict)
    """
    from google.cloud import secretmanager
    import json

    client = secretmanager.SecretManagerServiceClient()

    secret_resource_name = f"projects/{project_id}/secrets/reddit-api-key/versions/1"
    response = client.access_secret_version(request={"name": secret_resource_name})

    payload = response.payload.data.decode("UTF-8")
    reddit_key_json = json.loads(payload)

    return reddit_key_json

In [77]:
reddit_key_json = get_reddit_credentials(PROJECT_ID)

## Query data (posts) on Reddit

Now that we have our API key ready for use, we can query Reddit for our data! In the next cell, we will read the top 100 "hot" posts from a subreddit.

For our use-case, we want to check the posts to see whether: 1) they have an image associated with them; and 2) the title gives us some clues as to the contents (e.g. columns and rows) contained in the image.

To make data visualization, we're going to store the data from Reddit in a `pandas` dataframe.

In [85]:
def get_reddit_posts(reddit_credentials, subreddit_name, limit):
    """Gets posts from a subreddit.
    
    Arguments:
        reddit_credentials (dict): a dictionary with client_id, secret, and user_agent
        subreddit_name (str): the name of the subreddit to scrape posts from
        limit (int): the maximum number of posts to grab
    
    Returns:
        List of Reddit API objects
    """
    import praw

    reddit = praw.Reddit(client_id=reddit_credentials["client_id"], 
                 client_secret=reddit_credentials["secret"],
                 user_agent=reddit_credentials["user_agent"])
    
    return reddit.subreddit(subreddit_name).hot(limit=limit)

In [86]:
def convert_posts_to_dataframe(posts, columns):
    import numpy as np
    import pandas as pd
    
    filtered_posts = [[s.title, s.selftext, s.id, s.url] for s in posts]
    filtered_posts = np.array(filtered_posts)
    reddit_posts_df = pd.DataFrame(filtered_posts,
                               columns=columns)

    return reddit_posts_df

In [87]:
columns = ['Title', 'Post', 'ID', 'URL']

subreddit_name = "battlemaps"
posts = get_reddit_posts(reddit_credentials=reddit_key_json,
                                subreddit_name=subreddit_name, limit=100)
reddit_df = convert_posts_to_dataframe(posts, columns)

Now that we have the top 100 "hot" posts from the subreddit, we're going to filter for only the posts that we want. Again, our criteria are: 1) must have an image; 2) the title must have the grid dimensions of the image.

In [10]:
import re
jpg_df = reddit_posts_df[(reddit_posts_df["URL"].str.contains("jpg")) &
                         (reddit_posts_df["Title"].str.contains(pat = "\d+x\d"))]

jpg_df.head(10)

Unnamed: 0,Title,Post,ID,URL
3,We've just released the third part of Falthrin...,,zlq4cn,https://i.redd.it/phzh9ina1v5a1.jpg
4,Biolab Dungeon [32x22][70px/square],,zlqn3b,https://i.redd.it/x2bilzsr5v5a1.jpg
7,Here's our latest Czepeku Battlemap! A bit of ...,,zlw1u2,https://i.redd.it/d3y8vpo69w5a1.jpg
10,A path through snowy hills [25x30],,zlowht,https://i.redd.it/bd5ui60bqu5a1.jpg
14,Wicker Trail Bridge [30x20] [4200x2800px],,zlvmhw,https://i.redd.it/0uramedf2w5a1.jpg
15,"Hey Folks, here's my brand-new battlemap! Smug...",,zlxrd8,https://i.redd.it/4u0rfn0zkw5a1.jpg
16,Underground Fighting Pit [22x29],,zlvb2s,https://i.redd.it/trimemjs2w5a1.jpg
18,Processing [30x30],,zlo6wu,https://i.redd.it/llq7uj3wiu5a1.jpg
19,Under the Vanthampur Villa - Baldur's Gate: De...,,zlkeu2,https://i.redd.it/jwygttnwct5a1.jpg
20,Gothic Cathedral Ritual Illustration and Battl...,,zlr6r7,https://i.redd.it/x026ghq7av5a1.jpg


## Process the images and their metadata

In this next step, we will process all of the Reddit posts with images posts:

1. Download the image itself
2. Parse each image's metadata
3. (If needed) Split the image into smaller images
4. Store the metadata in Firestore
5. Save the images on Cloud Storage

### Download the image

First we're going to download the image locally. We'll need a meaningful filename to save the image.

In [11]:
def make_nice_filename(name):
    """Converts Reddit post title into a meaningful(ish) filename.
    
    Arguments:
        name (str): title of the post
    
    Returns:
        String. Format is `<adj.>-<nouns>.<cols>x<rows>.jpg`
    """
    import re
    
    dims = re.findall("\d+x\d+", name)
    if len(dims) is 0:
        return ""
    
    dims = dims[0].split("x")
    if len(dims) is not 2:
        return ""
    
    tokens = get_tokens(name)
    new_name = name.lower()[:30]
    
    if len(tokens) > 0:
        tokens = tokens[:6] # Arbitrarily keep new names to six words or less
        new_name = "_".join(tokens)
    
    return f"{new_name}.{dims[0]}x{dims[1]}.jpg"

In [12]:
def get_tokens(title):
    """Analyzes a post for nouns, proper nouns, and adjectives.
    
    Arguments:
        title (str): title of the post
    
    Returns:
        List of string. Words to use in a filename.    
    """
    import spacy

    POS = ["PROPN", "NOUN", "ADJ"]
    
    spacy.prefer_gpu()
    nlp = spacy.load("en_core_web_sm")
    
    words = []
    
    tokens = nlp(title)
    for t in tokens:
        pos = t.pos_
        
        if pos in POS:
            words.append(t.text.lower())
    
    return words 

One last thing: we want to avoid downloading the same image more than once. We'll need to compare the images programmatically to verify that each image is unique.

The easiest way to do this will be to reduce each image to a unique hash value and then ensure that we never have two copies of the same hash value. For the sake of simplicity, we'll use these hash values as the unique ID for each image.

In [13]:
def convert_image_to_hash(content, hashes):
    """Convert image data to hash value (str).
    
    Arguments:
        content (byte array): the image
        hashes (list): a list of hashes from converted strings
    
    Return:
        Bool. Indicates whether the process was success.
    """
    import hashlib
    
    sha1 = hashlib.sha1()
    jpg_hash = sha1.update(content)
    jpg_hash = sha1.hexdigest()
        
    if jpg_hash in hashes:
        hashes.append("")
        return False

    hashes.append(jpg_hash)
    return True

In [14]:
def download_image_local(url, path, hashes):
    """Download an image from the internet to local file system.
    
    Arguments:
        url (str): the image to download
        path (str): the local path to save the image.
        hashes (list): the list of UIDs for downloaded images
    
    Returns:
        Bool. Indicates whether downloading the image was successful.
    """
    import requests
    import shutil
    
    r = requests.get(url, stream=True)
    if r.status_code == 200:
        r.raw.decode_content = True
        
        is_unique = convert_image_to_hash(r.content, hashes)
        if not is_unique:
            return False
        
        with open(path, 'wb') as f:
            f.write(r.content)
    else:
        return False
    
    return True

In [15]:
import os

local_reddit_data_dir = "reddit_maps_data"

if not os.path.exists(local_reddit_data_dir):
    os.mkdir(local_reddit_data_dir)

paths = []
hashes = []
    
for _, row, in jpg_df.head(50).iterrows():
    url = row["URL"]
    filename = make_nice_filename(row["Title"])
    path = f"{local_reddit_data_dir}/{filename}"
    is_downloaded = download_image_local(url, path, hashes)
    
    if is_downloaded:
        paths.append(path)
    else:
        paths.append("")
    
jpg_df = jpg_df.assign(Path=paths, UID=hashes)
jpg_df.head(10)

Unnamed: 0,Title,Post,ID,URL,Path,UID
3,We've just released the third part of Falthrin...,,zlq4cn,https://i.redd.it/phzh9ina1v5a1.jpg,reddit_maps_data/third_part_falthringor_ancien...,455abfb0550a324cc42663c426654f99ec41b40c
4,Biolab Dungeon [32x22][70px/square],,zlqn3b,https://i.redd.it/x2bilzsr5v5a1.jpg,reddit_maps_data/biolab_dungeon_square.32x22.jpg,f00be2ed7b19b8d0f44915a1886437193aa224aa
7,Here's our latest Czepeku Battlemap! A bit of ...,,zlw1u2,https://i.redd.it/d3y8vpo69w5a1.jpg,reddit_maps_data/latest_czepeku_battlemap_bit_...,0bf1a303b4885d91f7a2cd72e27fc247595ad0e4
10,A path through snowy hills [25x30],,zlowht,https://i.redd.it/bd5ui60bqu5a1.jpg,reddit_maps_data/path_snowy_hills.25x30.jpg,b1fe854b9c658cf9312319a447b594743513499b
14,Wicker Trail Bridge [30x20] [4200x2800px],,zlvmhw,https://i.redd.it/0uramedf2w5a1.jpg,reddit_maps_data/wicker_trail_bridge.30x20.jpg,5bc994d7d38c1fc4654c58666d674200f731986b
15,"Hey Folks, here's my brand-new battlemap! Smug...",,zlxrd8,https://i.redd.it/4u0rfn0zkw5a1.jpg,reddit_maps_data/hey_folks_brand_new_battlemap...,34362be27ada680a58e42d26758bf08c01d3460a
16,Underground Fighting Pit [22x29],,zlvb2s,https://i.redd.it/trimemjs2w5a1.jpg,reddit_maps_data/underground_fighting_pit.22x2...,f358f2900db41c68c9c24ae9364a27d406bad1d2
18,Processing [30x30],,zlo6wu,https://i.redd.it/llq7uj3wiu5a1.jpg,reddit_maps_data/processing.30x30.jpg,e086c1cf420e27448bc1a45147b5c43df4b3d8d0
19,Under the Vanthampur Villa - Baldur's Gate: De...,,zlkeu2,https://i.redd.it/jwygttnwct5a1.jpg,reddit_maps_data/vanthampur_villa_baldur_gate_...,3294f8b4cc574e5f67bac75caeb49c0cc22745fb
20,Gothic Cathedral Ritual Illustration and Battl...,,zlr6r7,https://i.redd.it/x026ghq7av5a1.jpg,reddit_maps_data/gothic_cathedral_ritual_illus...,9105ca984984d3c3ffd61d4e1420ec649416c582


### Parse the image for metadata

All we know from the Reddit API is that these posts have JPGs associated with them and that they contain a substring in the format "NNxNN." However, we need more than just images and rough columns and rows for training a Vertex AI AutoML image object detection model. We would even need more data just to use these images in a VTT app.

To get valid training data and VTT data, we need to make some inferences about the images based upon the data that we have (or can get). The data we have are the number of columns and rows stated in the post's title (granted, these are not always accurate). The data we can get is the image's width and height. From these four data points, and assuming that they are accurate and that all cells in the map are uniform, we can infer the width and height of cells in the map.

Using the cell width and height, we can compute the rest of the data required for both an image object detection model and the data needed for a VTT app.

In [16]:
def get_image_width_and_height(path):
    """Open the image and get the image's height and width in pixels.
    
    Arguments:
        path (str):
        
    Returns:
        Tuple of width, height
    """
    from PIL import Image
    import math
    
    img = Image.open(path)
    w, h = img.size
    
    return (math.floor(w), math.floor(h))

#### Virtual tabletop (VTT) data

The easiest for us to compute is the VTT data. This data provides us with the `cellWidth` and `cellHeight` data that will allow us to complete the ML training data.

The JSON structure of VTT data is:

```json
{
    "imageHeight": ##,
    "imageWidth": ##,
    "cellHeight": ##,
    "cellWidth": ##,
    "cellOffsetY": ##,
    "cellOffsetX": ##
}

```


In [31]:
def compute_vtt_data(width, height, columns, rows):
    """Calculate the VTT values for the image.
    
    Arguments:
        width (int):
        height (int):
        columns (int):
        rows (int): 
    Returns:
        Dict.
    """
    
    return {
        "cellsOffsetX": 0, # Assumes no offset
        "cellsOffsetY": 0, # Assumes no offset
        "imageWidth": int(width),
        "imageHeight": int(height),
        "cellWidth": int(width / columns),
        "cellHeight": int(height / rows)
    }

#### Image object detection training data

AutoML image object detection on Vertex AI requires a JSONL file with information for the training data. Each line in the JSONL file needs to contain: the Cloud Storage URI of the image; and the bounding boxes of the objects (cells) that we want to train the model to identify on the image.

The structure of the JSON data in the JSONL file is:

```json
{
    "imageGcsUri": "URI",
    "boundingBoxAnnotations": {
        "displayName": "LABEL_NAME",
        "xMin": ##,
        "xMax": ##,
        "yMin": ##,
        "yMax": ##,
    }
}
```

For each bounding box, we need to provide a percentage value that expresses the vertices of the bounding box as a set of x and y pairs. Also, each bounding box needs to be given a label for that bounding box; all of our bounding boxes are "cells" so each one gets the label `cell`.

**Note**: A training image in Vertex AI can only have at most 500 bounding boxes; many fantasy maps have many more than 500 cells. So that we can use the most of the training data, we will split too large images until smaller images, or "shards", and use the shards  for training.

You can read more about how to format your training manifest JSONL file in [the documentation](https://cloud.google.com/vertex-ai/docs/datasets/prepare-image#object-detection).

In [32]:
def compute_shard_coordinates(width, height, cell_width,
                              cell_height, columns, rows):
    """Converts image data into 1,or more shards.
    
    Arguments:
        width (int):
        height (int):
        cell_width (int):
        cell_height (int):
        columns (int):
        rows (int):
        
    Returns:
        List of tuples of (xMin, yMin, xMax, yMax, columns, rows)
        TODO: convert to pd.Series
    """
    import math
    import pandas as pd
    
    total_cells = columns * rows
    if total_cells <= 500:
        return
    
    # Assume that a perfectly square map that approaches 500 cells is 22 cols by 22 rows.
    # Cut an image into as many 22x22 shards as possible
    SQRT = 22
    
    h_shards = math.floor(columns / SQRT)
    h_rem = columns % SQRT
    v_shards = math.floor(rows / SQRT)
    v_rem = rows % SQRT
    shard_columns = shard_rows = SQRT
    
    # Edge case 1: we have a narrow width (portrait-oriented) map
    if h_shards == 0:
        h_shards = 1
        h_rem = 0
        shard_columns = columns
    
    # Edge case 2: we have a short height (landscape-oriented) map
    if v_shards == 0:
        v_shards = 1
        v_rem = 0
        shard_rows = rows
    
    shards = []
    curr_min_x = 0
    curr_min_y = 0
    for _ in range(h_shards):
        max_x = (cell_width * shard_columns) + curr_min_x
        if max_x > width:
            max_x = width
        for _ in range(v_shards):
            max_y = (cell_height * shard_rows) + curr_min_y
            if max_y > height:
                max_y = height
            
            shards.append((curr_min_x, curr_min_y, max_x, max_y, shard_columns, shard_rows))
            curr_min_y = max_y
            
        curr_min_y = 0
        curr_min_x = max_x
    
    # Get the right-side remainder
    curr_min_x = width - (h_rem * cell_width)
    curr_min_y = 0
    for _ in range(v_shards):
        max_y = (cell_height * shard_rows) + curr_min_y
        if max_y > height:
            max_y = height
        shards.append((curr_min_x, curr_min_y, width, max_y, h_rem, shard_rows))
        curr_min_y = max_y
    
    # Get the bottom-side remainder
    curr_min_y = height - (v_rem * cell_height)
    curr_min_x = 0
    for _ in range(h_shards):
        max_x = (cell_width * shard_columns) + curr_min_x
        if max_x > width:
            max_x = width
        shards.append((curr_min_x, curr_min_y, max_x, height, shard_columns, v_rem))
        curr_min_x = max_x
            
    return shards

In [33]:
def create_shard(x_min, y_min, x_max, y_max, cols, rows, img_path, parent_id):
    """Crops and saves an image.
    
    Arguments:
        x_min (int): the left-most point to crop, relative to the parent image
        y_min (int): the top-most point to crop, relative to the parent image
        x_max (int): the right-most point, relative to the parent image
        y_max (int): the bottom-most poinst, relative to the parent image
        cols (cols): the grid columns in this shard
        rows (rows): the grid rows in this shard
        img_path (str): the parent image's local path
        parent_id (str): the parent image's UID
    
    Returns:
        DataFrame with local path, UID, width, height, columns, and rows
    
    """
    try:
        from PIL import Image
        import math

        img = Image.open(img_path)
        shard = img.crop((int(x_min), int(y_min), int(x_max), int(y_max)))

        # Get new filepath name
        s_path = create_shard_path(img_path, x_min, y_min, cols, rows)

        # Get new UID
        hashes = []
        convert_image_to_hash(shard.tobytes(), hashes)

        shard.save(s_path)
        
        d = {
            "Width": math.floor(x_max - x_min),
            "Height": math.floor(y_max - y_min),
            "Columns": cols,
            "Rows": rows,
            "UID": hashes[0],
            "Path": s_path,
            "IsShard": True,
            "Parent": parent_id
        }
        
    except SystemError as e:
        print(f"Error: {img_path}, bounds: {x_max},{y_max}")
        return None
    
    return pd.DataFrame(data=d, index=[0])

In [34]:
def create_shard_path(path, x_min, y_min, cols, rows):
    """Convert an image path string to new string.
    
    Assumes the image path is of the format:
        <folder>/<name>.<cols>x<rows>.jpg
    
    Arguments:
        path (str):
        x_min (int):
        y_min (int):
        cols (int):
        rows (int):
        
    Returns:
        String. New image path.
    """
    import math
    
    paths = path.split(".")
    paths[-2] = f"{math.floor(x_min)}_{math.floor(y_min)}.{cols}x{rows}"
    s_path = ".".join(paths)
    return s_path

In [35]:
def compute_bboxes(*, dataframe=None, series=None, cell_width=0, cell_height=0):
    """Determines bounding boxes for image object detection.
    
    Arguments:
        dataframe (pandas.Dataframe): A DataFrame with Height, Width, Columns, and Rows
        series (pandas.Series): A Series with Height, Width, Columns, and Rows
        cell_width (int):
        cell_height (int):
    
    Returns:
        List of dict.
    """
    bboxes = []
    try:
        if dataframe is not None:
            width = dataframe.iloc[0]["Width"]
            height = dataframe.iloc[0]["Height"]
            columns = dataframe.iloc[0]["Columns"]
            rows = dataframe.iloc[0]["Rows"]
        elif series is not None:
            width = series["Width"]
            height = series["Height"]
            columns = series["Columns"]
            rows = series["Rows"]
        else:
            return bboxes

        BORDER = 1 # 1px border around the outside of the cell
        LABEL = "cell"

        curr_x = cell_width
        while curr_x < width:
            curr_y = cell_height
            while curr_y < height:
                x_min = (curr_x - BORDER) / width
                y_min = (curr_y - BORDER) / height
                x_max = (curr_x + cell_width + BORDER) / width
                y_max = (curr_y + cell_height + BORDER) / height
                bboxes.append({
                    "xMin": x_min,
                    "xMax": x_max,
                    "yMin": y_min,
                    "yMax": y_max,
                    "displayName": LABEL
                })
                curr_y = curr_y + cell_height
            curr_x = curr_x + cell_width
    except:
        print(f"Error: {dataframe}")
        
    return bboxes

With all of the helper functions in place, we can start processing each image to extract the VTT and image object detection data. We'll need a datastructure to store all of the newly computed data--luckily we can simply append this new data to each `Series` object in the `jpg_df` pandas DataFrame.

In [38]:
import json

shards_df = pd.DataFrame()

for i, row in jpg_df.iterrows():
    local_path = row["Path"]
    
    # Get width & height for original image
    w, h = get_image_width_and_height(local_path)
    
    jpg_df.at[i, "Width"] = w
    jpg_df.at[i, "Height"] = h
    
    # Get columns & rows for original image, based upon the name.
    paths = local_path.split(".")
    dims = paths[-2]
    cols, rows = dims.split("x")
    
    cols = int(cols)
    rows = int(rows)
    
    jpg_df.at[i, "Columns"] = cols
    jpg_df.at[i, "Rows"] = rows
    
    # Compute the vtt data for the image
    vtt = compute_vtt_data(width=w, height=h, columns=cols, rows=rows)
    
    # Note: pandas has issues storing a dict in a cell
    jpg_df.at[i, "VTT"] = json.dumps(vtt) 
    
    # If image doesn't need to be sharded, simply compute and continue
    if (cols * rows) <= 500:
        bboxes = compute_bboxes(series=row,
                                cell_width=vtt["cellWidth"],
                                cell_height=vtt["cellHeight"])
        jpg_df.at[i, "BBoxes"] = json.dumps({ "bboxes": bboxes })
        continue
                            
    # Compute the number of shards
    shards = compute_shard_coordinates(width=w, height=h, columns=cols, rows=rows,
                                       cell_width=vtt["cellWidth"], cell_height=vtt["cellHeight"])
            
    for shard in shards:
        shard_df = create_shard(x_min=shard[0], y_min=shard[1], x_max=shard[2],
                                   y_max=shard[3], cols=shard[4], rows=shard[5],
                                   img_path=local_path, parent_id=row["UID"])

        if shard_df is None:
            continue
            
        s_vtt = vtt
        s_vtt["width"] = int(shard_df.iloc[0]["Width"])
        s_vtt["height"] = int(shard_df.iloc[0]["Height"])
        shard_df.at[0, "VTT"] = json.dumps(s_vtt)
        
        bboxes = compute_bboxes(dataframe=shard_df,
                                cell_width=vtt["cellWidth"],
                                cell_height=vtt["cellHeight"])

        shard_df.at[0, "BBoxes"] = json.dumps({ "bboxes": bboxes })
        shards_df = pd.concat([shards_df, shard_df])

Error: reddit_maps_data/third_part_falthringor_ancient_mountain_stronghold.44x68.jpg, bounds: 6160,3080
Error: reddit_maps_data/third_part_falthringor_ancient_mountain_stronghold.44x68.jpg, bounds: 6160,6160
Error: reddit_maps_data/third_part_falthringor_ancient_mountain_stronghold.44x68.jpg, bounds: 6160,9240
Error: reddit_maps_data/biolab_dungeon_square.32x22.jpg, bounds: 1540,1540
Error: reddit_maps_data/wicker_trail_bridge.30x20.jpg, bounds: 3080,2800
Error: reddit_maps_data/underground_fighting_pit.22x29.jpg, bounds: 3080,3080
Error: reddit_maps_data/gothic_cathedral_ritual_illustration_battlemap_44x63.44x63.jpg, bounds: 3168,1584
Error: reddit_maps_data/gothic_cathedral_ritual_illustration_battlemap_44x63.44x63.jpg, bounds: 3168,3168
Error: reddit_maps_data/mysterious_village_nahaut_strange_customs_creatures.31x44.jpg, bounds: 4092,8192
Error: reddit_maps_data/small_camp_side_forest_road_campers.20x40.jpg, bounds: 2800,3080
Error: reddit_maps_data/gothic_cathedral_interior_scene_

Finally, now that we have all of the VTT and image object detection metadata computed for the original images and/or shards, we can join the two dataframes together into one. As part of this process, we also want to reindex the resulting dataframe so that it uses the UIDs we calculated instead of the automatically generated indices.

In [39]:
complete_df = pd.concat([jpg_df, shards_df])
complete_df.set_index("UID", inplace=True)
complete_df.fillna("", inplace=True)
complete_df.head(10)

Unnamed: 0_level_0,Title,Post,ID,URL,Path,Width,Height,Columns,Rows,VTT,BBoxes,IsShard,Parent
UID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
455abfb0550a324cc42663c426654f99ec41b40c,We've just released the third part of Falthrin...,,zlq4cn,https://i.redd.it/phzh9ina1v5a1.jpg,reddit_maps_data/third_part_falthringor_ancien...,6160.0,9520.0,44.0,68.0,"{""cellsOffsetX"": 0, ""cellsOffsetY"": 0, ""imageW...",,,
f00be2ed7b19b8d0f44915a1886437193aa224aa,Biolab Dungeon [32x22][70px/square],,zlqn3b,https://i.redd.it/x2bilzsr5v5a1.jpg,reddit_maps_data/biolab_dungeon_square.32x22.jpg,2240.0,1540.0,32.0,22.0,"{""cellsOffsetX"": 0, ""cellsOffsetY"": 0, ""imageW...",,,
0bf1a303b4885d91f7a2cd72e27fc247595ad0e4,Here's our latest Czepeku Battlemap! A bit of ...,,zlw1u2,https://i.redd.it/d3y8vpo69w5a1.jpg,reddit_maps_data/latest_czepeku_battlemap_bit_...,6300.0,8400.0,45.0,60.0,"{""cellsOffsetX"": 0, ""cellsOffsetY"": 0, ""imageW...",,,
b1fe854b9c658cf9312319a447b594743513499b,A path through snowy hills [25x30],,zlowht,https://i.redd.it/bd5ui60bqu5a1.jpg,reddit_maps_data/path_snowy_hills.25x30.jpg,1750.0,2100.0,25.0,30.0,"{""cellsOffsetX"": 0, ""cellsOffsetY"": 0, ""imageW...",,,
5bc994d7d38c1fc4654c58666d674200f731986b,Wicker Trail Bridge [30x20] [4200x2800px],,zlvmhw,https://i.redd.it/0uramedf2w5a1.jpg,reddit_maps_data/wicker_trail_bridge.30x20.jpg,4200.0,2800.0,30.0,20.0,"{""cellsOffsetX"": 0, ""cellsOffsetY"": 0, ""imageW...",,,
34362be27ada680a58e42d26758bf08c01d3460a,"Hey Folks, here's my brand-new battlemap! Smug...",,zlxrd8,https://i.redd.it/4u0rfn0zkw5a1.jpg,reddit_maps_data/hey_folks_brand_new_battlemap...,4200.0,5600.0,60.0,80.0,"{""cellsOffsetX"": 0, ""cellsOffsetY"": 0, ""imageW...",,,
f358f2900db41c68c9c24ae9364a27d406bad1d2,Underground Fighting Pit [22x29],,zlvb2s,https://i.redd.it/trimemjs2w5a1.jpg,reddit_maps_data/underground_fighting_pit.22x2...,3080.0,4060.0,22.0,29.0,"{""cellsOffsetX"": 0, ""cellsOffsetY"": 0, ""imageW...",,,
e086c1cf420e27448bc1a45147b5c43df4b3d8d0,Processing [30x30],,zlo6wu,https://i.redd.it/llq7uj3wiu5a1.jpg,reddit_maps_data/processing.30x30.jpg,3600.0,3600.0,30.0,30.0,"{""cellsOffsetX"": 0, ""cellsOffsetY"": 0, ""imageW...",,,
3294f8b4cc574e5f67bac75caeb49c0cc22745fb,Under the Vanthampur Villa - Baldur's Gate: De...,,zlkeu2,https://i.redd.it/jwygttnwct5a1.jpg,reddit_maps_data/vanthampur_villa_baldur_gate_...,4096.0,2521.0,65.0,40.0,"{""cellsOffsetX"": 0, ""cellsOffsetY"": 0, ""imageW...",,,
9105ca984984d3c3ffd61d4e1420ec649416c582,Gothic Cathedral Ritual Illustration and Battl...,,zlr6r7,https://i.redd.it/x026ghq7av5a1.jpg,reddit_maps_data/gothic_cathedral_ritual_illus...,3168.0,4544.0,44.0,63.0,"{""cellsOffsetX"": 0, ""cellsOffsetY"": 0, ""imageW...",,,


## Store the images in Google Cloud Storage

Now that we've created the image shards, we can begin uploading the images to Google Cloud Storage. We'll need to have a Storage bucket already created for this next cell of code.

In [40]:
def store_image_gcs(*, project_id, series, bucket_name, prefix):
    """Copies a local image to Google Cloud Storage.
    
    Arguments:
        project_id (str): the Google Cloud Project ID to use
        series (pd.Series): a Pandas Series with "Path" column
        bucket_name (str): the Cloud Storage bucket to use
        prefix (str): the prefix or "folder" to use in the bucket

    Returns:
        String. The Cloud Storage URI of the image.
    """
    from google.cloud import storage
    
    storage_client = storage.Client(project=project_id)
    bucket = storage_client.bucket(bucket_name)
    
    local_path = series["Path"]
    file_name = local_path.split("/")[-1]
    img_gcs_uri = f"gs://{bucket_name}/{prefix}/{file_name}"
    blob_name = f"{prefix}/{file_name}"
            
    file_blob = bucket.blob(blob_name)           
    file_blob.upload_from_filename(local_path)
    
    return img_gcs_uri
    

In [41]:
for i, row in complete_df.iterrows():
    gcs_uri = store_image_gcs(project_id=PROJECT_ID, series=row,
                                bucket_name=BUCKET, prefix="FantasyMapsTest")
    complete_df.at[i, "GCS URI"] = gcs_uri

## Store the metadata in Firestore

Next we're going to store all of this metadata and URI in Firestore. The benefit of using Firestore is that the fields with JSON-formatted strings-`VTT` and `BBoxes` will automatically be translated into the correct document structure in Firestore after they've been upserted. 

In [42]:
def store_metadata_fs(*, project_id, series, collection_name, uid):
    """Upserts image metadata into a Firestore collection.
    
    Arguments:
        project_id (str): the Google Cloud project to store these in
        series (pd.Series): a Pandas series with the image's metadata
        collection_name (str): the Firestore collection to store the data in
    """
    from google.cloud import firestore
    
    client = firestore.Client(project=project_id)
    
    series_dict = series.to_dict()
    
    # clean up the data a little bit before upserting
    vtt = series["VTT"]
    if vtt is not "":
        vtt = json.loads(vtt)
        series_dict["VTT"] = vtt
        
    bboxes = series["BBoxes"]
    if bboxes is not "":
        bboxes = json.loads(bboxes)["bboxes"]
        series_dict["BBoxes"] = bboxes
    
    file_name = series["Path"].split("/")[-1]
    series_dict.pop("Path", None)
    series_dict["filename"] = file_name
    
    img_gcs_uri = series["GCS URI"]
    series_dict.pop("GCS URI", None)
    series_dict["gcsURI"] = img_gcs_uri
    
    # upsert the dict directly into Firestore!
    client.collection(collection_name).document(uid).set(series_dict)

Very, very last step: iterate over all the training data and store the metadata in the Firestore collection.

In [43]:
for uid, row in complete_df.iterrows():
    store_metadata_fs(project_id=PROJECT_ID, series=row,
                      collection_name=COLLECTION_NAME, uid=uid)
        

## Check the results of the metadata creation

Now that (hopefully) all of the image metadata has been added to the Firestore collection, we can review the data to ensure that it is correct.

To do this, we'll review the documents stored in the Firestore collection to verify that it has all the data we need--the VTT data, bounding boxes, and the GCS URI of the image.

In [54]:
from google.cloud import firestore

client = firestore.Client(project=PROJECT_ID)
collection = client.collection(COLLECTION_NAME)

docs = collection.where("BBoxes", "!=", "").select(field_paths=["gcsURI", "filename", "VTT", "Parent"]).stream()

With this Firestore query, we can verify the image metadata against the stored image in the Storage bucket. We'll first take the results of this query, compose it into a `pandas.DataFrame` object, and then print it out to the cell output. We can first take a look at the parent map (assuming that the map has been sharded) and then conclude whether the map and all its shards should be removed from the training set.

In [55]:
import json
import pandas as pd

docs_list = ((d.to_dict(), d.id) for d in docs)
docs_df = pd.DataFrame()
for i, d in enumerate(docs_list):
    d_dict = d[0]
    vtt = d_dict["VTT"]
    d_dict["VTT"] = json.dumps(vtt)
    d_dict["UID"] = d[1]
    docs_df = pd.concat([docs_df, pd.DataFrame(data=d_dict, index=[0])], ignore_index=True)

In [57]:
docs_df.set_index("UID", inplace=True)
check_set_df = docs_df[["filename", "gcsURI", "Parent"]]
check_set_df.head(10)

KeyError: "None of ['UID'] are in the columns"

Not everyone on Reddit follows the same conventions. Sometimes, there might be be a post where there are dimensions mentioned in the post (e.g. "50x40"), but the image doesn't actually have gridlines.

We shouldn't allow these images into the training and test dataset for our model. Unfortunately, we have to review the images that we've collected on GCS and then verify that they do (or don't!) have gridlines visually.


We'll start by printing out the entirety of our `DataFrame`.

In [59]:
pd.set_option("display.max_rows", 1000)
check_set_df.sort_values(by="filename", ascending=True, inplace=True)
display(check_set_df)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  return func(*args, **kwargs)


Unnamed: 0_level_0,filename,gcsURI,Parent
UID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
a816b6eb31b71cad5b50531d6e18ef46cb451cbd,[oc]ravenhall 24x18.24x18.jpg,gs://fantasy-maps/FantasyMapsTest/[oc]ravenhal...,
65b9e0658f564d1211d5bf4d0804807a06a0035e,abandoned_iron_mine_entrance.0_0.22x22.jpg,gs://fantasy-maps/FantasyMapsTest/abandoned_ir...,7a903be0bad0fc00bfabbefd682b6eef23263b67
d353a848c5ae3f697159848ec23827c0866a7bb7,abandoned_iron_mine_entrance.0_1540.22x13.jpg,gs://fantasy-maps/FantasyMapsTest/abandoned_ir...,7a903be0bad0fc00bfabbefd682b6eef23263b67
aca8edc7ef105b84bc58e70b6cff882e6a6faec1,abandoned_iron_mine_entrance.1540_0.3x22.jpg,gs://fantasy-maps/FantasyMapsTest/abandoned_ir...,7a903be0bad0fc00bfabbefd682b6eef23263b67
270dbb699ddd302a6f61a09ad15ed37bf28bf019,archaeologist_camp.0_0.22x22.jpg,gs://fantasy-maps/FantasyMapsTest/archaeologis...,ddce59ef9a9fb681ca68cb69c598e74372e79fb5
c0e3fbc2e6940da1b3e12eb3bd119e93fd27b7d7,archaeologist_camp.0_2684.22x22.jpg,gs://fantasy-maps/FantasyMapsTest/archaeologis...,ddce59ef9a9fb681ca68cb69c598e74372e79fb5
32500efc1587db341a1a4156a1e3ba92711562c4,archaeologist_camp.0_5412.22x6.jpg,gs://fantasy-maps/FantasyMapsTest/archaeologis...,ddce59ef9a9fb681ca68cb69c598e74372e79fb5
a84e8e94e9f63ef0cb2a6e14cb3bf2a55c6a3012,archaeologist_camp.2684_0.22x22.jpg,gs://fantasy-maps/FantasyMapsTest/archaeologis...,ddce59ef9a9fb681ca68cb69c598e74372e79fb5
b784a85ff19f4903f09ec7cbc7b44a5fb989ec6f,archaeologist_camp.2684_2684.22x22.jpg,gs://fantasy-maps/FantasyMapsTest/archaeologis...,ddce59ef9a9fb681ca68cb69c598e74372e79fb5
2da9b61fd048ffddb9551d329f8c07c81e4d8eba,archaeologist_camp.2684_5412.22x6.jpg,gs://fantasy-maps/FantasyMapsTest/archaeologis...,ddce59ef9a9fb681ca68cb69c598e74372e79fb5


This final step of data prepartion is to mark all of the unusable images in the Firestore collection. Luckily, we can use the Google Cloud Console to view the contents of our Storage bucket. We can even add new fields to the documents in our Firestore collection!

![Storage user interface in the Cloud Console](resources/StorageUI.png)
_Figure. The Google Cloud Storage user interface, showing images in a bucket._

![Firestore user interface in the Cloud Console](resources/FirestoreUI.png)
_Figure. The Cloud Firestore user interface, showing a "Usable" field being added to a document._

For this very last data preparation step, we will visually inspect all of the "parent" images in the Cloud Storage bucket. We will then create a list unusable images, where we store the image's UID. Finally, we will do a bulk update job to Firestore, setting a `Usable` field on the images to `False`.

In [62]:
unusable_parent_images = [
    "d3ee0039aeff5c33de778c5adbdd000f21c0b4cd",
    "9a7e82433239b0087121f6fd31e133f5a94fa7dc",
    "fc70f330cd5a3aaa3be94e0d603dab2876f9fca1",
    "890c0d27318b0286aebf67c392fab286bdc4e7c5",
    "2f4f496466d6e9fb8ad9791ccb81f7c13fd407db",
    "4c0d3eb86ed599f496f7e30e18025021aebfe153",
    "4bf88b8acc669331a65465e8c4b37fd8b9495e4b",
    "b15fe8185c30b3e7800e42e280a3792998a0b55f",
    "5cd7bbd5882fe8b9a270be1aa911c0ba858e818d",
    "67925948371de53b58c09b32c97af60f72c58e0f",
    "a816b6eb31b71cad5b50531d6e18ef46cb451cbd",
    "f00be2ed7b19b8d0f44915a1886437193aa224aa",
    "890c0d27318b0286aebf67c392fab286bdc4e7c5",
    # Others ....
]

usable_parent_images = [
    "7a903be0bad0fc00bfabbefd682b6eef23263b67",
    "b1fe854b9c658cf9312319a447b594743513499b",
    "3a52781685d7aed530a685739b58341fafd2e721",
    "3d17d612a843af7a5ad1a2b2d5dcce29e67d367f",
    "aa33a7ed5c3c87147fd25dafe6c0a1d3eb29dbe5",
    "2ce9f80408137e36531581fb22ee3fe892f41f76",
    "e086c1cf420e27448bc1a45147b5c43df4b3d8d0",
    "5bc994d7d38c1fc4654c58666d674200f731986b",
    "820c3bbfe4d14694ecb729ce3f45b4bda031f61a",
    "2d323018c74db7e0432ff368283ea429f13bd36d",
    "343833ab1d8cd17dae6b702830864500d1e66e19",
    "c60b952d0c20a63dc263220ec5b49a54fd20d175",
    "8ac01d3a84bc548a8b243d08c6e031206f293908",
    "f00be2ed7b19b8d0f44915a1886437193aa224aa",
    "950390c88b7bd9d6f886e5f01bae9460c0aa407b",
    "3322c1b4795e4a9e4477feafd55685d647b4e29c",
    "34362be27ada680a58e42d26758bf08c01d3460a",
    "60d815ba2c1458c3a4039595a5ed723a7501a36e",
    "928b339ba222a3933fce4523f8033fa3eb7ed62f",
    "4931f0033f0ab217fd0fe2a2024d22a119782e2c",
    "83846604933daff160cefc28c2a828bd93a84e1a",
    "d2fe7281d8b1e043009033c957ca347847343e14",
    "c1552a0046ef4ece5d146544043edb2deb97f7a1",
    # And more ...
]

**Note**: If you're thinking that this manual process should be automatable--you're right! In another notebook, we will use a pre-trained version of our gridline-detecting model to accept or reject images.

In [63]:
usable_set = set(usable_parent_images)
unusable_set = set(unusable_parent_images)

In [67]:
from google.cloud import firestore

firestore_client = firestore.Client(project=PROJECT_ID)
bulkwriter = firestore_client.bulk_writer()
collection = firestore_client.collection(COLLECTION_NAME)

In [72]:
# Iterate over all of the metadata entries & images that we want to delete
unusable_shards = collection.where("Parent", "in", list(unusable_set)[:10]).stream()
for doc in unusable_shards:
    bulkwriter.delete(doc.reference)

bulkwriter.flush()

In [74]:
# Iterate over all of the good entries
subset_start_index = 0
while subset_start_index < len(usable_set):
    subset = list(usable_set)[subset_start_index:subset_start_index + 10]
    usable_shards = collection.where("Parent", "in", subset).stream()
    
    for doc in usable_shards:
        bulkwriter.update(doc.reference, { "Usable": True})
        
    subset_start_index = subset_start_index + 10

bulkwriter.flush()

In [75]:
bulkwriter.close()