In [1]:
import pdb
import re
import random
import utils
import numpy as np
import torch
import torch.nn as nn
import domain
import pandas as pd

from torch import optim
from torch import autograd
from ipywidgets import interact
from agent import *
from utils import ContextGenerator
from dialog import Dialog, DialogLogger
from models.rnn_model import RnnModel
from models.latent_clustering_model import LatentClusteringPredictionModel, BaselineClusteringModel
from agent import RnnAgent, RnnRolloutAgent, RlAgent, HierarchicalAgent
from domain import get_domain
from nltk import ngrams
from tqdm.notebook import tqdm

In [2]:
def get_agent_type(model, smart=False):
    if isinstance(model, LatentClusteringPredictionModel):
        if smart:
            return LatentClusteringRolloutAgent
        else:
            return LatentClusteringAgent
    elif isinstance(model, RnnModel):
        if smart:
            return RnnRolloutAgent
        else:
            return RnnAgent
    elif isinstance(model, BaselineClusteringModel):
        if smart:
            return BaselineClusteringRolloutAgent
        else:
            return BaselineClusteringAgent
    else:
        assert False, 'unknown model type: %s' % (model)

In [3]:
class Arguments:
    alice_model_file = 'rnn_model.th'
    alice_forward_model_file = ''
    bob_model_file = 'rnn_model.th'
    bob_forward_model_file = ''
    context_file = 'data/negotiate/selfplay.txt'
    temperature = 0.5
    pred_temperature =1.0
    verbose = True
    seed = 1
    score_threshold = 6
    max_turns = 20
    log_file = ''
    smart_alice = False
    diverse_alice = False
    rollout_bsz = 3
    rollout_count_threshold = 3
    smart_bob = False
    selection_model_file = 'selection_model.th'
    rollout_model_file = ''
    diverse_bob = False
    cuda = True
    domain = 'object_division'
    visual = False
    eps = 0.0
    data = 'data/negotiate'
    unk_threshold = 20
    bsz = 16
    validate = False
    ref_text = ''
    rl_lr = 0.002
    rl_clip = 2.0
    lr = 0.1
    gamma = 0.99
    eps = 0.5
    clip = 0.1
    momentum = 0.1
    sep_sel = True
    unk_threshold = 20
    sv_train_freq = 1
    
args = Arguments()

In [4]:
utils.use_cuda(args.cuda)
utils.set_seed(args.seed)

In [5]:
alice_model = utils.load_model(args.alice_model_file)
alice_ty = get_agent_type(alice_model, args.smart_alice)
alice = alice_ty(alice_model, args, name='Alice', train=True, diverse=args.diverse_alice)
alice.vis = args.visual

bob_model = utils.load_model(args.bob_model_file)
bob_ty = get_agent_type(bob_model, args.smart_bob)
bob = bob_ty(bob_model, args, name='Bob', train=False, diverse=args.diverse_bob)
bob.vis = False

In [6]:
dialog = Dialog([alice, bob], args)
logger = DialogLogger(verbose=args.verbose, log_file=args.log_file)
ctx_gen = ContextGenerator(args.context_file)

#dialog2 = Dialog([alice, joe], args)
domain = get_domain(args.domain)
corpus = alice_model.corpus_ty(domain, args.data, freq_cutoff=args.unk_threshold,
                               verbose=True, sep_sel=args.sep_sel)
engine = alice_model.engine_ty(alice_model, args)
alice.engine = engine

In [7]:
validset, validset_stats = corpus.valid_dataset(args.bsz)
trainset, trainset_stats = corpus.train_dataset(args.bsz)

In [8]:
alice.model.ctx_encoder[0]

MlpContextEncoder(
  (cnt_enc): Sequential(
    (0): Embedding(11, 64)
    (1): Dropout(p=0.1, inplace=False)
  )
  (val_enc): Sequential(
    (0): Embedding(11, 64)
    (1): Dropout(p=0.1, inplace=False)
  )
  (encoder): Sequential(
    (0): Linear(in_features=192, out_features=64, bias=True)
    (1): Tanh()
  )
)

In [9]:
class Arguments:
    alice_model_file = 'rnn_model.th'
    alice_forward_model_file = ''
    bob_model_file = 'rnn_model.th'
    bob_forward_model_file = ''
    context_file = 'data/negotiate/selfplay.txt'
    temperature = 0.5
    pred_temperature =1.0
    verbose = False
    seed = 1
    score_threshold = 6
    max_turns = 20
    log_file = ''
    smart_alice = False
    diverse_alice = False
    rollout_bsz = 3
    rollout_count_threshold = 3
    smart_bob = False
    selection_model_file = 'selection_model.th'
    rollout_model_file = ''
    diverse_bob = False
    cuda = True
    domain = 'object_division'
    visual = False
    eps = 0.0
    data = 'data/negotiate'
    unk_threshold = 20
    bsz = 16
    validate = False
    ref_text = ''
    rl_lr = 0.002
    rl_clip = 2.0
    lr = 0.1
    gamma = 0.99
    eps = 0.5
    clip = 0.1
    momentum = 0.1
    sep_sel = True
    unk_threshold = 20
    sv_train_freq = 1
    
