In [116]:
import os
import numpy as np
import torch

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset
import torch.distributions as dist

from tqdm.notebook import tqdm

import matplotlib.pyplot as plt

from matplotlib.colors import LogNorm
from scipy.stats import multivariate_normal

from IPython.display import Image

from botorch.models import SingleTaskGP
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from botorch.utils.transforms import standardize, normalize, unnormalize
from botorch.optim import optimize_acqf
from botorch import fit_gpytorch_mll
from botorch.acquisition.monte_carlo import qExpectedImprovement
from botorch.sampling.normal import SobolQMCNormalSampler
import warnings

In [117]:
%config InlineBackend.figure_format='retina'

In [118]:
force_cpu = False
if not force_cpu:
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
else:
    device = torch.device("cpu")

print("Using device", device)

Using device cuda


In [119]:
# try:
#     from google.colab import drive
#     drive.mount('/content/drive')
#     root = '/content/drive/MyDrive/Colab Notebooks'
# except:
root = '.'

fig_folder = f"{root}/figures"
backup_folder = f"{root}/backup"

os.listdir(root)
for f in fig_folder, backup_folder:
    os.makedirs(f, exist_ok=True)

print("Backup folder:", backup_folder)
print("Figures folder:", fig_folder)

Backup folder: ./backup
Figures folder: ./figures


In [120]:
import gymnasium as gym
env = gym.make('CartPole-v1')

for _ in range(10):
    observation, info = env.reset(seed=0)
    # print(observation)
    t = 0
    terminated, truncated = False, False
    while not terminated and not truncated:
        action = env.action_space.sample()  # agent policy that uses the observation and info
        observation, reward, terminated, truncated, info = env.step(action)
        t += 1
        
env.close()

In [125]:
isinstance(torch.zeros(2), torch.TensorType)

False

# Using reinforce

In [197]:
def eval_trajectory(env, trajectory):
    
    observation, info = env.reset(seed=0)
    t = 0
    terminated, truncated = False, False
    while not terminated and not truncated:
        try:
            action = trajectory[t]
            # action = env.action_space.sample()  # agent policy that uses the observation and info
            observation, reward, terminated, truncated, info = env.step(action)
            t += 1
        except IndexError:
            break
    return t


def eval_policy(env, policy, p=None, kl_div_factor=0.001, n_sample=10):
    
    q = dist.Binomial(probs=torch.sigmoid(policy))
    trajectories = q.sample((n_sample, ))
    
    loss, p_sum = 0, 0

    for traj in trajectories:

        reward = eval_trajectory(trajectory=traj.long().cpu().numpy(), env=env)

        p_traj = q.log_prob(traj).sum().exp()
        loss -= p_traj * reward
        p_sum += p_traj

    loss /= p_sum
    
    if p is not None:
        kl_div = dist.kl_divergence(q, p).sum()
        loss += kl_div_factor * kl_div
    
    return loss

In [198]:
n_steps = 50
epochs = 1000 

param = nn.Parameter(torch.zeros(n_steps))
optimizer = optim.Adam([param, ], lr=0.2)

n_sample = 10

p = dist.Binomial(probs=torch.ones(n_steps)*0.5)

with tqdm(total=epochs) as pbar:
    for epoch in range(1, epochs + 1):
        
        optimizer.zero_grad()
        
        loss = eval_policy(env=env, policy=param, p=p)

        loss.backward()
        optimizer.step()

        pbar.update()
        pbar.set_postfix({"loss": f"{loss.item()/len(x):.4f}"})
        

  0%|          | 0/1000 [00:00<?, ?it/s]

In [199]:
n_sample = 10000
sum_reward = 0
with torch.no_grad():
    q = dist.Binomial(probs=torch.sigmoid(param))
    trajectories = q.sample((n_sample, ))
    for traj in trajectories:

        reward = eval_trajectory(trajectory=traj.long().numpy(), env=env)
        sum_reward += reward
print("After training", sum_reward/n_sample)

