In [174]:
import torch
from torch.nn import functional as F
from abc import ABC, abstractmethod

from model.modules.networks import MLPNetwork

class IntrinsicReward(ABC):
    """ Abstract class for an Intrinsic Reward Model. """
    
    @abstractmethod
    def init_new_episode(self):
        """
        Initialise model at start of new episode.
        """
        raise NotImplementedError
    
    @abstractmethod
    def set_train(self, device):
        """
        Set to training mode and put networks on given device.
        Inputs:
            device (str): CUDA device.
        """
        raise NotImplementedError
    
    @abstractmethod
    def set_eval(self, device):
        """
        Set to evaluation mode and put networks on given device.
        Inputs:
            device (str): CUDA device.
        """
        raise NotImplementedError
    
    @abstractmethod
    def get_reward(self, state):
        """
        Returns the reward computed from given state.
        Inputs:
            state (torch.Tensor): State used for computing reward, dim=(1, state_dim).
        """
        raise NotImplementedError
    
    @abstractmethod
    def train(self, state_batch, act_batch):
        """
        Set to evaluation mode and put networks on given device.
        Inputs:
            state_batch (torch.Tensor): Batch of states, dim=(episode_length, batch_size, state_dim).
            act_batch (torch.Tensor): Batch of actions, dim=(episode_length, batch_size, action_dim).
        """
        raise NotImplementedError
        
    @abstractmethod
    def get_params(self):
        """
        Returns state dicts of networks and optimizers.
        """
        raise NotImplementedError
        
    @abstractmethod
    def load_params(self, params):
        """
        Load parameters in networks and optimizers.
        Inputs:
            params (dict): Dictionary of state dicts.
        """
        raise NotImplementedError

class E3B:
    
    def __init__(self, input_dim, act_dim, enc_dim, hidden_dim=64, ridge=0.1, lr=1e-4):
        self.enc_dim = enc_dim
        self.ridge = ridge
        # State encoder
        self.encoder = MLPNetwork(input_dim, enc_dim, hidden_dim, norm_in=False)
        # Inverse dynamics model
        self.inv_dyn = MLPNetwork(2 * enc_dim, act_dim, hidden_dim, norm_in=False)
        # Inverse covariance matrix
        self.ridge = ridge
        self.inv_cov = torch.eye(enc_dim) * (1.0 / self.ridge)
        self.outer_product_buffer = torch.empty(enc_dim, enc_dim)
        
        # Optimizers
        self.encoder_optim = torch.optim.Adam(
            self.encoder.parameters(), 
            lr=lr)
        self.inv_dyn_optim = torch.optim.Adam(
            self.inv_dyn.parameters(), 
            lr=lr)
    
    def init_new_episode(self):
        self.inv_cov = torch.eye(self.enc_dim) * (1.0 / self.ridge)

    def set_train(self, device):
        self.encoder.train()
        self.encoder = self.encoder.to(device)
        self.inv_dyn.train()
        self.inv_dyn = self.inv_dyn.to(device)
        self.device = device

    def set_eval(self, device):
        self.encoder.eval()
        self.encoder = self.encoder.to(device)
        self.device = device
        
    def get_reward(self, state):
        """
        Inputs:
            state (torch.Tensor): dim=(1, state_dim)
        """
        # Encode state
        enc_state = self.encoder(state).squeeze().detach()
        # Compute the intrinsic reward
        u = torch.mv(self.inv_cov, enc_state)
        int_reward = torch.dot(enc_state, u).item()
        # Update covariance matrix
        torch.outer(u, u, out=self.outer_product_buffer)
        torch.add(
            self.inv_cov, self.outer_product_buffer, 
            alpha=-(1. / (1. + int_reward)), out=self.inv_cov)
        return int_reward
    
    def train(self, state_batch, act_batch):
        """
        Inputs:
            state_batch (torch.Tensor): Batch of states, dim=(episode_length, 
                batch_size, state_dim).
            act_batch (torch.Tensor): Batch of actions, dim=(episode_length, 
                batch_size, action_dim).
        """
        # Encode states
        enc_all_states_b = self.encoder(state_batch)
        enc_states_b = enc_all_states_b[:-1]
        enc_next_states_b = enc_all_states_b[1:]
        # Run inverse dynamics model
        inv_dyn_inputs = torch.cat((enc_states_b, enc_next_states_b), dim=-1)
        pred_actions = self.inv_dyn(inv_dyn_inputs)
        # Compute loss
        # index_act_batch = act_batch.max(dim=-1)[1]
        # inv_dyn_loss = F.nll_loss(
        #     F.log_softmax(torch.flatten(pred_actions, 0, 1), dim=-1),
        #     target=torch.flatten(index_act_batch, 0, 1),
        #     reduction='none')
        # inv_dyn_loss = inv_dyn_loss.view_as(index_act_batch)
        # loss = torch.sum(torch.mean(inv_dyn_loss, dim=1))
        print(pred_actions.shape, act_batch.shape)
        loss = F.mse_loss(pred_actions, act_batch)
        # Backward pass
        self.encoder_optim.zero_grad()
        self.inv_dyn_optim.zero_grad()
        loss.backward()
        self.encoder_optim.step()
        self.inv_dyn_optim.step()
        return loss
    
    def get_params(self):
        return {'encoder': self.encoder.state_dict(),
                'inv_dyn': self.inv_dyn.state_dict(),
                'encoder_optim': self.encoder_optim.state_dict(),
                'inv_dyn_optim': self.inv_dyn_optim.state_dict()}

    def load_params(self, params):
        self.encoder.load_state_dict(params['encoder'])
        self.inv_dyn.load_state_dict(params['inv_dyn'])
        self.encoder_optim.load_state_dict(params['encoder_optim'])
        self.inv_dyn_optim.load_state_dict(params['inv_dyn_optim'])

In [300]:
m = E3B(80, 6, 64)
state = torch.ones(1, 80)
m.set_eval('cpu')
#m.get_reward(state)

In [307]:
state = torch.ones(1, 80)
state[0, 69] *= 1
m.get_reward(state)

0.16666549444198608

In [268]:
m = E3B(80, 6, 64)
state_batch = torch.ones(2, 81, 32, 80)
act_batch = torch.zeros(2, 80, 32, 3)
act_batch[:, :, :, 1] = 1
act_batch = torch.cat(tuple(act_batch), dim=-1)
m.set_train('cpu')
m.train(state_batch[0], act_batch)

torch.Size([80, 32, 6]) torch.Size([80, 32, 6])


tensor(2.9394, grad_fn=<MseLossBackward0>)

In [244]:
m.train(state_batch[0], act_batch)

torch.Size([80, 32, 6]) torch.Size([80, 32, 6])


tensor(3.1116, grad_fn=<MseLossBackward0>)

In [219]:
act_batch = torch.zeros(2, 10, 4, 3)
act_batch[0] = 1
torch.cat(tuple(act_batch), dim=-1).shape

torch.Size([10, 4, 6])

In [134]:
state_batch[0].shape

torch.Size([11, 4, 10])