args = Arguments()

In [10]:
utils.use_cuda(args.cuda)
utils.set_seed(args.seed)

In [11]:
alice_model = utils.load_model(args.alice_model_file)
alice_ty = get_agent_type(alice_model, args.smart_alice)
alice = alice_ty(alice_model, args, name='Alice', train=True, diverse=args.diverse_alice)
alice.vis = args.visual

bob_model = utils.load_model(args.bob_model_file)
bob_ty = get_agent_type(bob_model, args.smart_bob)
bob = bob_ty(bob_model, args, name='Bob', train=False, diverse=args.diverse_bob)
bob.vis = False

In [12]:
dialog = Dialog([alice, bob], args)
logger = DialogLogger(verbose=args.verbose, log_file=args.log_file)
ctx_gen = ContextGenerator(args.context_file)

#dialog2 = Dialog([alice, joe], args)
domain = get_domain(args.domain)
corpus = alice_model.corpus_ty(domain, args.data, freq_cutoff=args.unk_threshold,
                               verbose=True, sep_sel=args.sep_sel)
engine = alice_model.engine_ty(alice_model, args)
alice.engine = engine

In [13]:
validset, validset_stats = corpus.valid_dataset(args.bsz)
trainset, trainset_stats = corpus.train_dataset(args.bsz)

In [14]:
n = 0
rew_freq = 2
all_rewards = []
norm_reward = 0
args.sv_train_freq = 1
args.verbose = False
utt_reward = 0
for ctxs in tqdm(ctx_gen.iter()):
    if args.sv_train_freq > 0 and n % args.sv_train_freq == 0:
        batch = random.choice(trainset)
        engine.model.train()
        engine.train_batch(batch, reward=utt_reward)
        engine.model.eval()
    if n % rew_freq == 0:
        logger.dump('=' * 80)
        conv, agree, rewards = dialog.run(ctxs, logger)
        #dialog2.run(ctxs, logger)
        logger.dump('=' * 80)
        logger.dump('')
        
        # compute context rewards
        reward, partner_reward = rewards
        diff = reward - partner_reward
        all_rewards.append(diff)
        r = (diff - np.mean(all_rewards)) / max(1e-4, np.std(all_rewards))
        g = r
        rewards = []
        for _ in alice.logprobs:
            rewards.append(g)
            g = g * args.gamma
        ctx_norm_reward = 0
        for lp, r in zip(alice.logprobs, rewards):
            ctx_norm_reward -= lp.tolist()[0] * r
        #print('context reward:', ctx_norm_reward)
        
        # compute utterance rewards
        utt_reward = 0
        for utterance in conv:
            unigrams = pd.Series(ngrams(utterance, 1))
            if len(conv) < 2:
                utt_reward -= 0.5
                continue
            utt_reward += unigrams.count() - 8 if unigrams.count() < 8 else 0
            bigrams = pd.Series(ngrams(utterance, 2))
            utt_reward -= bigrams.value_counts().std()
            trigrams = pd.Series(ngrams(utterance, 3))
            utt_reward -= trigrams.value_counts().std()
        #print('utterance reward:', utt_reward)
        utt_reward = max(-2.0, utt_reward)
            
               
        #input()
    n += 1

In [15]:
len(list(ctx_gen.iter())), len(ctx_gen.ctxs)

(4086, 4086)

In [16]:
n = 0
rew_freq = 2
all_rewards = []
norm_reward = 0
args.sv_train_freq = 1
args.verbose = False
utt_reward = 0
for ctxs in tqdm(ctx_gen.iter(), total=len(ctx_gen.ctxs)):
    if args.sv_train_freq > 0 and n % args.sv_train_freq == 0:
        batch = random.choice(trainset)
        engine.model.train()
        engine.train_batch(batch, reward=utt_reward)
        engine.model.eval()
    if n % rew_freq == 0:
        logger.dump('=' * 80)
        conv, agree, rewards = dialog.run(ctxs, logger)
        #dialog2.run(ctxs, logger)
        logger.dump('=' * 80)
        logger.dump('')
        
        # compute context rewards
        reward, partner_reward = rewards
        diff = reward - partner_reward
        all_rewards.append(diff)
        r = (diff - np.mean(all_rewards)) / max(1e-4, np.std(all_rewards))
        g = r
        rewards = []
        for _ in alice.logprobs:
            rewards.append(g)
            g = g * args.gamma
        ctx_norm_reward = 0
        for lp, r in zip(alice.logprobs, rewards):
            ctx_norm_reward -= lp.tolist()[0] * r
        #print('context reward:', ctx_norm_reward)
        
        # compute utterance rewards
        utt_reward = 0
        for utterance in conv:
            unigrams = pd.Series(ngrams(utterance, 1))
            if len(conv) < 2:
                utt_reward -= 0.5
                continue
            utt_reward += unigrams.count() - 8 if unigrams.count() < 8 else 0
            bigrams = pd.Series(ngrams(utterance, 2))
            utt_reward -= bigrams.value_counts().std()
            trigrams = pd.Series(ngrams(utterance, 3))
            utt_reward -= trigrams.value_counts().std()
        #print('utterance reward:', utt_reward)
        utt_reward = max(-2.0, utt_reward)
            
               
        #input()
    n += 1