n_sample = 10000
sum_reward = 0
with torch.no_grad():
    q = dist.Binomial(probs=torch.sigmoid(torch.zeros(n_steps)))
    trajectories = q.sample((n_sample, ))
    for traj in trajectories:
        reward = eval_trajectory(trajectory=traj.long().numpy(), env=env)
        sum_reward += reward
print("Random", sum_reward/n_sample)

After training 46.2959
Random 21.3538


In [200]:
with torch.no_grad():
    reward =  - eval_policy(env=env, policy=param, n_sample=1000)
    print("After training", reward)
    reward = - eval_policy(env=env, policy=torch.zeros(n_steps), n_sample=1000)
    print("Random", reward)

After training tensor(49.8929)
Random tensor(21.4770)


## Close loop

### Vae model

In [201]:
class VAE(nn.Module):
    def __init__(self):
        super().__init__()

        self.enc1 = nn.Linear(n_steps, 40)
        self.enc2 = nn.Linear(40, 20)
        self.mu = nn.Linear(20, 2)
        self.logvar = nn.Linear(20, 2)
        
        self.dec1 = nn.Linear(2, 20)
        self.dec2 = nn.Linear(20, 40)
        self.dec3 = nn.Linear(40, n_steps)

    def encode(self, x):
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        mu = self.mu(x)
        logvar = self.logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        z = F.relu(self.dec1(z))
        z = F.relu(self.dec2(z))
        x = torch.sigmoid(self.dec3(z))
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

### Vae loss

In [210]:
# Reconstruction + KL divergence losses summed over all elements and batch
def vae_loss(recon_x, x, mu, logvar):
    bce = F.binary_cross_entropy(recon_x, x, reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    kl = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp())
    
    return bce + kl

### Regression Model

In [211]:
class Regression(nn.Module):
    def __init__(self):
        super().__init__()
        self.pred1 = nn.Linear(2, 128)
        self.pred2 = nn.Linear(128, 128)
        self.pred3 = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.pred1(x))
        x = F.relu(self.pred2(x))
        x = self.pred3(x)        
        return x

### Regression loss

In [217]:
def regression_loss(pred, y):
    loss = nn.MSELoss(reduction="sum")
    # output = loss(pred.squeeze(), y.float())
    return loss(pred, y)

### GP related stuff

In [220]:
def fit_vae(
        vae, reg, train_x, train_y, 
        batch_size=64,
        epochs=20,
        weight_reg_loss=20):
    
    class TrainDataset(Dataset): 
        def __init__(self):
            super().__init__()
        def __len__(self):
            return len(train_y)
        def __getitem__(self, idx):
            return train_x[idx], train_y[idx]
        
    train_loader = DataLoader(TrainDataset(), batch_size=batch_size)
    
    vae.to(device)
    reg.to(device)
    vae.train()
    reg.train()

    optimizer = optim_class(list(vae.parameters()) + list(reg.parameters()), **optim_kwargs)
    
    with tqdm(total=epochs, leave=False) as pbar:
        
        for epoch in range(1, epochs + 1):
            
            train_loss = 0
                        
            for batch_idx, (x, y) in enumerate(train_loader):

                x, y = train_x, train_y
                x.to(device)
                y.to(device)

                optimizer.zero_grad()

                recon_x, mu, logvar = vae(x)
                l_vae = vae_loss(recon_x=recon_x, x=x, mu=mu, logvar=logvar)

                pred = reg(mu)
                pred.to(device)
                y.to(device)
                l_reg = weight_reg_loss*regression_loss(y=y, pred=pred)

                loss = l_vae + l_reg
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
            
            pbar.update()

#             pbar.set_postfix({
#                 "loss VAE": f"{l_vae.item()/len(x):.4f}",
#                 "loss classifier": f"{l_class.item()/len(x):.4f}"})

In [221]:
# seed=1
# torch.manual_seed(seed)

dtype=torch.float

d = 2


vae = VAE().to(device)
reg = Regression().to(device)

SMOKE_TEST = False

acq_opt__batch_size = 3
acq_opt__num_restarts = 10
acq_opt__raw_samples = 256

qmc_sampler__sample_shape = 2048
qmc_sampler__seed = 123

n_batch = 100

