<a href="https://colab.research.google.com/github/HanbumKo/DRL-course/blob/main/15_World_Models/15_World_Models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# World Models

Code based on [here](https://github.com/ctallec/world-models)

In [1]:
!apt update
!apt install xvfb
!pip install gym[box2d]
!pip install cma
!pip install pyvirtualdisplay
!pip install gym-notebook-wrapper
!pip install ray
!wget https://github.com/HanbumKo/DRL-course/raw/main/15_World_Models/weight/vae.pt
!wget https://github.com/HanbumKo/DRL-course/raw/main/15_World_Models/weight/mdrnn.pt

[33m0% [Working][0m            Hit:1 http://security.ubuntu.com/ubuntu bionic-security InRelease
Hit:2 https://cloud.r-project.org/bin/linux/ubuntu bionic-cran40/ InRelease
Ign:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64  InRelease
Ign:4 https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64  InRelease
Hit:5 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64  Release
Hit:6 https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64  Release
Hit:7 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu bionic InRelease
Hit:8 http://archive.ubuntu.com/ubuntu bionic InRelease
Hit:9 http://archive.ubuntu.com/ubuntu bionic-updates InRelease
Hit:10 http://ppa.launchpad.net/cran/libgit2/ubuntu bionic InRelease
Get:11 http://archive.ubuntu.com/ubuntu bionic-backports InRelease [74.6 kB]
Hit:13 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu bionic InRelease
Hit:15 http://ppa.lau

In [2]:
import gym
import sys
import ray
import cv2
import cma
import time
import torch
import random
import gnwrapper
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from gym import wrappers
from torchvision import transforms
from torch.distributions.normal import Normal
from torch.multiprocessing import Process, Queue
from os import mkdir, unlink, listdir, getpid
from os.path import join, exists

from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



In [3]:
def show_image(image):
    if isinstance(image, torch.Tensor):
        image = image.numpy()
    clear_output(True)
    plt.imshow(image)
    plt.show()

In [4]:
def show_two_images(image1, image2):
    if isinstance(image1, torch.Tensor):
        image1 = image1.detach().numpy()
    if isinstance(image2, torch.Tensor):
        image2 = image2.detach().numpy()
    image = np.concatenate((image1, image2), axis=1)
    clear_output(True)
    plt.imshow(image)
    plt.show()

In [5]:
env = gym.make("CarRacing-v0")
env = gnwrapper.LoopAnimation(env)



# VAE rollout data

In [6]:
rollout_length = 10000
observations = np.zeros((rollout_length, 64, 64, 3))
actions = np.zeros((rollout_length, 3))

In [7]:
obs = env.reset()
for i in range(rollout_length):
    action = env.action_space.sample() # Take random action
    obs = cv2.resize(obs, dsize=(64, 64), interpolation=cv2.INTER_CUBIC)
    observations[i, :, :, :] = obs
    actions[i, :] = action
    obs, rew, done, _ = env.step(action)
    # print(i)
    if done:
        obs = env.reset()

observations = np.transpose(observations, (0, 3, 1, 2)) / 255.
np.save('observations.npy', observations)
np.save('actions.npy', actions)

In [8]:
obs_data = np.load('data/observations.npy')
act_data = np.load('data/actions.npy')
print(obs_data.shape)
print(act_data.shape)

# V model

In [9]:
class Decoder(nn.Module):
    """ VAE decoder """
    def __init__(self, img_channels, latent_size):
        super(Decoder, self).__init__()
        self.latent_size = latent_size
        self.img_channels = img_channels

        self.fc1 = nn.Linear(latent_size, 1024)
        self.deconv1 = nn.ConvTranspose2d(1024, 128, 5, stride=2)
        self.deconv2 = nn.ConvTranspose2d(128, 64, 5, stride=2)
        self.deconv3 = nn.ConvTranspose2d(64, 32, 6, stride=2)
        self.deconv4 = nn.ConvTranspose2d(32, img_channels, 6, stride=2)

    def forward(self, x): # pylint: disable=arguments-differ
        x = F.relu(self.fc1(x))
        x = x.unsqueeze(-1).unsqueeze(-1)
        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2(x))
        x = F.relu(self.deconv3(x))
        reconstruction = F.sigmoid(self.deconv4(x))
        return reconstruction

In [10]:
class Encoder(nn.Module): # pylint: disable=too-many-instance-attributes
    """ VAE encoder """
    def __init__(self, img_channels, latent_size):
        super(Encoder, self).__init__()
        self.latent_size = latent_size
        #self.img_size = img_size
        self.img_channels = img_channels

        self.conv1 = nn.Conv2d(img_channels, 32, 4, stride=2)
        self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
        self.conv3 = nn.Conv2d(64, 128, 4, stride=2)
        self.conv4 = nn.Conv2d(128, 256, 4, stride=2)

        self.fc_mu = nn.Linear(2*2*256, latent_size)
        self.fc_logsigma = nn.Linear(2*2*256, latent_size)


    def forward(self, x): # pylint: disable=arguments-differ
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = x.reshape(x.size(0), -1)

        mu = self.fc_mu(x)
        logsigma = self.fc_logsigma(x)

        return mu, logsigma

In [11]:
class VAE(nn.Module):
    """ Variational Autoencoder """
    def __init__(self, img_channels, latent_size):
        super(VAE, self).__init__()
        self.encoder = Encoder(img_channels, latent_size)
        self.decoder = Decoder(img_channels, latent_size)

    def forward(self, x): # pylint: disable=arguments-differ
        mu, logsigma = self.encoder(x)
        sigma = logsigma.exp()
        eps = torch.randn_like(sigma)
        z = eps.mul(sigma).add_(mu)

        recon_x = self.decoder(z)
        return recon_x, mu, logsigma

In [12]:
# Reconstruction + KL divergence losses summed over all elements and batch
def vae_loss_function(recon_x, x, mu, logsigma):
    """ VAE loss function """
    BCE = F.mse_loss(recon_x, x, size_average=False)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + 2 * logsigma - mu.pow(2) - (2 * logsigma).exp())
    return BCE + KLD

In [13]:
vae = VAE(3, 32).to(device)
optimizer = optim.Adam(vae.parameters())
batch_size = 128
vae.load_state_dict(torch.load("vae.pt", map_location=device))

<All keys matched successfully>

In [14]:
def train_vae():
    vae.train()
    idxs = np.random.randint(0, rollout_length, size=batch_size)
    data = observations[idxs]
    data = torch.as_tensor(data, dtype=torch.float32).to(device)
    recon_batch, mu, logvar = vae(data)
    optimizer.zero_grad()
    loss = loss_function(recon_batch, data, mu, logvar)
    loss.backward()
    optimizer.step()
    # print("Train loss :", loss.item() / batch_size)


In [15]:
def test_vae():
    vae.eval()
    with torch.no_grad():
        data = torch.as_tensor(observations, dtype=torch.float32).to(device)
        recon_batch, mu, logvar = vae(data)
        test_loss = loss_function(recon_batch, data, mu, logvar).item() / data.shape[0]
        print("Test loss :", test_loss)

In [16]:
for epoch in range(100000):
    train_vae()
    if epoch % 1000 == 0:
        print(epoch)
        test_vae()

In [17]:
torch.save(vae.state_dict(), "vae.pt")

In [18]:
data = torch.as_tensor(observations, dtype=torch.float32).to(device)
with torch.no_grad():
    output, _, _ = vae(data)
    data = data.cpu().permute(0, 2, 3, 1)
    output = output.cpu().permute(0, 2, 3, 1)
    for _ in range(10):
        i = random.randint(0, data.shape[0])
        show_two_images(data[i], output[i])
        time.sleep(1)


# M model

In [19]:
class MDRNNCell(nn.Module):
    """ MDRNN model for one step forward """
    def __init__(self, latents, actions, hiddens, gaussians):
        super().__init__()
        self.latents = latents
        self.actions = actions
        self.hiddens = hiddens
        self.gaussians = gaussians

        self.gmm_linear = nn.Linear(hiddens, (2*latents + 1)*gaussians + 2)
        self.rnn = nn.LSTMCell(latents + actions, hiddens)

    def forward(self, action, latent, hidden): # pylint: disable=arguments-differ
        in_al = torch.cat([action, latent], dim=1)

        next_hidden = self.rnn(in_al, hidden)
        out_rnn = next_hidden[0]

        out_full = self.gmm_linear(out_rnn)

        stride = self.gaussians * self.latents

        mus = out_full[:, :stride]
        mus = mus.view(-1, self.gaussians, self.latents)

        sigmas = out_full[:, stride:2 * stride]
        sigmas = sigmas.view(-1, self.gaussians, self.latents)
        sigmas = torch.exp(sigmas)

        pi = out_full[:, 2 * stride:2 * stride + self.gaussians]
        pi = pi.view(-1, self.gaussians)
        logpi = f.log_softmax(pi, dim=-1)

        r = out_full[:, -2]

        d = out_full[:, -1]

        return mus, sigmas, logpi, r, d, next_hidden

In [20]:
class MDRNN(nn.Module):
    def __init__(self, latents, actions, hiddens, gaussians):
        super().__init__()
        self.latents = latents
        self.actions = actions
        self.hiddens = hiddens
        self.gaussians = gaussians

        self.gmm_linear = nn.Linear(hiddens, (2*latents + 1)*gaussians + 2)
        self.rnn = nn.LSTM(latents + actions, hiddens)
    
    def forward(self, actions, latents):
        # actions: (sequence_length, batch_size, action_size)
        # latents: (sequence_length, batch_size, latent_size)
        seq_len, bs = actions.shape[0], actions.shape[1]
        
        ins = torch.cat([actions, latents], dim=-1)
        outs, _ = self.rnn(ins)
        gmm_outs = self.gmm_linear(outs)

        stride = self.gaussians * self.latents
        
        mus = gmm_outs[:, :, :stride]
        mus = mus.view(seq_len, bs, self.gaussians, self.latents)

        sigmas = gmm_outs[:, :, stride:2 * stride]
        sigmas = sigmas.view(seq_len, bs, self.gaussians, self.latents)
        sigmas = torch.exp(sigmas)

        pi = gmm_outs[:, :, 2 * stride: 2 * stride + self.gaussians]
        pi = pi.view(seq_len, bs, self.gaussians)
        logpi = F.log_softmax(pi, dim=-1)

        rs = gmm_outs[:, :, -2]

        ds = gmm_outs[:, :, -1]

        return mus, sigmas, logpi, rs, ds

In [21]:
@ray.remote
def mdrnn_rollout():
    o = np.zeros((1000, 64, 64, 3))
    a = np.zeros((1000, 3))
    r = np.zeros((1000, ))
    d = np.zeros((1000, ))
    o2 = np.zeros((1000, 64, 64, 3))

    env = gym.make("CarRacing-v0")
    env = gnwrapper.LoopAnimation(env)

    obs = env.reset()
    for i in range(1000):
        action = env.action_space.sample() # Take random action
        next_obs, rew, done, _ = env.step(action)

        obs = cv2.resize(obs, dsize=(64, 64), interpolation=cv2.INTER_CUBIC)
        next_obs = cv2.resize(next_obs, dsize=(64, 64), interpolation=cv2.INTER_CUBIC)

        o[i, :, :, :] = obs / 255.
        a[i, :] = action
        r[i] = rew
        d[i] = done
        o2[i, :, :, :] = next_obs / 255.

        obs = next_obs

        if done:
            pass
            # obs = env.reset()
        
    return o, a, r, d, o2

In [22]:
ray.init()

2021-05-11 08:20:21,467	INFO services.py:1269 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m


{'metrics_export_port': 65105,
 'node_id': '4ad3d15f5d06f2510f571d1f42eece9849f234777b4254b3f8943d39',
 'node_ip_address': '172.28.0.2',
 'object_store_address': '/tmp/ray/session_2021-05-11_08-20-19_568358_4254/sockets/plasma_store',
 'raylet_ip_address': '172.28.0.2',
 'raylet_socket_name': '/tmp/ray/session_2021-05-11_08-20-19_568358_4254/sockets/raylet',
 'redis_address': '172.28.0.2:6379',
 'session_dir': '/tmp/ray/session_2021-05-11_08-20-19_568358_4254',
 'webui_url': '127.0.0.1:8265'}

In [23]:
mdrnn = MDRNN(latents=32, actions=3, hiddens=256, gaussians=5)
mdrnn.to(device)
mdrnn.load_state_dict(torch.load("mdrnn.pt", map_location=device))
optimizer = torch.optim.RMSprop(mdrnn.parameters(), lr=1e-3, alpha=.9)

In [24]:
def gmm_loss(batch, mus, sigmas, logpi, reduce=True): # pylint: disable=too-many-arguments
    """ Computes the gmm loss.
    Compute minus the log probability of batch under the GMM model described
    by mus, sigmas, pi. Precisely, with bs1, bs2, ... the sizes of the batch
    dimensions (several batch dimension are useful when you have both a batch
    axis and a time step axis), gs the number of mixtures and fs the number of
    features.
    :args batch: (bs1, bs2, *, fs) torch tensor
    :args mus: (bs1, bs2, *, gs, fs) torch tensor
    :args sigmas: (bs1, bs2, *, gs, fs) torch tensor
    :args logpi: (bs1, bs2, *, gs) torch tensor
    :args reduce: if not reduce, the mean in the following formula is ommited
    :returns:
    loss(batch) = - mean_{i1=0..bs1, i2=0..bs2, ...} log(
        sum_{k=1..gs} pi[i1, i2, ..., k] * N(
            batch[i1, i2, ..., :] | mus[i1, i2, ..., k, :], sigmas[i1, i2, ..., k, :]))
    NOTE: The loss is not reduced along the feature dimension (i.e. it should scale ~linearily
    with fs).
    """
    batch = batch.unsqueeze(-2)
    normal_dist = Normal(mus, sigmas)
    g_log_probs = normal_dist.log_prob(batch)
    g_log_probs = logpi + torch.sum(g_log_probs, dim=-1)
    max_log_probs = torch.max(g_log_probs, dim=-1, keepdim=True)[0]
    g_log_probs = g_log_probs - max_log_probs

    g_probs = torch.exp(g_log_probs)
    probs = torch.sum(g_probs, dim=-1)

    log_prob = max_log_probs.squeeze() + torch.log(probs)
    if reduce:
        return - torch.mean(log_prob)
    return - log_prob

In [25]:
def to_latent(obs, next_obs):
    # obs: (sequence length, width, height, 3) 
    # next_obs: (sequence length, width, height, 3)
    obs = obs.permute(0, 3, 1, 2)
    next_obs = next_obs.permute(0, 3, 1, 2)

    recon1, obs_mu, obs_logsigma = vae(obs)
    recon2, next_obs_mu, next_obs_logsigma = vae(next_obs)

    # eps = torch.randn_like(obs_logsigma)
    latent_obs = obs_mu + obs_logsigma.exp() * torch.randn_like(obs_mu)

    # eps = torch.randn_like(next_obs_logsigma)
    latent_next_obs = next_obs_mu + next_obs_logsigma.exp() * torch.randn_like(next_obs_mu)

    return latent_obs, latent_next_obs

In [26]:
def get_loss(latent_obs, action, reward, terminal, latent_next_obs):
    # :args latent_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor
    # :args action: (BSIZE, SEQ_LEN, ASIZE) torch tensor
    # :args reward: (BSIZE, SEQ_LEN) torch tensor
    # :args terminal: (BSIZE, SEQ_LEN) torch tensor
    # :args latent_next_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor
    latent_obs, action,\
        reward, terminal,\
        latent_next_obs = [arr.transpose(1, 0)
                            for arr in [latent_obs, action,
                                        reward, terminal,
                                        latent_next_obs]]
    mus, sigmas, logpi, rs, ds = mdrnn(action, latent_obs)
    gmm = gmm_loss(latent_next_obs, mus, sigmas, logpi)
    bce = F.binary_cross_entropy_with_logits(ds, terminal)
    mse = F.mse_loss(rs, reward)
    scale = 34
    # print("gmm :", gmm)
    # print("bce :", bce)
    # print("mse :", mse)
    loss = (gmm + bce + mse) / scale
    return loss


In [27]:
def train():
    mdrnn.train()
    rollout_ops = [mdrnn_rollout.remote() for _ in range(n_workers)]
    res = ray.get(rollout_ops)
    o = np.array(())
    loss_sum = 0.
    for i in range(n_workers):
        o, a, r, d, o2 = res[i]

        o = torch.as_tensor(o, dtype=torch.float32).to(device)
        a = torch.as_tensor(a, dtype=torch.float32).to(device)
        r = torch.as_tensor(r, dtype=torch.float32).to(device)
        d = torch.as_tensor(d, dtype=torch.float32).to(device)
        o2 = torch.as_tensor(o2, dtype=torch.float32).to(device)

        # o: (1000, 64, 64, 3)
        # a: (1000, 3)
        # r: (1000, )
        # d: (1000, )
        # o2: (1000, 64, 64, 3)

        latent_obs, latent_next_obs = to_latent(o, o2)
        # latent_obs: (1000, 32)
        # latent_next_obs: (1000, 32)

        latent_obs = latent_obs.unsqueeze(0)
        a = a.unsqueeze(0)
        r = r.unsqueeze(0)
        d = d.unsqueeze(0)
        latent_next_obs = latent_next_obs.unsqueeze(0)

        loss = get_loss(latent_obs, a, r, d, latent_next_obs)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_sum += loss.item()
    print()
    print("="*60)
    print(loss_sum/n_workers)
    print("="*60)
    print()

    torch.save(mdrnn.state_dict(), "mdrnn.pt")


In [28]:
for _ in range(10000000):
    train()