In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np  
import matplotlib.pyplot as plt
from class_snapbot import Snapbot4EnvClass
from class_dlpg import DeepLatentPolicyGradientClass

env        = Snapbot4EnvClass(VERBOSE=True, condition=False, render_mode=None)
PriorDLPG  = DeepLatentPolicyGradientClass(
                                        name     = 'PriorDLPG',              
                                        x_dim    = 160,              # input dimension
                                        c_dim    = 3,               # condition dimension
                                        z_dim    = 32,               # latent dimension
                                        h_dims   = [128, 128],          # hidden dimensions of encoder (and decoder)
                                        actv_enc = nn.ReLU(),        # encoder activation
                                        actv_dec = nn.ReLU(),        # decoder activation
                                        actv_q   = nn.Softplus(),    # q activation
                                        actv_out = None,             # output activation
                                        var_max  = -1,             # maximum variance
                                        device   = 'cuda:0'
                                        )
PriorDLPG.to(PriorDLPG.device)
PriorDLPG.load_state_dict(torch.load("dlpg/14/weights/dlpg_model_weights_200.pth", map_location='cuda:0'))

n_sample = 1000
c = torch.zeros(size=(n_sample, 3)).to(PriorDLPG.device)
c[:, 1] = 1

x_train = PriorDLPG.sample_x(
                            c             = c,
                            n_sample      = n_sample,
                            SKIP_Z_SAMPLE = True
                            ).reshape(n_sample, 8, -1)

x_train = x_train[:, :, ::4].detach()
x_train = x_train.reshape(n_sample, -1)
print("x_train shape: {}".format(x_train.shape))

Snapbot(4legs) Environment
Obs Dim: [103] Act Dim: [8] dt:[0.02] Condition:[False]
ctrl_coef:[0] body_coef:[0] jump_coef:[0] vel_coef:[0] head_coef:[0]
x_train shape: torch.Size([1000, 40])


