## Creating a generative AI with twinlab

In [None]:
# Standard imports
import pickle

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# twinLab
import twinlab as tl

In [None]:
# Parameters

# Experiment
# experiment = "MNIST-lowres" # 8x8 images
experiment = "MNIST" # 28x28 images
# experiment = "CIFAR-10" # https://www.cs.toronto.edu/~kriz/cifar.html 

# Random numbers
random_seed = 123

# Training parameters
training_samples = 1500
explained_variance = 0.5
onehot_encode = True

In [None]:
# Functions
def unpickle(file): 
    with open(file, "rb") as fo:
        dict = pickle.load(fo, encoding="bytes")
    return dict

def wrangle_image(linear_image, npix): 
    # Reshape a CIFAR-10 image
    pix = npix**2
    if len(linear_image) == pix:
        image = linear_image.reshape(npix, npix)
    elif len(linear_image) == 3*pix:
        R = linear_image[0*pix:1*pix].reshape(npix, npix)
        G = linear_image[1*pix:2*pix].reshape(npix, npix)
        B = linear_image[2*pix:3*pix].reshape(npix, npix)
        image = np.dstack((R, G, B)).astype(np.uint8)
    else:
        raise ValueError("Image is neither 1D nor 3D.")
    return image

In [None]:
# Calculations
np.random.seed(random_seed)

In [None]:
if experiment == "MNIST-lowres":

    # Read data in and set pixels that the training data has 
    # In this case, it's 8x8 pixel pictures of numbers 0 to 9. 1798 pictures
    npix = 8
    filepath = "MNIST-lowres/data.csv"
    df = pd.read_csv(filepath)

elif experiment == "MNIST":

    # 10,000 examples of 28x28 pixel pictures of numbers 0 to 9
    npix = 28
    filepath = "MNIST/test.csv"
    df = pd.read_csv(filepath)
    df.rename({"label": "number"}, axis="columns", inplace=True)

elif experiment == "CIFAR-10":

    # 32x32 pixel pictures. 10 pictures of 10 different types of object 
    npix = 32
    filepath = "CIFAR-10/data_batch_1"
    data = unpickle(filepath)

    # Iterate through the RGB values that compose these pictures 
    # Each pixel gets a value so we can unpack a 3D object into the 2D dataframe
    df = pd.DataFrame(data[b"data"])
    df.columns = [f"{RGB}-{i}-{j}" for RGB in ["R", "G", "B"] for i in range(npix) for j in range(npix)]
    df["number"] = data[b"labels"] # TODO: Try to insert this as the first column
    image_dict = {
        "airplane": 0,
        "automobile": 1,
        "bird": 2,
        "cat": 3,
        "deer": 4,
        "dog": 5,
        "frog": 6,
        "horse": 7,
        "ship": 8,
        "truck": 9
    }
    df = df[df["number"] == image_dict["dog"]] # TODO: Only dogs for tests!

else:
    raise ValueError("Experiement not recognised")

# One-hot encoding
if onehot_encode:
    df = pd.get_dummies(df, columns=["number"])

# Set up campaign
if onehot_encode:
    inputs = [f"number_{i}" for i in range(10)]
else:
    inputs = ["number"]
outputs = list(df.drop(columns=inputs))

# Print to screen
display(df)

In [None]:
# Plot an image of the data
plt.subplots(10, 10, figsize=(10, 10))
for i in range(100):
    plt.subplot(10, 10, i+1)
    image = wrangle_image(df[outputs].iloc[i].to_numpy(), npix)
    plt.imshow(image, cmap="binary_r", vmin=0., vmax=255.)
    plt.xticks([]); plt.yticks([])
plt.show()

Data campaign

In [None]:
# Setup campaign
setup_dict = {
    "inputs": inputs,
    "outputs": outputs,
    'estimator': 'gaussian_process_regression', # What type of model do you want to use? 
    'decompose_outputs': True, # Equivalent of PCA/SVD; necessary to learn structure
    'output_explained_variance': explained_variance # Toggle this number to improve accuracy
}
campaign = tl.Campaign(**setup_dict)

# Train campaign
train_dict = {
    "df": df,
    "train_test_split": training_samples, # How many rows are used for training?
}
campaign.fit(**train_dict)

In [None]:
# Show some diagnostic information about the trained campaign
campaign.get_diagnostics()

In [None]:
# Make some predictions
df_predict = pd.DataFrame({'number': list(range(10))})
if onehot_encode:
    df_predict = pd.get_dummies(df_predict, columns=["number"])
# df_predict = pd.DataFrame({'number': np.linspace(0.5, 9.5, 10)})
# df_predict = pd.DataFrame({'number': np.linspace(5.5, 6.4, 10)})
display(df_predict)

# Pull out the mean and true to the prediction of the campaign
# Could also pull out the standard deviation (std), not sure if this is useful
df_mean, _ = campaign.predict(df_predict)
display(df_mean)

In [None]:
# Plot the mean value of each figure/number from the trained dataset
plt.subplots(2, 5, figsize=(10, 4))
iplot = 0
for row in range(10):
    iplot += 1
    plt.subplot(2, 5, iplot)
    image = wrangle_image(df_mean.iloc[row].to_numpy(), npix)
    plt.imshow(image, cmap="binary_r", vmin=0., vmax=255.)
    plt.xticks([]); plt.yticks([])
plt.show()

Output

In [None]:
# Generate some random samples from the trained model
# Note that "observation_noise" needs to be true in order to generate diverse samples
nsample = 5
df_samples = campaign.sample(df_predict, nsample, {"observation_noise": True})
# display(df_samples)

In [None]:
# Plot rsamples of each type of image from the trained model
nrow, ncol = nsample, 10
plt.subplots(nrow, ncol, figsize=(20, 2*nsample))
iplot = 0
for sample in range(nsample):
    for number in range(10):
        iplot += 1
        plt.subplot(nrow, ncol, iplot)
        linear_image = df_samples.xs(sample, axis="columns", level=1, drop_level=True).iloc[number].to_numpy()
        image = wrangle_image(linear_image, npix)
        plt.imshow(image, cmap="binary_r", vmin=0., vmax=255.)
        plt.xticks([]); plt.yticks([])
plt.show()