In [None]:
import os

import lpips
import torch as th
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import tqdm
import cv2

from scripts import *

In [None]:
# important params
checkpoint_dir = "experiments/dcgan_disc_double_params_padding_reflect_lr_scheduler/lightning_logs/version_0/checkpoints/epoch=899.ckpt"
save_path = "experiments/gan_inversion/dcgan_disc_double_params_padding_reflect_lr_scheduler_epoch=899"

os.makedirs(save_path, exist_ok=True)

# optional
bs = 64
n_iter = 1000
r, c = 32, 64

In [None]:
# load model
model = GAN.load_from_checkpoint(checkpoint_dir)
model.eval()
model.cuda()
1

In [None]:
# fake data loader

# path names
X_fake_path = os.path.join(save_path, "X_fake.npy")
y_fake_path = os.path.join(save_path, "y_fake.npy")

if not os.path.exists(X_fake_path) or not os.path.exists(y_fake_path):
    y_fake = []
    X_fake = []

    with th.no_grad():
        for i in range(10):
            z = th.normal(0, 1, (1000, model.generator.latent_dim), device=model.device)
            x = model.generator(z)

            y_fake.append(z.detach().cpu())
            X_fake.append(x.detach().cpu())

    y_fake = th.cat(y_fake, axis=0)
    X_fake = th.cat(X_fake, axis=0)

    # save fake dataset
    np.save(X_fake_path, X_fake.detach().cpu().numpy())
    np.save(y_fake_path, y_fake.detach().cpu().numpy())
else:
    X_fake = th.tensor(np.load(X_fake_path))
    y_fake = th.tensor(np.load(y_fake_path))
dataset = TensorDataset(X_fake.detach().cpu(), y_fake.detach().cpu())
dataloader_fake = DataLoader(dataset, batch_size=bs, shuffle=True, num_workers=2)

y_fake.shape, X_fake.shape

In [None]:
# real data loader

X_real_path = os.path.join(save_path, "X_real.npy")

cars = np.load("../potsdam_data/potsdam_cars/cars.npy", allow_pickle=True)

if not os.path.exists(X_real_path):
    X_real = []
    for car in tqdm.tqdm(cars):
        car_resized = cv2.resize(car, (c, r)).transpose(2, 0, 1).astype(np.float32)

        # scaling -1, 1
        car_resized = 2*(car_resized / 255) - 1

        # append
        X_real.append(np.expand_dims(car_resized, axis=0))

    X_real = np.concatenate(X_real, axis=0)

    # save real dataset
    np.save(X_real_path, X_real)
else:
    X_real = np.load(X_real_path)

X_real = th.tensor(X_real, dtype=X_fake.dtype)

dataset = TensorDataset(X_real, X_real)
dataloader_real = DataLoader(dataset, batch_size=bs, shuffle=True, num_workers=2)

X_real.shape

In [None]:
# val data loader

X_real_path = os.path.join(save_path, "X_real_val.npy")

cars = np.load("../potsdam_data/potsdam_cars_val/cars.npy", allow_pickle=True)

if not os.path.exists(X_real_path):
    X_real = []
    for car in cars:
        car_resized = cv2.resize(car, (c, r)).transpose(2, 0, 1).astype(np.float32)
        
        # scaling -1, 1
        car_resized = 2*(car_resized / 255) - 1

        # append
        X_real.append(np.expand_dims(car_resized, axis=0))
        
    X_real = np.concatenate(X_real, axis=0)

    # save real dataset
    np.save(X_real_path, X_real)
else:
    X_real = np.load(X_real_path)

X_real = th.tensor(X_real, dtype=X_fake.dtype)

dataset = TensorDataset(X_real, X_real)
dataloader_val = DataLoader(dataset, batch_size=2*bs, shuffle=True, num_workers=2)

X_real.shape

