## Imports

In [None]:
import os, sys
sys.path.append("..")
sys.path.append("../ALAE")

import torch
import numpy as np
import torch.nn.functional as F

import deeplake
from tqdm import tqdm

import wandb
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

from alae_ffhq_inference import load_model, encode, decode

from notebooks_utils import TensorSampler


from src.lightsbm import LightSBM

## Parameters

In [None]:
dim = 512

input_data = 'ADULT'
target_data = 'CHILDREN'

output_seed = 42
batch_size = 128
eps = 0.1
lr = 1e-3

n_potentials = 10
is_diag = True
S_init = 0.1

max_iter = 20000
device = 'cuda:0'


## Data

In [None]:
import gdown
import os

if not os.path.isdir('../data'):
    os.makedirs('../data')

urls = {
    "../data/age.npy": "https://drive.google.com/uc?id=1Vi6NzxCsS23GBNq48E-97Z9UuIuNaxPJ",
    "../data/gender.npy": "https://drive.google.com/uc?id=1SEdsmQGL3mOok1CPTBEfc_O1750fGRtf",
    "../data/latents.npy": "https://drive.google.com/uc?id=1ENhiTRsHtSjIjoRu1xYprcpNd8M9aVu8",
    "../data/test_images.npy": "https://drive.google.com/uc?id=1SjBWWlPjq-dxX4kxzW-Zn3iUR3po8Z0i",
}

for name, url in urls.items():
    gdown.download(url, os.path.join(f"{name}"), quiet=False)
    

In [None]:

train_size = 60000
test_size = 10000

latents = np.load("../data/latents.npy")
gender = np.load("../data/gender.npy")
age = np.load("../data/age.npy")
test_inp_images = np.load("../data/test_images.npy")

train_latents, test_latents = latents[:train_size], latents[train_size:]
train_gender, test_gender = gender[:train_size], gender[train_size:]
train_age, test_age = age[:train_size], age[train_size:]

if input_data == "MAN":
    x_inds_train = np.arange(train_size)[(train_gender == "male").reshape(-1)]
    x_inds_test = np.arange(test_size)[(test_gender == "male").reshape(-1)]
elif input_data == "WOMAN":
    x_inds_train = np.arange(train_size)[(train_gender == "female").reshape(-1)]
    x_inds_test = np.arange(test_size)[(test_gender == "female").reshape(-1)]
elif input_data == "ADULT":
    x_inds_train = np.arange(train_size)[
        (train_age >= 18).reshape(-1)*(train_age != -1).reshape(-1)
    ]
    x_inds_test = np.arange(test_size)[
        (test_age >= 18).reshape(-1)*(test_age != -1).reshape(-1)
    ]
elif input_data == "CHILDREN":
    x_inds_train = np.arange(train_size)[
        (train_age < 18).reshape(-1)*(train_age != -1).reshape(-1)
    ]
    x_inds_test = np.arange(test_size)[
        (test_age < 18).reshape(-1)*(test_age != -1).reshape(-1)
    ]
x_data_train = train_latents[x_inds_train]
x_data_test = test_latents[x_inds_test]

if target_data == "MAN":
    y_inds_train = np.arange(train_size)[(train_gender == "male").reshape(-1)]
    y_inds_test = np.arange(test_size)[(test_gender == "male").reshape(-1)]
elif target_data == "WOMAN":
    y_inds_train = np.arange(train_size)[(train_gender == "female").reshape(-1)]
    y_inds_test = np.arange(test_size)[(test_gender == "female").reshape(-1)]
elif target_data == "ADULT":
    y_inds_train = np.arange(train_size)[
        (train_age >= 18).reshape(-1)*(train_age != -1).reshape(-1)
    ]
    y_inds_test = np.arange(test_size)[
        (test_age >= 18).reshape(-1)*(test_age != -1).reshape(-1)
    ]
elif target_data == "CHILDREN":
    y_inds_train = np.arange(train_size)[
        (train_age < 18).reshape(-1)*(train_age != -1).reshape(-1)
    ]
    y_inds_test = np.arange(test_size)[
        (test_age < 18).reshape(-1)*(test_age != -1).reshape(-1)
    ]
