# NN Project Data Generator

### Imports

In [None]:
import h5py
import torch
import numpy as np
import matplotlib.pyplot as plt
import clip
from PIL import Image
import requests
from importlib import reload
import io
# Import our custom modules
import image_loader
reload(image_loader)

### Read Data From H5 Files

In [None]:
with h5py.File("data/eee443_project_dataset_train.h5", "r") as f:
    print("Keys: %s" % f.keys())
    train_cap = np.array(f["train_cap"])
    train_imid = np.array(f["train_imid"])
    #train_ims = np.array(f["train_ims"])
    train_url = np.array(f["train_url"])
    word_code = np.array(f["word_code"])
words = np.array(word_code.dtype.names)
word_indices = np.array(list(word_code[0]), dtype=np.int32)
with h5py.File("data/eee443_project_dataset_test.h5", "r") as f:
    print("Keys: %s" % f.keys())
    test_cap = np.array(f["test_caps"])
    test_imid = np.array(f["test_imid"])
    #test_ims = np.array(f["test_ims"])
    test_url = np.array(f["test_url"])
train_N = train_cap.shape[0]
test_N = test_cap.shape[0]
train_cap.shape, train_imid.shape, train_url.shape, test_cap.shape, test_imid.shape, test_url.shape

### Load CLIP

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocessor = clip.load("ViT-B/32", device=device)

### Load or Calculate Tokenized Captions

In [None]:
try:
    tokenized_train_captions = torch.load("data/tokenized_train_captions.pt", map_location=device)
except FileNotFoundError:
    ends = np.where(train_cap == 2)[1]
    all_caption =  [""] * train_N
    for i in range(len(train_cap)):
        cap_int = train_cap[i,1:ends[i]]
        cap_int = [cap for cap in cap_int if cap not in [0,1,2,3]]
        cap = " ".join(words[cap_int])
        all_caption[i] = cap
    tokenized_train_captions = clip.tokenize(all_caption).to(device)
    torch.save(tokenized_train_captions, "data/tokenized_train_captions.pt")
try:
    tokenized_test_captions = torch.load("data/tokenized_test_captions.pt", map_location=device)
except FileNotFoundError:
    ends = np.where(test_cap == 2)[1]
    all_caption =  [""] * test_N
    for i in range(len(test_cap)):
        cap_int = test_cap[i,1:ends[i]]
        cap_int = [cap for cap in cap_int if cap not in [0,1,2,3]]
        cap = " ".join(words[cap_int])
        all_caption[i] = cap
    tokenized_test_captions = clip.tokenize(all_caption).to(device)
    torch.save(tokenized_test_captions, "data/tokenized_test_captions.pt")

### Load or Calculate Text Features

In [None]:
try:
    encoded_train_captions = torch.load("data/encoded_train_captions.pt", map_location=device)
