## utils

---

> Internship neural networks
>
> Group 4: Reinforcement learning
>
> Deadline 28.02.23 23:59

---

In [None]:
import numpy as np
from torch.utils.data import Dataset

In [None]:
class Trajectory (object):

    def __init__(self):
        self.trajDict = {"observations": [],
                        "actions": [],
                        "rtgs": []}

    def push(self, observation, action, reward):
        """Save a transition"""
        self.trajDict["observations"].append(observation.flatten())
        self.trajDict["actions"].append(action)
        self.trajDict["rtgs"].append(reward)

    def pop(self):
        '''Delete a transition'''
        del self.trajDict["observations"][-1]
        del self.trajDict["actions"][-1]
        del self.trajDict["rtgs"][-1]

    def compute_rtgs(self, i):
        '''compute the rewards to go'''
        j = self.__len__()
        rewards = np.array(self.trajDict["rtgs"])[i:j]
        return np.sum(rewards)

    def reset(self):
        self.__init__()
        
    def __len__(self):
        return len(self.trajDict["observations"])

In [None]:
class TrajectoryDataset(Dataset):
    '''
    Create the dataset for the decision transformer 
    '''
    def __init__(self, trajectories, context_len):
        '''
        __init__ transform the dictionaries from the trajectories to np.arrays
        '''
        self.context_len = context_len

        self.trajectories = trajectories
        
        states = []
        for traj in self.trajectories:
            traj["actions"] = np.array(traj["actions"])
            traj["rtgs"] = np.array(traj["rtgs"])
            traj_len = len(traj['observations'])
            states.append(traj['observations'])
            traj['observations'] = np.array(traj['observations'])

    def __len__(self):
        return len(self.trajectories)

    def __getitem__(self, idx):
        '''
        __getitem__ creates tensors of the entries in the dataset
        
        :idx: index of traj
        :return: tensors for timesteps, states, actions, returns_to_go and traj_mask
        '''
        traj = self.trajectories[idx]
        traj_len = len(traj['observations'])

        if traj_len >= self.context_len:
            # sample random index to slice trajectory
            si = random.randint(0, traj_len - self.context_len)

            states = torch.from_numpy(traj['observations'][si : si + self.context_len])
            states = states.to(torch.float32)
            actions = torch.from_numpy(traj['actions'][si : si + self.context_len])
            returns_to_go = torch.from_numpy(traj['rtgs'][si : si + self.context_len])
            returns_to_go = returns_to_go.to(torch.float32)
            timesteps = torch.arange(start=si, end=si+self.context_len, step=1)

            # all ones since no padding
            traj_mask = torch.ones(self.context_len, dtype=torch.long)

        else:
            # need padding because traj is to short
            padding_len = self.context_len - traj_len

            # padding with zeros
            states = torch.from_numpy(traj['observations'])
            states = torch.cat([states,
                                torch.zeros(([padding_len] + list(states.shape[1:])),
                                dtype=states.dtype)],
                               dim=0)
            states = states.to(torch.float32)

            actions = torch.from_numpy(traj['actions'])
            actions = torch.cat([actions,
                                torch.zeros(([padding_len] + list(actions.shape[1:])),
                                dtype=actions.dtype)],
                               dim=0)

            returns_to_go = torch.from_numpy(traj['rtgs'])
            returns_to_go = torch.cat([returns_to_go,
                                torch.zeros(([padding_len] + list(returns_to_go.shape[1:])),
                                dtype=returns_to_go.dtype)],
                               dim=0)
            returns_to_go = returns_to_go.to(torch.float32)

            timesteps = torch.arange(start=0, end=self.context_len, step=1)
            
            # mask for fillers that are not in the traj
            traj_mask = torch.cat([torch.ones(traj_len, dtype=torch.long),
                                   torch.zeros(padding_len, dtype=torch.long)],
                                  dim=0)

        return  timesteps, states, actions, returns_to_go, traj_mask

In [None]:
def evaluate_on_env(agent, opponent, player, env,
                    num_eval_ep=100, max_test_ep_len=21, render=False):
    '''
    evaluate_on_env records games and gives back the avg reward, avg ep len and avg winrate
    
    :agent: our agent (decision transformer)
    :opponent: the agent of the opponent
    :player: player number of our agent
    :env: the game environement
    :num_eval_ep: how many games to play
    :max_test_ep_len: maximum length of a game
    :render: if you want to render the games
    :return: a dictionary with the avg reward, avg ep len and avg winrate
    '''
    results = {}
    total_reward = 0
    total_timesteps = 0

    wins = 0

    for i in range(num_eval_ep):

        # init episode
        running_state = env.reset()
        agent.reset_agent()
        running_reward = 0
            
        for t in range(max_test_ep_len):

            total_timesteps += 1
                
            available_actions = env.get_available_actions()
                
            if player == "p1":
                action_p1 = agent.select_action(t, running_reward, running_state, available_actions, againstDQN=False)
                running_state, running_reward = env.make_move(action_p1, "p1")
            else:
                action_p1 = opponent.select_action(running_state, available_actions, training=False)
                running_state, _ = env.make_move(action_p1, 'p1')
                
            if render:
                env.render()
            if env.isDone:
                if player == "p1":
                    total_reward += running_reward
                    wins += 1
                else:
                    total_reward -= 10
                break
                
            available_actions = env.get_available_actions()
            if player == "p2":
                action_p2 = agent.select_action(t, running_reward, running_state, available_actions, againstDQN=False)
                running_state, running_reward = env.make_move(action_p2, "p2")
            else:
                action_p2 = opponent.select_action(running_state, available_actions, training=False)
                running_state, _ = env.make_move(action_p2, 'p2')

            if render:
                env.render()
            if env.isDone:
                if player == "p1":
                    total_reward -= 10
                else:
                    total_reward += running_reward
                    wins += 1
                break
                    
    winrate = wins / num_eval_ep

    results['avg_reward'] = total_reward / num_eval_ep
    results['avg_ep_len'] = total_timesteps / num_eval_ep
    results['win_rate'] = winrate
    
    return results