In [124]:
import torch
from torch.nn import functional as F

from model.modules.networks import MLPNetwork


class E3B:
    
    def __init__(self, input_dim, act_dim, enc_dim, hidden_dim=64, ridge=0.1):
        self.enc_dim = enc_dim
        self.ridge = ridge
        # State encoder
        self.encoder = MLPNetwork(input_dim, enc_dim, hidden_dim, norm_in=False)
        # Inverse dynamics model
        self.inv_dyn = MLPNetwork(2 * enc_dim, act_dim, hidden_dim, norm_in=False)
        # Inverse covariance matrix
        self.ridge = ridge
        self.inv_cov = torch.eye(enc_dim) * (1.0 / self.ridge)
        self.outer_product_buffer = torch.empty(enc_dim, enc_dim)
        
        # Optimizers
        self.encoder_optim = torch.optim.Adam(
            self.encoder.parameters(), 
            lr=1e-4)
        self.inv_dyn_optim = torch.optim.Adam(
            self.inv_dyn.parameters(), 
            lr=1e-4)
    
    def init_new_episode(self):
        self.inv_cov = torch.eye(self.enc_dim) * (1.0 / self.ridge)

    def set_train(self, device):
        self.encoder.train()
        self.encoder = self.encoder.to(device)
        self.inv_dyn.train()
        self.inv_dyn = self.inv_dyn.to(device)
        self.device = device

    def set_eval(self, device):
        self.encoder.eval()
        self.encoder = self.encoder.to(device)
        self.device = device
        
    def get_reward(self, state):
        """
        Inputs:
            state (torch.Tensor): dim=(1, state_dim)
        """
        # Encode state
        enc_state = self.encoder(state).squeeze().detach()
        # Compute the intrinsic reward
        u = torch.mv(self.inv_cov, enc_state)
        int_reward = torch.dot(enc_state, u).item()
        # Update covariance matrix
        torch.outer(u, u, out=self.outer_product_buffer)
        torch.add(
            self.inv_cov, self.outer_product_buffer, 
            alpha=-(1. / (1. + int_reward)), out=self.inv_cov)
        return int_reward
    
    def train(self, state_batch, act_batch):
        # Reshape tensors
        nb_agents, ep_len, batch_size, state_dim = state_batch.shape
        state_batch = torch.cat(tuple(state_batch[:]), dim=1)
        act_batch = torch.cat(tuple(act_batch[:]), dim=1)
        # Encode states
        enc_all_states_b = self.encoder(state_batch)
        enc_states_b = enc_all_states_b[:-1]
        enc_next_states_b = enc_all_states_b[1:]
        # Run inverse dynamics model
        inv_dyn_inputs = torch.cat((enc_states_b, enc_next_states_b), dim=-1)
        pred_actions = self.inv_dyn(inv_dyn_inputs)
        # Compute loss
        index_act_batch = act_batch.max(dim=-1)[1]
        inv_dyn_loss = F.nll_loss(
            F.log_softmax(torch.flatten(pred_actions, 0, 1), dim=-1),
            target=torch.flatten(index_act_batch, 0, 1),
            reduction='none')
        inv_dyn_loss = inv_dyn_loss.view_as(index_act_batch)
        loss = torch.sum(torch.mean(inv_dyn_loss, dim=1))
        # Backward pass
        self.encoder_optim.zero_grad()
        self.inv_dyn_optim.zero_grad()
        loss.backward()
        self.encoder_optim.step()
        self.inv_dyn_optim.step()
        return loss
    
    def get_params(self):
        return {'encoder': self.encoder.state_dict(),
                'inv_dyn': self.inv_dyn.state_dict(),
                'encoder_optim': self.encoder_optim.state_dict(),
                'inv_dyn_optim': self.inv_dyn_optim.state_dict()}

    def load_params(self, params):
        self.encoder.load_state_dict(params['encoder'])
        self.inv_dyn.load_state_dict(params['inv_dyn'])
        self.encoder_optim.load_state_dict(params['encoder_optim'])
        self.inv_dyn_optim.load_state_dict(params['inv_dyn_optim'])

In [43]:
m = E3B(10, 3, 32)
state = torch.ones(1, 10)
m.set_eval('cpu')
m.get_reward(state)

111.61341857910156

In [48]:
state = torch.ones(1, 10)
state[0, 5] = 0
m.get_reward(state)

9.89312744140625

In [128]:
m = E3B(10, 3, 32)
state_batch = torch.ones(2, 11, 4, 10)
act_batch = torch.zeros(2, 10, 4, 3)
act_batch[:, :, :, 1] = 1
m.set_train('cpu')
m.train(state_batch, act_batch)

tensor(8.2353, grad_fn=<SumBackward0>)