In [1]:
import random
import collections
import os
import wandb
from dotenv import load_dotenv, find_dotenv
import json
import tensorflow as tf
from tqdm import tqdm
import numpy as np
from text_utils import generate_text_artifacts
from image_utils import load_image
from config import subset, batch_size, max_length, vocabulary_size, train_split
from config import WANDB_PROJECT, WANDB_ENTITY
from utils import save_to_pickle, load_from_pickle
import pandas as pd
load_dotenv(find_dotenv())

True

In [2]:
run = wandb.init(project=WANDB_PROJECT,
                 entity=WANDB_ENTITY, name="log-test-train-split", job_type="data_process")

[34m[1mwandb[0m: Currently logged in as: [33ma-sh0ts[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.11 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [3]:
images_art = run.use_artifact("images:latest")
images_path = images_art.download()
img_cap_table = run.use_artifact(
    "image_caption_table:latest").get("img_cap_table")
captions = img_cap_table.get_column("caption")
img_names = img_cap_table.get_column("name")

img_name_vector = [os.path.join(images_path, img_name)
                   for img_name in img_names]

[34m[1mwandb[0m: Downloading large artifact images:latest, 12867.46MB. 82783 files... Done. 0:0:0
[34m[1mwandb[0m: Downloading large artifact image_caption_table:latest, 951.74MB. 6001 files... Done. 0:0:0


In [4]:
caption_dataset = tf.data.Dataset.from_tensor_slices(captions)
# cap vecotr contains each sentence as max_length where the word position index is the vocab index
_, cap_vector, _, _ = generate_text_artifacts(
    caption_dataset, max_length=max_length, vocabulary_size=vocabulary_size, return_mapping=False)

img_to_cap_vector = collections.defaultdict(list)
for img, cap in zip(img_name_vector, cap_vector):
    img_to_cap_vector[img].append(cap)

# Create training and validation sets using an 80-20 split randomly.
img_keys = list(img_to_cap_vector.keys())
random.shuffle(img_keys)

slice_index = int(len(img_keys)*train_split)
img_name_train_keys, img_name_val_keys = img_keys[:
                                                  slice_index], img_keys[slice_index:]

img_name_train = []
cap_train = []
for img_path in img_name_train_keys:
    train_img_name = os.path.basename(img_path)
    train_caps = img_to_cap_vector[img_path]
    capt_len = len(train_caps)
    img_name_train.extend([train_img_name] * capt_len)
    cap_train.extend(train_caps)

img_name_val = []
cap_val = []
for img_path in img_name_val_keys:
    val_img_name = os.path.basename(img_path)
    val_caps = img_to_cap_vector[img_path]
    capv_len = len(val_caps)
    img_name_val.extend([val_img_name] * capv_len)
    cap_val.extend(val_caps)

split_art_dir = os.path.join(".", "split_data")
if not os.path.exists(split_art_dir):
    os.makedirs(split_art_dir)

In [5]:
save_to_pickle(img_name_train, os.path.join(
    split_art_dir, "img_name_train.pkl"))
save_to_pickle(cap_train, os.path.join(split_art_dir, "cap_train.pkl"))
save_to_pickle(img_name_val, os.path.join(
    split_art_dir, "img_name_val.pkl"))
save_to_pickle(cap_val, os.path.join(
    split_art_dir, "cap_val.pkl"))

In [6]:
split_art = wandb.Artifact(name="split", type="dataset")
split_art.add_dir(split_art_dir)

[34m[1mwandb[0m: Adding directory to artifact (.\split_data)... Done. 0.1s


In [7]:
train_wandb_imgs = [wandb.Image(os.path.join(images_path, path))
                    for path in img_name_train]
val_wandb_imgs = [wandb.Image(os.path.join(images_path, path))
                  for path in img_name_val]

In [8]:
train_img_cap_table =  wandb.Table(columns=["name", "image", *[f"word_index_{i}" for i in range(max_length)]])
for name, image, caption in tqdm(zip(img_name_train, train_wandb_imgs, cap_train)):
    train_img_cap_table.add_data(name, image, *[cap_index.numpy() for cap_index in caption])

val_img_cap_table = wandb.Table(columns=["name", "image", *[f"word_index_{i}" for i in range(max_length)]])
for name, image, caption in tqdm(zip(img_name_val, val_wandb_imgs, cap_val)):
    val_img_cap_table.add_data(name, image, *[cap_index.numpy() for cap_index in caption])

24012it [03:55, 102.11it/s]
6001it [01:00, 98.99it/s] 


In [9]:
train_img_cap_table.set_fk("name", img_cap_table, "name")
val_img_cap_table.set_fk("name", img_cap_table, "name")

In [10]:
split_art.add(train_img_cap_table, "train_img_cap_table")
split_art.add(val_img_cap_table, "val_img_cap_table")

<ManifestEntry digest: mZ/rlU22nHZHBE95Ue54RQ==>

In [11]:
run.log({
    "train_img_cap_table": train_img_cap_table,
    "val_img_cap_table": val_img_cap_table
})
run.log_artifact(split_art)
run.finish()

VBox(children=(Label(value=' 980.25MB of 980.25MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=…