In [17]:
-2.0 * args.gamma

-1.98

In [18]:
import wandb

wandb.init(project="goal-based-negotiating-agents")

W&B Run: https://app.wandb.ai/tropdeep/goal-based-negotiating-agents/runs/b965pgjv

In [19]:
n = 0
rew_freq = 2
all_rewards = []
norm_reward = 0
args.sv_train_freq = 1
args.verbose = False
utt_reward = 0
for ctxs in tqdm(ctx_gen.iter(), total=len(ctx_gen.ctxs)):
    if args.sv_train_freq > 0 and n % args.sv_train_freq == 0:
        batch = random.choice(trainset)
        engine.model.train()
        engine.train_batch(batch, reward=utt_reward)
        engine.model.eval()
    if n % rew_freq == 0:
        logger.dump('=' * 80)
        conv, agree, rewards = dialog.run(ctxs, logger)
        #dialog2.run(ctxs, logger)
        logger.dump('=' * 80)
        logger.dump('')
        
        # compute context rewards
        reward, partner_reward = rewards
        diff = reward - partner_reward
        all_rewards.append(diff)
        r = (diff - np.mean(all_rewards)) / max(1e-4, np.std(all_rewards))
        g = r
        rewards = []
        for _ in alice.logprobs:
            rewards.append(g)
            g = g * args.gamma
        ctx_norm_reward = 0
        for lp, r in zip(alice.logprobs, rewards):
            ctx_norm_reward -= lp.item() * r
        #print('context reward:', ctx_norm_reward)
        
        # compute utterance rewards
        utt_reward = 0
        for utterance in conv:
            unigrams = pd.Series(ngrams(utterance, 1))
            if len(conv) < 2:
                utt_reward -= 0.5
                continue
            utt_reward += unigrams.count() - 8 if unigrams.count() < 8 else 0
            bigrams = pd.Series(ngrams(utterance, 2))
            utt_reward -= bigrams.value_counts().std()
            trigrams = pd.Series(ngrams(utterance, 3))
            utt_reward -= trigrams.value_counts().std()
        #print('utterance reward:', utt_reward)
        utt_reward = max(-2.0, utt_reward) * args.gamma
        
        # logs
        wandb.log({'utterance-reward': utt_reward,
                   'ctx-norm-reward': ctx_norm_reward})
            
               
        #input()
    n += 1

HBox(children=(IntProgress(value=0, max=4086), HTML(value='')))

In [20]:
ctx_norm_reward

-5.858900724264033e-06

In [21]:
wandb.watch(alice_model)

[<wandb.wandb_torch.TorchGraph at 0x7fdf7075a510>]

In [22]:
n = 0
rew_freq = 2
all_rewards = []
norm_reward = 0
args.sv_train_freq = 1
args.verbose = False
utt_reward = 0
for ctxs in tqdm(ctx_gen.iter(), total=len(ctx_gen.ctxs)):
    if args.sv_train_freq > 0 and n % args.sv_train_freq == 0:
        batch = random.choice(trainset)
        loss = engine.model.train()
        engine.train_batch(batch, reward=utt_reward)
        engine.model.eval()
        wandb.log({"loss": loss})
    if n % rew_freq == 0:
        logger.dump('=' * 80)
        conv, agree, rewards = dialog.run(ctxs, logger)
        #dialog2.run(ctxs, logger)
        logger.dump('=' * 80)
        logger.dump('')
        
        # compute context rewards
        reward, partner_reward = rewards
        diff = reward - partner_reward
        all_rewards.append(diff)
        r = (diff - np.mean(all_rewards)) / max(1e-4, np.std(all_rewards))
        g = r
        rewards = []
        for _ in alice.logprobs:
            rewards.append(g)
            g = g * args.gamma
        ctx_norm_reward = 0
        for lp, r in zip(alice.logprobs, rewards):
            ctx_norm_reward -= lp.item() * r
        #print('context reward:', ctx_norm_reward)
        
        # compute utterance rewards
        utt_reward = 0
        for utterance in conv:
            unigrams = pd.Series(ngrams(utterance, 1))
            if len(conv) < 2:
                utt_reward -= 0.5
                continue
            utt_reward += unigrams.count() - 8 if unigrams.count() < 8 else 0
            bigrams = pd.Series(ngrams(utterance, 2))
            utt_reward -= bigrams.value_counts().std()
            trigrams = pd.Series(ngrams(utterance, 3))
            utt_reward -= trigrams.value_counts().std()
        #print('utterance reward:', utt_reward)
        utt_reward = max(-2.0, utt_reward) * args.gamma
        
        # logs
        wandb.log({'utterance-reward': utt_reward,
                   'ctx-norm-reward': ctx_norm_reward})
            
               
        #input()
    n += 1

HBox(children=(IntProgress(value=0, max=4086), HTML(value='')))