In [1]:
import sys, os
import shutil

sys.path.append('/Users/jasmineli/Desktop/moral-ai-irl')
sys.path.append('/Users/jasmineli/Desktop/moral-ai-irl/human_aware_rl_master')
import torch
from torch import nn
import numpy as np
import pickle
import argparse
import matplotlib.pyplot as plt
from human_aware_rl.ppo.ppo_rllib_client import run
from human_aware_rl_master.human_aware_rl.human.process_dataframes import *
from human_aware_rl.dummy.rl_agent import *
from human_aware_rl.rllib.utils import get_base_ae
from overcooked_ai_py.agents.agent import AgentPair
from human_aware_rl.irl.config_model import get_train_config
from overcooked_ai_py.mdp.overcooked_mdp import OvercookedState

In [2]:
class TorchLinearReward(nn.Module):
    def __init__(self, n_input, n_h1=400, n_h2=1):
        super(TorchLinearReward, self).__init__()
        self.fc1 = nn.Linear(in_features=n_input, out_features=n_h1, bias=True)
        self.fc2 = nn.Linear(in_features=n_h1, out_features=n_h2, bias=True)
        self.act = nn.ELU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        # x = self.act(x)
        return x

    def get_theta(self):
        return [self.fc1.weight.detach()]

    def get_rewards(self, states):
        if type(states) == np.ndarray:
            states = torch.tensor(states, dtype=torch.float)
        with torch.no_grad():
            rewards = self.forward(states).detach()
        return rewards


In [3]:
def _loadEnvironment(config):
    mdp_params = config["environment_params"]["mdp_params"]
    env_params = config["environment_params"]["env_params"]
    ae = get_base_ae(mdp_params, env_params)
    env = ae.env
    
    return env

def _loadProcessedHumanData(data_path, view_traj=False):
    assert os.path.isfile(data_path)
    with open(data_path, 'rb') as file:
        human_data = pickle.load(file)
    
    gridworld = human_data['gridworld']
    trajectory = human_data['trajectory']

    states = []
    actions = []
    scores = []
    for i in range(len(trajectory)):
        state = []
        action = []
        score = []
        for j in range(len(trajectory[i])):
            state_dict = trajectory[i][j]
            s = state_dict['state']
            a = state_dict['joint_action']
            sc = state_dict['score']

            s = OvercookedState.from_dict(s)
            state.append(s)
            action.append(a)
            score.append(sc)
            
            if view_traj:
                print(gridworld.state_string(s))
        states.append(state)
        actions.append(action)
        scores.append(score)

    assert len(states) == len(trajectory)
    assert len(actions) == len(trajectory)
    assert len(scores) == len(trajectory)
    return states, actions, scores

def _convertAction2Index(actions):
    act = []
    for traj in actions:
        temp = []
        for idx in traj:
            act_0 = tuple(idx[0]) if type(idx[0]) == list else idx[0]
            act_1 = tuple(idx[1]) if type(idx[1]) == list else idx[1]
            temp.append([Action.ACTION_TO_INDEX[act_0], Action.ACTION_TO_INDEX[act_1]])
        act.append(temp)
    return act

def getVisitation(states, joint_action, scores, env):
    target_player_idx = 0
    num_game = len(states)
    print(f'number of games={num_game}')
    freq = {}
    for game, actions, score in zip(states,joint_action, scores):
        for s,a,sc in zip(game,actions, score):
            reward_features = env.human_coop_state_encoding(s, a, sc)[target_player_idx]
            reward_features = tuple(reward_features)
            if reward_features not in freq:
                freq[reward_features] = 0
            freq[reward_features] += 1
    
    for state in freq:
        freq[state] /= num_game
    return freq

def getExpertVisitation(env, data_path):
    states, actions, scores = _loadProcessedHumanData(data_path, view_traj=False)
    actions = _convertAction2Index(actions)
    state_visit = getVisitation(states, actions, scores, env)
    return state_visit

