In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm

In [None]:
import torch

In [None]:
from model import Critic, Generator

In [None]:
import tensorflow as tf

## Load data

Importing torchvision is incompatible with sklearn in this environment. Load data from tensorflow.

In [None]:
CHANNELS_IMG = 1

In [None]:
dataset = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = dataset.load_data()

In [None]:
IMAGE_SIZE = x_train.shape[-1]
NUM_CLASSES = np.unique(y_train).shape[0]

In [None]:
size_train = x_train.shape[0]
size_test = x_test.shape[0]
scale = x_train.max()
x_train_scale = ((x_train / scale) - 0.5) / 0.5
x_test_scale = ((x_test / scale) - 0.5) / 0.5

In [None]:
fig, axs = plt.subplots(2,5, sharex=True, sharey=True)
for i in range (2):
    for j in range (5):
        ind = i*5+j
        mask = y_train == ind
        axs[i,j].imshow(np.mean(x_train_scale[mask], axis=0))
plt.tight_layout()

In [None]:
fig, axs = plt.subplots(2,5, sharex=True, sharey=True)
for i in range (2):
    for j in range (5):
        ind = i*5+j
        mask = y_test == ind
        axs[i,j].imshow(np.mean(x_test_scale[mask], axis=0))
plt.tight_layout()

## Load models (critic and generator)

Pytorch models loaded onto cpus is incompatible with scikit learn in this environment. Load models onto gpu.

In [None]:
FEATURES = 32
device = 'mps'

In [None]:
critic = Critic(features=FEATURES, channels_img=CHANNELS_IMG, img_size=IMAGE_SIZE, num_classes=NUM_CLASSES)
critic.load_state_dict(torch.load(f'cwdcgan-gp_critic_2024-04-23_1738.pt'))
critic.to(device)
critic.eval();

In [None]:
NOISE_DIM = 100

In [None]:
gen = Generator(channels_noise=NOISE_DIM, features=FEATURES, channels_img=CHANNELS_IMG, num_classes=NUM_CLASSES)
gen.load_state_dict(torch.load(f'cwdcgan-gp_gen_2024-04-23_1738.pt'))
gen.to(device)
gen.eval();

## Plot metrics

In [None]:
loss = pd.read_csv('cwdcgan-gp_loss_2024-04-23_1738.csv', index_col=0)

In [None]:
loss_critic = loss['Loss Critic Fake'] - loss['Loss Critic Real'] + loss['Loss Critic Gradient Penalty']

In [None]:
fig, axs = plt.subplots(1,2,figsize=[10,5])
axs[0].plot(loss_critic, label='loss critic', color='k')
axs[0].plot(loss['Loss Gen'], label='loss gen', color='C3')
axs[0].legend()
axs[1].plot(loss['Loss Critic Real'], label='loss critic real')
axs[1].plot(loss['Loss Critic Fake'], label='loss critic fake')
axs[1].plot(loss['Loss Critic Gradient Penalty'], label='loss critic gp')
axs[1].plot(loss['Loss Gen'], label='loss gen')
axs[1].legend()

## Generate samples

In [None]:
num_samples = 16
latent = torch.randn(num_samples, NOISE_DIM, 1, 1).to(device)
digit = 0
label = torch.ones(num_samples).type(torch.LongTensor).to(device) * digit

In [None]:
x_gen_torch = gen(latent, label)
x_gen = x_gen_torch.cpu().detach().numpy().reshape(num_samples, IMAGE_SIZE, IMAGE_SIZE)

In [None]:
n = 4
fig, axs = plt.subplots(n, n, figsize=[5,5], sharex=True, sharey=True)
for i in range (n):
    for j in range (n):
        ind = i*n+j
        axs[i, j].imshow(x_gen[ind])
plt.tight_layout()

## Interpolation

In [None]:
digit1, digit2 = 3, 6
label1 = torch.ones(1).type(torch.LongTensor).to(device) * digit1
label2 = torch.ones(1).type(torch.LongTensor).to(device) * digit2

In [None]:
embed1 = gen.embed(label1).cpu().detach().numpy().reshape(NOISE_DIM, 1, 1)
embed2 = gen.embed(label2).cpu().detach().numpy().reshape(NOISE_DIM, 1, 1)

In [None]:
steps = 20
embed_interp = torch.Tensor(np.linspace(embed1, embed2, steps)).to(device)

In [None]:
latent1 = np.random.standard_normal((NOISE_DIM, 1, 1))
latent2 = np.random.standard_normal((NOISE_DIM, 1, 1))

In [None]:
latent_interp = torch.Tensor(np.linspace(latent1, latent2, steps)).to(device)

In [None]:
x_interp = torch.cat([latent_interp, embed_interp], dim=1)

In [None]:
gen_interp = gen.gen(x_interp).cpu().detach().numpy().reshape(steps,IMAGE_SIZE,IMAGE_SIZE)

In [None]:
fig, axs = plt.subplots(4,5, sharex=True, sharey=True)
for i in range (4):
    for j in range (5):
        ind = i*5+j
        axs[i,j].imshow(gen_interp[ind])
plt.tight_layout()

## Generator filter weights

In [None]:
gen_block1_conv_filters = gen.block4[0].weight.data.cpu().numpy()

In [None]:
vmin = gen_block1_conv_filters.min()
vmax = gen_block1_conv_filters.max()
fig, axs = plt.subplots(4, 8, sharex=True, sharey=True)
for i in range (4):
    for j in range (8):
        ind = i*8+j
        axs[i, j].set_title(f'{ind}: {i}, {j}')
        axs[i, j].imshow(gen_block1_conv_filters[ind, 0], vmin=vmin, vmax=vmax)
plt.tight_layout()