except FileNotFoundError:
    encoded_train_captions = torch.empty((train_N,512), device=device)
    TEXT_ENCODE_BATCH = 1000
    with torch.no_grad():
        for i in range(train_N//100):
            encoded_train_captions[i*TEXT_ENCODE_BATCH:(i+1)*TEXT_ENCODE_BATCH] = model.encode_text(tokenized_train_captions[i*TEXT_ENCODE_BATCH:(i+1)*TEXT_ENCODE_BATCH]).float()
            print(f"Encoded {i*TEXT_ENCODE_BATCH} captions", end="\r")
    torch.save(encoded_train_captions, "data/encoded_train_captions.pt")
try:
    encoded_test_captions = torch.load("data/encoded_test_captions.pt", map_location=device)
except FileNotFoundError:
    encoded_test_captions = torch.empty((test_N,512), device=device)
    TEXT_ENCODE_BATCH = 1000
    with torch.no_grad():
        for i in range(test_N//100):
            encoded_test_captions[i*TEXT_ENCODE_BATCH:(i+1)*TEXT_ENCODE_BATCH] = model.encode_text(tokenized_test_captions[i*TEXT_ENCODE_BATCH:(i+1)*TEXT_ENCODE_BATCH]).float()
            print(f"Encoded {i*TEXT_ENCODE_BATCH} captions", end="\r")
    torch.save(encoded_test_captions, "data/encoded_test_captions.pt")

### Remove Tokenized Captions (only needed for caption encoding)

In [None]:
del tokenized_train_captions, tokenized_test_captions

### Load Image Features

In [None]:
try:
    test_image_features = torch.load("data/test_image_features.pt")
except FileNotFoundError:
    print("Test Image Fatures Missing")
try:
    train_image_features = torch.load("data/train_image_features.pt")
except FileNotFoundError:
    print("Train Image Fatures Missing")

### Load URL Health Masks

In [None]:
try:
    healty_test_urls = np.load("data/healty_test_urls.npy")
except FileNotFoundError:
    print("Healty Test URLs Missing")
try:
    healty_train_urls = np.load("data/healty_train_urls.npy")
except FileNotFoundError:
    print("Healty Train URLs Missing")


### Remove Missing Images from ALL Datasets and Train-Validation Split

In [None]:
try:
    test_X = np.load("data/test_X.npy")
    test_Y = np.load("data/test_Y.npy")
    train_X = np.load("data/train_X.npy")
    train_Y = np.load("data/train_Y.npy")
    validation_X = np.load("data/validation_X.npy")
    validation_Y = np.load("data/validation_Y.npy")
    encoded_train_X = torch.load("data/encoded_train_X.pt", map_location=device)
    encoded_test_X = torch.load("data/encoded_test_X.pt", map_location=device)
    encoded_validation_X = torch.load("data/encoded_validation_X.pt", map_location=device)
except FileNotFoundError:
    print("Data Missing")
    validation_split = 0.1
    missing_train_url_indices = np.where(healty_train_urls == False)[0]
    missing_test_url_indices = np.where(healty_test_urls == False)[0]
    train_missing_data_mask = np.zeros(train_N, dtype=bool)
    test_missing_data_mask = np.zeros(test_N, dtype=bool)
    for missing_url in missing_train_url_indices:
        train_missing_data_mask[train_imid == missing_url] = True
    for missing_url in missing_test_url_indices:
        test_missing_data_mask[test_imid == missing_url] = True
    clean_train_cap = train_cap[~train_missing_data_mask]
    clean_train_imid = train_imid[~train_missing_data_mask]
    clean_test_cap = test_cap[~test_missing_data_mask]
    clean_test_imid = test_imid[~test_missing_data_mask]
    clean_encoded_train_cap = encoded_train_captions[~train_missing_data_mask]
    clean_encoded_test_cap = encoded_test_captions[~test_missing_data_mask]
    clean_test_N = clean_test_cap.shape[0]
    clean_train_N = clean_train_cap.shape[0]
    val_N = int(validation_split * clean_train_N)
    validation_indices = np.random.choice(clean_train_N, val_N, replace=False)
    train_indices = np.setdiff1d(np.arange(clean_train_N), validation_indices)
    train_X = clean_train_cap[train_indices]
    train_Y = clean_train_imid[train_indices]
    validation_X = clean_train_cap[validation_indices]
    validation_Y = clean_train_imid[validation_indices]
    test_X = clean_test_cap
    test_Y = clean_test_imid
    encoded_train_X = clean_encoded_train_cap[train_indices]
    encoded_validation_X = clean_encoded_train_cap[validation_indices]
    encoded_test_X = clean_encoded_test_cap
    np.save("data/train_X.npy", train_X)
    np.save("data/train_Y.npy", train_Y)
    np.save("data/validation_X.npy", validation_X)
    np.save("data/validation_Y.npy", validation_Y)
    np.save("data/test_X.npy", test_X)
    np.save("data/test_Y.npy", test_Y)
    torch.save(encoded_train_X, "data/encoded_train_X.pt")
    torch.save(encoded_validation_X, "data/encoded_validation_X.pt")
    torch.save(encoded_test_X, "data/encoded_test_X.pt")
    del clean_train_cap, clean_train_imid, clean_test_cap, clean_test_imid
    del train_missing_data_mask, test_missing_data_mask
    del missing_train_url_indices, missing_test_url_indices
    del train_indices, validation_indices
train_N = train_X.shape[0]
val_N = validation_X.shape[0]
test_N = test_X.shape[0]