y_data_train = train_latents[y_inds_train]
y_data_test = test_latents[y_inds_test]

X_train = torch.tensor(x_data_train)
Y_train = torch.tensor(y_data_train)

X_test = torch.tensor(x_data_test)
Y_test = torch.tensor(y_data_test)

X_sampler = TensorSampler(X_train, device="cpu")
Y_sampler = TensorSampler(Y_train, device="cpu")


## Model

In [None]:

model = LightSBM(dim=dim, n_potentials=n_potentials, epsilon=eps, S_diagonal_init=S_init, is_diagonal=is_diag)

model.to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)


## Train

In [None]:


def train(model, max_iter, eps, opt, val_freq=1000, batch_size=512, safe_t=1e-2, device=device):
    
    pbar = tqdm(range(1, max_iter + 1))
    
    for i in pbar:
        
        x_0_samples = X_sampler.sample(batch_size).to(device)      
        x_1_samples = Y_sampler.sample(batch_size).to(device)
        
        t = torch.rand([batch_size, 1], device=device) * (1 - safe_t)
        
        x_t = x_1_samples * t + x_0_samples * (1 - t) + torch.sqrt(eps * t * (1 - t)) * torch.randn_like(x_0_samples)
                
        predicted_drift = model.get_drift(x_t, t.squeeze())
        
        loss_plan = (model.get_log_C(x_0_samples) - model.get_log_potential(x_1_samples)).mean()
        
        target_drift = (x_1_samples - x_t) / (1 - t)
        
        loss = F.mse_loss(target_drift, predicted_drift)
        
        opt.zero_grad()
        
        loss.backward()
        
        opt.step()
        
        pbar.set_description(f'Loss : {loss.item()} Plan Loss: {loss_plan.item()}')
        
        if wandb.run:
            wandb.log({'loss_bm': loss, 'loss_plan': loss_plan})
        
        if i % val_freq == 0:
            pass
            

In [None]:

train(model, max_iter, eps, opt, val_freq=1000, batch_size=512, safe_t=1e-2, device=device)


## Results Plotting

In [None]:

alae_model = load_model("../ALAE/configs/ffhq.yaml", training_artifacts_dir="../ALAE/training_artifacts/ffhq/")
torch.manual_seed(output_seed); np.random.seed(output_seed)

inds_to_map = np.random.choice(np.arange((x_inds_test < 300).sum()), size=10, replace=False)
number_of_samples = 3

mapped_all = []
latent_to_map = torch.tensor(test_latents[x_inds_test[inds_to_map]])

inp_images = test_inp_images[x_inds_test[inds_to_map]]

with torch.no_grad():
    for k in range(number_of_samples):
        mapped = model(latent_to_map.to(device))
        mapped_all.append(mapped)
    
mapped = torch.stack(mapped_all, dim=1)

decoded_all = []
with torch.no_grad():
    for k in range(number_of_samples):
        decoded_img = decode(alae_model, mapped[:, k].cpu())
        decoded_img = ((decoded_img * 0.5 + 0.5) * 255).type(torch.long).clamp(0, 255).cpu().type(torch.uint8).permute(0, 2, 3, 1).numpy()
        decoded_all.append(decoded_img)
        
decoded_all = np.stack(decoded_all, axis=1)


In [None]:

n_pictures = 2

fig, axes = plt.subplots(n_pictures, number_of_samples+1, figsize=(number_of_samples+1, n_pictures), dpi=200)

for i, ind in enumerate(range(n_pictures)):
    ax = axes[i]
    ax[0].imshow(inp_images[ind])
    for k in range(number_of_samples):
        ax[k+1].imshow(decoded_all[ind, k])
        
        ax[k+1].get_xaxis().set_visible(False)
        ax[k+1].set_yticks([])
        
    ax[0].get_xaxis().set_visible(False)
    ax[0].set_yticks([])

fig.tight_layout(pad=0.05)
fig.savefig('alae_transfer.png')