def getAgentVisitation(train_config, env): #get the feature expectations of a new policy using RL agent
    '''
    Trains an RL agent with the current reward function. 
    Then rolls out one trial of the trained agent and calculate the feature expectation of the RL agent.
    - train_config: the configuration taken by the rllib trainer
    
    Returns the feature expectation.
    '''
    # train and get rollouts
    try:
        results = run(train_config)
        states = results['evaluation']['states']
        actions = results['evaluation']['actions']
        scores = results['evaluation']['sparse_reward']
        actions = _convertAction2Index(actions)
        state_visit = getVisitation(states, actions, scores, env)
        return state_visit
    except Exception as e:
        print('ERROR: could not get Agent Visitation. -->' + str(e))

def getStatesAndGradient(expert_sv, agent_sv):
    # calculate the gradient for each of the state: (mu_agent - mu_expert)
    visit = {}
    for state in agent_sv:
        visit[state] = agent_sv[state]
    for state in expert_sv:
        if state not in visit:
            visit[state] = 0.0
        visit[state] -= expert_sv[state]
    
    # organize into NN input
    states = []
    grad = []
    for s in visit:
        state = torch.tensor(s, dtype=torch.float)
        states.append(state)
        grad.append(visit[s])
    states = torch.stack(states)
    grad = torch.tensor(grad, dtype=torch.float)
    grad = torch.unsqueeze(grad, dim=1)

    return states, grad

def load_checkpoint(file_path):
    assert os.path.isfile(file_path)
    with open(file_path, 'rb') as file:
        checkpoint = pickle.load(file)
    return checkpoint

In [4]:
trial = '_notebook'
data_path = ''
resume_from = ''

# init 
n_epochs = 10
i = 1

if not resume_from:
    # directory to save results
    cwd = os.getcwd()
    save_dir = f'{cwd}/result/human/T{trial}'
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    # make a copy of the config file
    path = os.path.join(save_dir, f'config.py')
    shutil.copy('config_model.py', path)

    print(f'initiating models and optimizers...')
    reward_obs_shape = torch.tensor([15])       # change if reward shape changed.
    reward_model = TorchLinearReward(reward_obs_shape)
    optim = torch.optim.SGD(reward_model.parameters(), lr=0.02, momentum=0.9, weight_decay=0.9)

    print(f'loading training configurations...')
    config = get_train_config()
    print(config['training_params']['num_gpus'])
    print(config['model_params']['use_lstm'])

    config['environment_params']['multi_agent_params']['custom_reward_func'] = reward_model.get_rewards

    print(f'getting expert trajectory and state visitation...')
    env = _loadEnvironment(config)
    expert_state_visit = getExpertVisitation(env, data_path)
    print(f'complete')
else:
    print(f'loading model checkpoint from {resume_from}...')
    checkpoint = load_checkpoint(resume_from)

    print(f'retrieving reward model and optimizer...')
    reward_model = checkpoint["reward_model"]
    optim = checkpoint['optimizer']

    print(f'loading configurations...')
    config = checkpoint['config']
    env = _loadEnvironment(config)
    i = checkpoint['current_epoch'] + 1 # advance to the next epoch

    print(f'getting expert trajectory and state visitation...')
    expert_state_visit = checkpoint['expert_svf']
    print(f'complete')

loading model checkpoint from /home/jasmine/moral-ai-irl/human_aware_rl_master/human_aware_rl/irl/result/human/T7/epoch=15.checkpoint...
retrieving reward model and optimizer...
loading configurations...
getting expert trajectory and state visitation...
complete


Train 1 epoch:

In [5]:
# expert_state_visit = getExpertVisitation(env, data_path)
print(f'{len(expert_state_visit)} states in expert state visitations')

print(f'model parameters:')
l = 0
for param in reward_model.parameters():
    print(f'layer {l}')
    print(f'maximum value={torch.max(param)}')
    print(f'minimum value={torch.min(param)}')
    l += 1

