In [1]:
import pdb
import re
import random
import utils
import numpy as np
import torch
import torch.nn as nn
import domain
import pandas as pd

from torch import optim
from torch import autograd
from ipywidgets import interact
from agent import *
from utils import ContextGenerator
from dialog import Dialog, DialogLogger
from models.rnn_model import RnnModel
from models.latent_clustering_model import LatentClusteringPredictionModel, BaselineClusteringModel
from agent import RnnAgent, RnnRolloutAgent, RlAgent, HierarchicalAgent
from domain import get_domain
from nltk import ngrams
from tqdm.notebook import tqdm

In [2]:
def get_agent_type(model, smart=False):
    if isinstance(model, LatentClusteringPredictionModel):
        if smart:
            return LatentClusteringRolloutAgent
        else:
            return LatentClusteringAgent
    elif isinstance(model, RnnModel):
        if smart:
            return RnnRolloutAgent
        else:
            return RnnAgent
    elif isinstance(model, BaselineClusteringModel):
        if smart:
            return BaselineClusteringRolloutAgent
        else:
            return BaselineClusteringAgent
    else:
        assert False, 'unknown model type: %s' % (model)

In [3]:
class Arguments:
    alice_model_file = 'rnn_model.th'
    alice_forward_model_file = ''
    bob_model_file = 'rnn_model.th'
    bob_forward_model_file = ''
    context_file = 'data/negotiate/selfplay.txt'
    temperature = 0.5
    pred_temperature =1.0
    verbose = False
    seed = 1
    score_threshold = 6
    max_turns = 20
    log_file = ''
    smart_alice = False
    diverse_alice = False
    rollout_bsz = 3
    rollout_count_threshold = 3
    smart_bob = False
    selection_model_file = 'selection_model.th'
    rollout_model_file = ''
    diverse_bob = False
    cuda = True
    domain = 'object_division'
    visual = False
    eps = 0.0
    data = 'data/negotiate'
    unk_threshold = 20
    bsz = 16
    validate = False
    ref_text = ''
    rl_lr = 0.002
    rl_clip = 2.0
    lr = 0.1
    gamma = 0.99
    eps = 0.5
    clip = 0.1
    momentum = 0.1
    sep_sel = True
    unk_threshold = 20
    sv_train_freq = 1
    
args = Arguments()

In [4]:
utils.use_cuda(args.cuda)
utils.set_seed(args.seed)

In [5]:
alice_model = utils.load_model(args.alice_model_file)
alice_ty = get_agent_type(alice_model, args.smart_alice)
alice = alice_ty(alice_model, args, name='Alice', train=True, diverse=args.diverse_alice)
alice.vis = args.visual

bob_model = utils.load_model(args.bob_model_file)
bob_ty = get_agent_type(bob_model, args.smart_bob)
bob = bob_ty(bob_model, args, name='Bob', train=False, diverse=args.diverse_bob)
bob.vis = False

In [6]:
dialog = Dialog([alice, bob], args)
logger = DialogLogger(verbose=args.verbose, log_file=args.log_file)
ctx_gen = ContextGenerator(args.context_file)

#dialog2 = Dialog([alice, joe], args)
domain = get_domain(args.domain)
corpus = alice_model.corpus_ty(domain, args.data, freq_cutoff=args.unk_threshold,
                               verbose=True, sep_sel=args.sep_sel)
engine = alice_model.engine_ty(alice_model, args)
alice.engine = engine

In [7]:
validset, validset_stats = corpus.valid_dataset(args.bsz)
trainset, trainset_stats = corpus.train_dataset(args.bsz)

In [8]:
import wandb

wandb.init(project="goal-based-negotiating-agents")

W&B Run: https://app.wandb.ai/tropdeep/goal-based-negotiating-agents/runs/rcrhg38y