## Generates a file containing the PGM dataset data

In [None]:
import os
import numpy as np

def load_data(folder_name, train_test_split=False, small=True, save_name="data.npz"):
    X = []
    y = []
    X_train = []
    y_train = []
    X_val = []
    y_val = []
    X_test = []
    y_test = []

    for filename in os.listdir(folder_name):
        if filename.endswith(".npz"):
            file_path = os.path.join(folder_name, filename)
            data = np.load(file_path)
            image = data["image"].reshape(16, 160, 160)
            target = data["target"]
            if train_test_split:
                if "train" in filename:
                    X_train.append(image)
                    y_train.append(target)
                elif "val" in filename:
                    X_val.append(image)
                    y_val.append(target)
                elif "test" in filename:
                    X_test.append(image)
                    y_test.append(target)
            else:
                X.append(image)
                y.append(target)
                if small and len(X) == 1000 :
                    break    
            data.close()

    X = np.array(X)
    y = np.array(y)
    X_train = np.array(X_train)
    y_train = np.array(y_train)
    X_val = np.array(X_val)
    y_val = np.array(y_val)
    X_test = np.array(X_test)
    y_test = np.array(y_test)

    if train_test_split:
        return X_train, X_val, X_test, y_train, y_val, y_test
    else:
        save_dict = {"X": X, "y": y}
        np.savez(save_name, **save_dict)
        return X, y

In [None]:
import matplotlib.pyplot as plt

# Define the folder name
folder_name = "path/to/folder"

# Call the load_data function
X, y = load_data(folder_name)
## Save the data into a single numpy record indexed by "X" and "y"
np.savez("data.npz", X=X, y=y)

# Plot the first few images
fig, axes = plt.subplots(1, len(X[:5]), figsize=(12, 4))
for i, images in enumerate(X[:5]):
    ## Note that images here contains 16 images; the first 8 are context, the second 8 are the targets.
    ## The winning target is given by y[i]; the winning target is the target that the model should predict.
    ## Plot the 8 context images and separately (in a way that is clear) the 8 target images; and highlight the correct target
    # Plot the first few images
    fig, axes = plt.subplots(2, 8, figsize=(16, 6))

    # Plot the context images
    for i, image in enumerate(X[:8]):
        axes[0, i].imshow(image)
        axes[0, i].axis('off')
        axes[0, i].set_title(f'Context {i+1}')

    # Plot the target images
    for i, image in enumerate(X[8:16]):
        axes[1, i].imshow(image)
        axes[1, i].axis('off')
        axes[1, i].set_title(f'Target {i+1}')

    # Highlight the correct target
    correct_target = y[0]
    axes[1, correct_target].set_title(f'Target {correct_target+1} (Correct)', color='green')

    plt.tight_layout()
    plt.show()
    