for state in expert_state_visit:
    tens = torch.tensor(state, dtype=torch.float)
    print(reward_model.get_rewards(tens))

765 states in expert state visitations
model parameters:
layer 0
maximum value=1566055.5
minimum value=-94658.109375
layer 1
maximum value=247052.03125
minimum value=-2.5441884994506836
layer 2
maximum value=11.985937118530273
minimum value=-2285666.0
layer 3
maximum value=-5.366361141204834
minimum value=-5.366361141204834
tensor([-1.2796e+15])
tensor([-1.1999e+15])
tensor([-1.1874e+15])
tensor([-1.0944e+15])
tensor([-1.0781e+15])
tensor([-1.0015e+15])
tensor([-9.4878e+14])
tensor([-8.7223e+14])
tensor([-8.2530e+14])
tensor([-7.4875e+14])
tensor([-8.3619e+14])
tensor([-8.5533e+14])
tensor([-8.5279e+14])
tensor([-8.6917e+14])
tensor([-8.6639e+14])
tensor([-7.8469e+14])
tensor([-1.0493e+15])
tensor([-1.0004e+15])
tensor([-1.0798e+15])
tensor([-1.0329e+15])
tensor([-1.1095e+15])
tensor([-1.1860e+15])
tensor([-1.1755e+15])
tensor([-1.2575e+15])
tensor([-1.2517e+15])
tensor([-1.3366e+15])
tensor([-1.1233e+15])
tensor([-1.0468e+15])
tensor([-9.9790e+14])
tensor([-9.5098e+14])
tensor([-8.744

tensor([-7.7413e+14])
tensor([-1.0172e+15])
tensor([-1.0938e+15])
tensor([-1.0368e+15])
tensor([-1.0204e+15])
tensor([-1.0998e+15])
tensor([-1.0835e+15])
tensor([-1.1600e+15])
tensor([-1.1073e+15])
tensor([-1.1838e+15])
tensor([-1.1896e+15])
tensor([-1.2717e+15])
tensor([-1.3566e+15])
tensor([-1.3038e+15])
tensor([-1.3230e+15])
tensor([-1.2387e+15])
tensor([-1.2551e+15])
tensor([-1.1702e+15])
tensor([-1.0615e+15])
tensor([-9.2775e+14])
tensor([-1.2231e+15])
tensor([-1.3025e+15])
tensor([-1.2154e+15])
tensor([-1.1305e+15])
tensor([-1.0457e+15])
tensor([-8.7469e+14])
tensor([-8.6835e+14])
tensor([-7.8665e+14])
tensor([-1.1645e+15])
tensor([-8.4234e+14])
tensor([-8.2596e+14])
tensor([-9.1021e+14])
tensor([-9.2354e+14])
tensor([-9.2273e+14])
tensor([-8.4364e+14])
tensor([-1.2124e+15])
tensor([-1.1960e+15])
tensor([-1.1195e+15])
tensor([-1.1031e+15])
tensor([-1.0214e+15])
tensor([-9.6865e+14])
tensor([-8.8375e+14])
tensor([-8.3683e+14])
tensor([-8.3365e+14])
tensor([-8.4286e+14])
tensor([-8

In [6]:
os.environ["CUDA_VISIBLE_DEVICES"]="2"
# assert config['training_params']['num_gpus'] == 0
# assert config['model_params']['use_lstm'] == False

agent_state_visit = getAgentVisitation(config, env)

DummyPolicy: layout=coop_experiment_1, agent=MAIDumbAgentRightCoop
DummyPolicy: layout=coop_experiment_1, agent=MAIDumbAgentRightCoop
0: ep rew mean=-3.936447637957651e+16, max=-3.671758776932762e+16, min=-4.233600199294976e+16
10: ep rew mean=-4.000261160289139e+16, max=-3.733422078741709e+16, min=-4.26547084305367e+16
20: ep rew mean=-4.068567448333046e+16, max=-3.838553590281011e+16, min=-4.305073166142669e+16
30: ep rew mean=-4.105756303717314e+16, max=-4.003360786717082e+16, min=-4.330363906582118e+16
40: ep rew mean=-4.109929647693354e+16, max=-4.037789741167411e+16, min=-4.330526658999091e+16
50: ep rew mean=-4.108333778249575e+16, max=-3.961533012901888e+16, min=-4.331991484438938e+16
60: ep rew mean=-4.112484017052229e+16, max=-4.057168754402918e+16, min=-4.332316962429338e+16
70: ep rew mean=-4.11049648440544e+16, max=-4.056384318891622e+16, min=-4.331207451580826e+16
80: ep rew mean=-4.111718843302145e+16, max=-4.003579481083085e+16, min=-4.331991457595392e+16
90: ep rew mea

In [7]:
print(len(agent_state_visit))
print(agent_state_visit)


75
{(4.0, 3.0, 8.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0): 1.0, (4.0, 3.0, 7.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0): 0.3, (4.0, 3.0, 7.0, 2.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0): 0.06, (4.0, 3.0, 7.0, 3.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0): 0.78, (4.0, 3.0, 7.0, 3.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0): 0.9, (4.0, 3.0, 6.0, 3.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0): 0.96, (4.0, 3.0, 6.0, 3.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0): 0.78, (4.0, 3.0, 7.0, 3.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0): 1.04, (4.0, 3.0, 7.0, 4.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0): 0.64, (4.0, 3.0, 7.0, 5.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0): 0.5, (4.0, 3.0, 7.0, 6.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0): 0.44, (4.0, 3.0, 8.0, 6.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0): 0.52, (4.0

In [8]:
# compute the rewards and gradients for occurred states
states, grad_r = getStatesAndGradient(expert_state_visit, agent_state_visit)
print(states)
print(grad_r)

tensor([[ 4.,  3.,  8.,  ...,  0.,  0.,  0.],
        [ 4.,  3.,  7.,  ...,  0.,  0.,  0.],
        [ 4.,  3.,  7.,  ...,  0.,  0.,  0.],
        ...,
        [ 2.,  1.,  8.,  ...,  0.,  0.,  1.],
        [ 3.,  3., 10.,  ...,  0.,  0.,  0.],
        [ 3.,  3.,  9.,  ...,  0.,  0.,  0.]])
tensor([[ 1.0000e+00],
        [ 3.0000e-01],
        [ 6.0000e-02],
        [ 7.7740e-01],
        [ 7.7773e-01],
        [ 9.5324e-01],
        [ 5.9842e-01],
        [ 1.0364e+00],
        [ 6.3740e-01],
        [ 4.9428e-01],
        [ 4.3636e-01],
        [ 5.1792e-01],
        [ 1.1938e+00],
        [ 6.9688e-01],
        [ 5.7844e-01],
        [ 5.5896e-01],
        [ 4.9792e-01],
        [ 7.9022e-01],
        [ 2.4664e-01],
        [ 4.1948e-01],
        [ 3.1103e-01],
        [ 5.1507e-01],
        [ 3.1883e-01],
        [ 3.1896e-01],
        [ 3.6716e-01],
        [ 5.5877e-01],
        [ 1.1948e-01],
        [ 1.6000e-01],
        [ 1.9184e-01],
        [ 2.8739e-01],
        [ 2.1948e-01

In [9]:
reward = reward_model.forward(states)
print(f'rewards = {reward}')

rewards = tensor([[-9.7662e+14],
        [-8.9491e+14],
        [-9.7367e+14],
        [-1.0530e+15],
        [-1.0522e+15],
        [-9.6589e+14],
        [-9.6508e+14],
        [-1.0501e+15],
        [-1.1295e+15],
        [-1.2061e+15],
        [-1.2826e+15],
        [-1.3647e+15],
        [-1.4496e+15],
        [-1.3653e+15],
        [-1.2804e+15],
        [-1.2007e+15],
        [-1.1242e+15],
        [-1.0468e+15],
        [-8.9687e+14],
        [-9.7644e+14],
        [-1.2818e+15],
        [-1.4488e+15],
        [-1.1233e+15],
        [-8.9769e+14],
        [-9.7563e+14],
        [-1.0493e+15],
        [-1.1162e+15],
        [-1.1889e+15],
        [-1.3645e+15],
        [-1.2796e+15],
        [-1.0343e+15],
        [-1.0179e+15],
        [-9.5256e+14],
        [-1.1287e+15],
        [-1.2053e+15],
        [-1.3639e+15],
        [-1.2693e+15],
        [-1.3513e+15],
        [-1.4362e+15],
        [-1.0476e+15],
        [-1.1999e+15],
        [-1.0502e+15],
        [-1.4324e+15],
 

In [10]:
# gradient descent
optim.zero_grad()
reward.backward(gradient=grad_r)
optim.step()

In [11]:
print(f'model parameters:')
for param in reward_model.parameters():
    print(param)

model parameters:
Parameter containing:
tensor([[-0.1776, -0.0067, -0.1200,  ...,  0.2136, -0.1268,  0.0154],
        [-0.1743,  0.1899,  0.0041,  ...,  0.0039,  0.1239,  0.0600],
        [-0.0362,  0.2701, -0.1467,  ...,  0.0422, -0.1213, -0.0333],
        ...,
        [ 0.0495,  0.0919, -0.2338,  ...,  0.1424, -0.1996, -0.1917],
        [ 0.0207,  0.2336,  0.2489,  ..., -0.0929, -0.1018,  0.2491],
        [-0.0982,  0.0123, -0.1801,  ..., -0.0764, -0.1403, -0.1676]],
       requires_grad=True)
Parameter containing:
tensor([-9.4050e-02, -1.9384e-02,  1.7323e-01,  1.1227e-01,  8.8740e-02,
         2.2495e-01, -8.2278e-03, -1.3150e-01,  1.7644e-01, -5.5884e-02,
         1.9842e-02,  4.0598e-02, -1.8294e-01,  1.4501e-01, -1.5486e-01,
         7.4371e-02,  1.1939e-03, -1.0314e-01,  1.7847e-02,  1.0529e-01,
         2.4615e-01, -4.9848e-02, -2.0044e-01, -1.8187e-01, -1.3087e-01,
         1.9807e-01,  4.1375e-02,  3.0948e-02, -1.5169e-01, -2.1891e-01,
        -2.3173e-01,  1.0497e-01,  1.73

In [12]:
reward = reward_model.forward(states)
print(f'rewards = {reward}')

rewards = tensor([[ -9.1838],
        [-10.4790],
        [-12.2273],
        [-14.4247],
        [-14.1154],
        [-14.4948],
        [-13.8008],
        [-16.2893],
        [-18.1002],
        [-19.8660],
        [-20.0323],
        [-16.5329],
        [-17.6014],
        [-14.0788],
        [-14.8102],
        [-14.5213],
        [-14.4973],
        [-14.8755],
        [-16.6379],
        [-18.4745],
        [-20.6102],
        [-20.7840],
        [-20.8070],
        [-20.8197],
        [-10.4281],
        [ -8.9988],
        [-13.7246],
        [-13.3956],
        [-14.1525],
        [-14.5515],
        [-14.9178],
        [-15.8546],
        [-17.6738],
        [-19.4545],
        [-20.4160],
        [-20.4694],
        [-20.4429],
        [-18.8878],
        [-16.8653],
        [-17.6302],
        [-16.1726],
        [-16.5063],
        [-14.6494],
        [-11.3701],
        [ -9.4943],
        [ -7.0167],
        [ -5.1615],
        [ -5.9708],
        [ -6.6312],
        [ 