n_sample_start = 15

max_n = 20


bkp_vae =     "vae_close_loop-gym__vae.p"
bkp_reg =     "vae_close_loop-gym__reg.p"
bkp_train_h = "vae_close_loop-gym__train_h.p"
bkp_gp =      "vae_close_loop-gym__gp.p"

if False:
    vae.load_state_dict(torch.load(bkp_vae))
    classifier_latent2digit.load_state_dict(torch.load(bkp_cls_latent2digit))
    train_h = torch.load(bkp_train_h)

else:

    warnings.simplefilter("ignore")

    best_observed = []

    # --------------------------------
    train_z = torch.rand(n_sample_start, d, device=device, dtype=dtype)
    print(train_z)

    with torch.no_grad():
        train_x = vae.decode(train_z)
        train_y = torch.tensor([
            - eval_policy(
                env=env,
                policy=pol) for pol in train_x])  
    # ------------------------------


    state_dict = None
    # run N_BATCH rounds of BayesOpt after the initial random batch
    for iteration in tqdm(range(n_batch)):
        
        if len(train_x) > max_n:
            train_x = train_x[-max_n:]
            train_y = train_y[-max_n:]
            train_z = train_z[-max_n:]

        # Fit the VAE (+ latent2digit classifier)
        fit_vae(
            vae=vae, 
            reg=reg,
            train_x=train_x, 
            train_y=train_y)

        with torch.no_grad():
            train_x = decode(train_z, vae_model=vae)
            train_y = torch.tensor([
                - eval_policy(
                    env=env,
                    policy=pol) for pol in train_x])  

        # fit the GP model
        # normalize(train_h, bounds=bounds), 
        # standardize(train_obj) 
        model = SingleTaskGP(train_X=train_x, train_Y=train_obj)
        # if state_dict is not None:
        #     model.load_state_dict(state_dict)
        mll = ExactMarginalLogLikelihood(model.likelihood, model)
        mll.to(train_x)
        fit_gpytorch_mll(mll)

        # define the qNEI acquisition module using a QMC sampler
        qmc_sampler = SobolQMCNormalSampler(**kwargs_qmc_sampler)
        qEI = qExpectedImprovement(
            model=model, 
            sampler=qmc_sampler, 
            best_f=train_y.max())

        # optimize and get new observation
        # optimize
        candidates, _ = optimize_acqf(
            acq_function=acq_func,
            bounds=torch.stack([
                torch.zeros(d, dtype=dtype, device=device), 
                torch.ones(d, dtype=dtype, device=device),
            ]),
            q=acq_opt__batch_size,
            num_restarts=acq_opt__num_restarts,
            raw_samples=acq_opt__raw_samples,
        )

        # observe new values 
        with torch.no_grad():
            new_z = candidates.detach()
            new_x = vae.decode(new_z)
            new_y = torch.tensor([
                - eval_policy(
                    env=env,
                    policy=pol) for pol in new_x])  

        # update training points
        train_z = torch.cat((train_z, new_z))
        train_x = torch.cat((train_x, new_x))
        train_y = torch.cat((train_y, new_y))

        # update progress
        best_value = train_y.max().item()
        best_observed.append(best_value)

        # state_dict = model.state_dict()

    torch.save(vae.state_dict(), bkp_vae)
    torch.save(reg.state_dict(), bkp_reg)
    # torch.save(model.state_dict(), bkp_gp)
    # torch.save(classifier_latent2digit.state_dict(), bkp_cls_latent2digit)
    torch.save(train_z, bkp_train_z)

tensor([[0.3015, 0.5563],
        [0.2475, 0.6147],
        [0.1296, 0.0821],
        [0.8089, 0.3794],
        [0.5686, 0.1558],
        [0.6450, 0.8953],
        [0.0925, 0.6217],
        [0.0777, 0.2138],
        [0.6687, 0.7374],
        [0.0517, 0.2388],
        [0.4127, 0.6627],
        [0.1501, 0.3584],
        [0.8628, 0.4697],
        [0.3907, 0.4794],
        [0.2759, 0.0350]], device='cuda:0')


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!