In [2]:
import random
class ConditionalVariationalAutoEncoder(nn.Module):
    def __init__(
        self,
        name     = 'DLPG',              
        x_dim    = 784,              # input dimension
        c_dim    = 10,               # condition dimension
        z_dim    = 16,               # latent dimension
        h_dims   = [64,32],          # hidden dimensions of encoder (and decoder)
        actv_enc = nn.ReLU(),        # encoder activation
        actv_dec = nn.ReLU(),        # decoder activation
        actv_out = None,             # output activation
        var_max  = None,             # maximum variance
        device   = 'cpu'
        ):
        """
            Initialize
        """
        super(ConditionalVariationalAutoEncoder,self).__init__()
        self.name = name
        self.x_dim    = x_dim
        self.c_dim    = c_dim
        self.z_dim    = z_dim
        self.h_dims   = h_dims
        self.actv_enc = actv_enc
        self.actv_dec = actv_dec
        self.actv_out = actv_out
        self.var_max  = var_max
        self.device   = device
        # Initialize layers
        self.init_layers()
        self.init_params()
                
    def init_layers(self):
        """
            Initialize layers
        """
        self.layers = {}
        
        # Encoder part
        h_dim_prev = self.x_dim + self.c_dim
        for h_idx,h_dim in enumerate(self.h_dims):
            self.layers['enc_%02d_lin'%(h_idx)]  = \
                nn.Linear(h_dim_prev,h_dim,bias=True)
            self.layers['enc_%02d_actv'%(h_idx)] = \
                self.actv_enc
            h_dim_prev = h_dim
        self.layers['z_mu_lin']  = nn.Linear(h_dim_prev,self.z_dim,bias=True)
        self.layers['z_var_lin'] = nn.Linear(h_dim_prev,self.z_dim,bias=True)
        
        # Decoder part
        h_dim_prev = self.z_dim + self.c_dim
        for h_idx,h_dim in enumerate(self.h_dims[::-1]):
            self.layers['dec_%02d_lin'%(h_idx)]  = \
                nn.Linear(h_dim_prev,h_dim,bias=True)
            self.layers['dec_%02d_actv'%(h_idx)] = \
                self.actv_dec
            h_dim_prev = h_dim
        self.layers['out_lin'] = nn.Linear(h_dim_prev,self.x_dim,bias=True)
        
        # Append parameters
        self.param_dict = {}
        for key in self.layers.keys():
            layer = self.layers[key]
            if isinstance(layer,nn.Linear):
                self.param_dict[key+'_w'] = layer.weight
                self.param_dict[key+'_b'] = layer.bias
        self.cvae_parameters = nn.ParameterDict(self.param_dict)
        
    def xc_to_z_mu(
        self,
        x = torch.randn(2,784),
        c = torch.randn(2,10)
        ):
        """
            x and c to z_mu
        """
        if c is not None:
            net = torch.cat((x,c),dim=1)
        else:
            net = x
        for h_idx,_ in enumerate(self.h_dims):
            net = self.layers['enc_%02d_lin'%(h_idx)](net)
            net = self.layers['enc_%02d_actv'%(h_idx)](net)
        z_mu = self.layers['z_mu_lin'](net)
        return z_mu
    
    def xc_to_h(
        self,
        x = torch.randn(2,784),
        c = torch.randn(2,10)
        ):
        """
            x and c to z_mu
        """
        hidden_value = []
        if c is not None:
            net = torch.cat((x,c),dim=1)
        else:
            net = x
        for h_idx,_ in enumerate(self.h_dims):
            net = self.layers['enc_%02d_lin'%(h_idx)](net)
            # hidden_value.append(net.detach())
            net = self.layers['enc_%02d_actv'%(h_idx)](net)
            hidden_value.append(net.detach())
        return hidden_value

    def xc_to_z_var(
        self,
        x = torch.randn(2,784),
        c = torch.randn(2,10)
        ):
        """
            x and c to z_var
        """
        if c is not None:
            net = torch.cat((x,c),dim=1)
        else:
            net = x
        for h_idx,_ in enumerate(self.h_dims):
            net = self.layers['enc_%02d_lin'%(h_idx)](net)
            net = self.layers['enc_%02d_actv'%(h_idx)](net)
        net = self.layers['z_var_lin'](net)
        if self.var_max is None:
            net = torch.exp(net)
        else:
            net = self.var_max*torch.sigmoid(net)
        z_var = net
        return z_var
    
    def zc_to_x_recon(
        self,
        z = torch.randn(2,16),
        c = torch.randn(2,10)
        ):
        """
            z and c to x_recon
        """
        if c is not None:
            net = torch.cat((z,c),dim=1)
        else:
            net = z
        for h_idx,_ in enumerate(self.h_dims[::-1]):
            net = self.layers['dec_%02d_lin'%(h_idx)](net)
            net = self.layers['dec_%02d_actv'%(h_idx)](net)
        net = self.layers['out_lin'](net)
        if self.actv_out is not None:
            net = self.actv_out(net)
        x_recon = net
        return x_recon
    
    def xc_to_z_sample(
        self,
        x = torch.randn(2,784),
        c = torch.randn(2,10)
        ):
        """
            x and c to z_sample
        """
        z_mu,z_var = self.xc_to_z_mu(x=x,c=c),self.xc_to_z_var(x=x,c=c)
        eps_sample = torch.randn(
            size=z_mu.shape,dtype=torch.float32).to(self.device)
        z_sample   = z_mu + torch.sqrt(z_var+1e-10)*eps_sample
        return z_sample
    
    def xc_to_x_recon(
        self,
        x             = torch.randn(2,784),
        c             = torch.randn(2,10), 
        STOCHASTICITY = True
        ):
        """
            x and c to x_recon
        """
        if STOCHASTICITY:
            z_sample = self.xc_to_z_sample(x=x,c=c)
        else:
            z_sample = self.xc_to_z_mu(x=x,c=c)
        x_recon = self.zc_to_x_recon(z=z_sample,c=c)
        return x_recon
    
    def sample_x(
        self,
        c             = torch.randn(5,10),
        n_sample      = 5,
        SKIP_Z_SAMPLE = False
        ):
        """
            Sample x
        """
        z_sample = torch.randn(
            size=(n_sample,self.z_dim),dtype=torch.float32).to(self.device)
        if SKIP_Z_SAMPLE:
            return self.zc_to_x_recon(z=z_sample,c=c)
        else:
            return self.zc_to_x_recon(z=z_sample,c=c),z_sample
    
    def init_params(self,seed=0):
        """
            Initialize parameters
        """
        # Fix random seed
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        # Init
        for key in self.layers.keys():
            layer = self.layers[key]
            if isinstance(layer,nn.Linear):
                nn.init.normal_(layer.weight,mean=0.0,std=0.01)
                nn.init.zeros_(layer.bias)
            elif isinstance(layer,nn.BatchNorm2d):
                nn.init.constant_(layer.weight,1.0)
                nn.init.constant_(layer.bias,0.0)
            elif isinstance(layer,nn.Conv2d):
                nn.init.kaiming_normal_(layer.weight)
                nn.init.zeros_(layer.bias)
    
    def loss_recon(
        self,
        x               = torch.randn(2,784),
        c               = torch.randn(2,10),
        LOSS_TYPE       = 'L1+L2',
        recon_loss_gain = 1.0,
        STOCHASTICITY   = True
        ):
        """
            Recon loss
        """
        x_recon = self.xc_to_x_recon(x=x,c=c,STOCHASTICITY=STOCHASTICITY)
        if (LOSS_TYPE == 'L1') or (LOSS_TYPE == 'MAE'):
            errs = torch.mean(torch.abs(x-x_recon),axis=1)
        elif (LOSS_TYPE == 'L2') or (LOSS_TYPE == 'MSE'):
            errs = torch.mean(torch.square(x-x_recon),axis=1)
        elif (LOSS_TYPE == 'L1+L2') or (LOSS_TYPE == 'EN'):
            errs = torch.mean(
                0.5*(torch.abs(x-x_recon)+torch.square(x-x_recon)),axis=1)
        else:
            raise Exception("VAE:[%s] Unknown loss_type:[%s]"%
                            (self.name,LOSS_TYPE))
        return recon_loss_gain*torch.mean(errs)
    
    def loss_kl(
        self,
        x = torch.randn(2,784),
        c = torch.randn(2,10)
        ):
        """
            KLD loss
        """
        z_mu     = self.xc_to_z_mu(x=x,c=c)
        z_var    = self.xc_to_z_var(x=x,c=c)
        z_logvar = torch.log(z_var)
        errs     = 0.5*torch.sum(z_var + z_mu**2 - 1.0 - z_logvar,axis=1)
        return torch.mean(errs)
        
    def loss_total(
        self,
        x               = torch.randn(2,784),
        c               = torch.randn(2,10),
        LOSS_TYPE       = 'L1+L2',
        recon_loss_gain = 1.0,
        STOCHASTICITY   = True,
        beta            = 1.0
        ):
        """
            Total loss
        """
        loss_recon_out = self.loss_recon(
            x               = x,
            c               = c,
            LOSS_TYPE       = LOSS_TYPE,
            recon_loss_gain = recon_loss_gain,
            STOCHASTICITY   = STOCHASTICITY
        )
        loss_kl_out    = beta*self.loss_kl(
            x = x,
            c = c
        )
        loss_total_out = loss_recon_out + loss_kl_out
        info           = {'loss_recon_out' : loss_recon_out,
                          'loss_kl_out'    : loss_kl_out,
                          'loss_total_out' : loss_total_out,
                          'beta'           : beta}
        return loss_total_out,info

    def update(
        self,
        x  = torch.randn(2,784),
        c  = torch.randn(2,10),
        lr = 0.001,
        max_iter   = 2400,
        batch_size = 100
        ):
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        n_x       = x.shape[0]
        loss_list = []
        for n_iter in range(max_iter):
            self.train()
            rand_idx = np.random.permutation(n_x)[:batch_size]
            x_batch  = x[rand_idx, :].to(self.device)
            if c is not None:
                c_batch = c[rand_idx, :].to(self.device)
            else:
                c_batch = None
            total_loss, _ = self.loss_total(x=x_batch, c=c_batch, LOSS_TYPE='L2')
            loss_list.append(total_loss.item())
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            
            if (n_iter+1) % 10 == 0:
                print("{}/{} Clear, Loss: {}".format(n_iter+1, max_iter, total_loss.item()))
        
        return loss_list