In [None]:
net = EncoderLatent(generator).cuda()
optimizer = th.optim.Adam(net.parameters(), lr=0.001)
loss_latent = th.nn.L1Loss(reduction="mean")
loss_rec = th.nn.L1Loss(reduction="mean")
loss_rec = lpips.LPIPS(net='vgg').cuda()

In [None]:
net

In [None]:
from torchsummary import summary
summary(net.cuda(), (3, 32, 64))

In [None]:
net.train()

learning_curve = []
learning_curve_val = []
n_iter = 300

for i in range(n_iter):

    x_fake, z_fake = next(iter(dataloader_fake))
    x_real, _ = next(iter(dataloader_real))
    
    optimizer.zero_grad()
    z_fake_, x_fake_ = net.forward(x_fake.cuda())
    _, x_real_ = net.forward(x_real.cuda())


    loss_1 = loss_latent(z_fake_, z_fake.cuda())
    loss_2 = th.mean(loss_rec(x_fake_, x_fake.cuda(), normalize=True))
    loss_3 = th.mean(loss_rec(x_real_, x_real.cuda(), normalize=True))
    total_loss = loss_1 + loss_2 + loss_3
    
    learning_curve.append(total_loss.item())
    
    total_loss.backward()
    optimizer.step()
    
    if i % 10 == 0:
        print("iteration ", i, "loss", learning_curve[-1])
    
    if i % 90 == 0:
        for g in optimizer.param_groups:
            g['lr'] = g['lr']*0.1
    
    if i % 100 == 0:
        with th.no_grad():
            net.eval()
            l = 0
            for x_val, _ in dataloader_val:
                _, x_val_ = net.forward(x_val.cuda())
                l += th.mean(loss_rec(x_val_, x_val.cuda(), normalize=True)).item()
            print("iteration ", i, "val loss", l)
            learning_curve_val.append(l)
            net.train()
        th.save(net.state_dict(), os.path.join(save_path, f"iter={i}.pkl"))

In [None]:
plt.plot(learning_curve)
plt.plot([100*i for i in range(len(learning_curve_val))], learning_curve_val)
plt.title("Loss Curve")
plt.xlabel("Epochs")
plt.ylabel("MSE")

In [None]:
best_model_iter = 300
net.load_state_dict(th.load(os.path.join(save_path, f"iter={best_model_iter}.pkl")))
net.cuda()
net.eval()

In [None]:
# testing
with th.no_grad():
    z = th.normal(0, 1, (1, model.generator.latent_dim), device=model.device)
    x_orig = model(z)

    z_rec, _ = net.forward(x_orig)
    x_rec = model.generator(z_rec)

# rescale
x_rec = (np.squeeze(x_rec.detach().cpu().numpy()).transpose(1, 2, 0) + 1) / 2
x_orig = (np.squeeze(x_orig.detach().cpu().numpy()).transpose(1, 2, 0) + 1) / 2

# plot
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.imshow(x_rec)
plt.title("Reconstructed (Fake)", fontsize=12)
plt.xticks([])
plt.yticks([])
plt.subplot(2, 1, 2)
plt.imshow(x_orig)
plt.title("Original (Fake)", fontsize=12)
plt.xticks([])
plt.yticks([])

In [None]:
i = 1
x_val, _ = next(iter(dataloader_val))

x_orig = th.unsqueeze(x_val[i], axis=0)
z_rec, _ =  net.forward(x_orig.cuda())
x_rec = model.generator(z_rec)

# rescale
x_rec = (np.squeeze(x_rec.detach().cpu().numpy()).transpose(1, 2, 0) + 1) / 2
x_orig = (np.squeeze(x_orig.detach().cpu().numpy()).transpose(1, 2, 0) + 1) / 2

# plot
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.imshow(x_rec)
plt.title("Reconstructed (Fake)", fontsize=16)
plt.xticks([])
plt.yticks([])
plt.subplot(2, 1, 2)
plt.imshow(x_orig)
plt.title("Original (Real)", fontsize=16)
plt.xticks([])
plt.yticks([])