In [3]:
VAE  = ConditionalVariationalAutoEncoder(
        name     = 'CVAE',              
        x_dim    = 40,              # input dimension
        c_dim    = 0,               # condition dimension
        z_dim    = 2,               # latent dimension
        h_dims   = [32, 32],          # hidden dimensions of encoder (and decoder)
        actv_enc = nn.ReLU(),        # encoder activation
        actv_dec = nn.ReLU(),        # decoder activation
        actv_out = None,             # output activation
        var_max  = 0.1,             # maximum variance
        device   = 'cpu'
        )
loss = VAE.update(x=x_train, c=None, lr=0.01, max_iter=500, batch_size=200)

10/500 Clear, Loss: 1.7095259428024292
20/500 Clear, Loss: 1.4097232818603516
30/500 Clear, Loss: 1.406043529510498
40/500 Clear, Loss: 1.404691219329834
50/500 Clear, Loss: 1.4027942419052124
60/500 Clear, Loss: 1.402838110923767
70/500 Clear, Loss: 1.402674674987793
80/500 Clear, Loss: 1.402874231338501
90/500 Clear, Loss: 1.4036569595336914
100/500 Clear, Loss: 1.4037344455718994
110/500 Clear, Loss: 1.4028949737548828
120/500 Clear, Loss: 1.4027738571166992
130/500 Clear, Loss: 1.4026340246200562
140/500 Clear, Loss: 1.402604103088379
150/500 Clear, Loss: 1.4056934118270874
160/500 Clear, Loss: 1.4037137031555176
170/500 Clear, Loss: 1.4028632640838623
180/500 Clear, Loss: 1.4026100635528564
190/500 Clear, Loss: 1.4026598930358887
200/500 Clear, Loss: 1.4026132822036743
210/500 Clear, Loss: 1.4025967121124268
220/500 Clear, Loss: 1.4028358459472656
230/500 Clear, Loss: 1.402712106704712
240/500 Clear, Loss: 1.402601957321167
250/500 Clear, Loss: 1.4026130437850952
260/500 Clear, Lo

In [4]:
VAE.xc_to_x_recon(
        x             = x_train[110].cpu(),
        c             = None,
        STOCHASTICITY = True
        )

tensor([-0.3186, -0.0230, -0.3659, -0.0636, -0.3244, -0.0975, -0.1925, -0.2511,
        -0.1959,  0.0181, -0.5489,  0.4296, -0.9305,  0.5267, -0.9182,  0.3270,
        -0.6384, -0.2574, -0.4520, -0.5038, -0.3987, -0.5641, -0.5327, -0.1845,
        -0.4308,  0.2725, -0.2853,  0.0968, -0.3812, -0.1691, -0.4881, -0.1950,
        -0.3558,  0.0300, -0.1175,  0.1186, -0.2124,  0.0561, -0.3226, -0.0352],
       grad_fn=<AddBackward0>)

In [24]:
x_train[721]

tensor([-0.3184, -0.0233, -0.3655, -0.0642, -0.3229, -0.0995, -0.1925, -0.2505,
        -0.1967,  0.0174, -0.5478,  0.4247, -0.9259,  0.5211, -0.9129,  0.3253,
        -0.6349, -0.2561, -0.4483, -0.5023, -0.3952, -0.5650, -0.5300, -0.1869,
        -0.4309,  0.2683, -0.2869,  0.0921, -0.3818, -0.1703, -0.4887, -0.1936,
        -0.3573,  0.0313, -0.1180,  0.1168, -0.2125,  0.0555, -0.3228, -0.0354],
       device='cuda:0')