In [None]:
%load_ext autoreload

%autoreload 2

import torch

import os
import time
import json
import numpy as np
from collections import defaultdict
from speaker import Speaker

from utils import read_vocab,write_vocab,build_vocab,Tokenizer,padding_idx,timeSince, read_img_features, read_graph_features, read_graph_features_parallel
import utils
from env import R2RBatch
from eval import Evaluation
from param import args
from agent import ActiveExplore_v1

import warnings
warnings.filterwarnings("ignore")

os.environ["CUDA_VISIBLE_DEVICES"] = "0"


from tensorboardX import SummaryWriter

print('current directory',os.getcwd())
os.chdir('..')
print('current directory',os.getcwd())


args.name = 'active'
args.attn = 'soft'
args.train = 'listener'
args.featdropout = 0.3
args.angle_feat_size = 128
args.feedback = 'sample'
args.ml_weight = 0.2
args.sub_out = 'max'
args.dropout = 0.5
args.optim = 'adam'
args.lr = 1e-4
args.iters = 80000
args.maxAction = 35
args.batchSize = 64

args.self_train = True
args.aug = 'tasks/R2R/data/aug_paths.json'
# args.aug = 'tasks/R2R/data/aug_paths_unseenvalid.json'
args.speaker = 'snap/speaker/state_dict/best_val_unseen_bleu'

args.featdropout = 0.4
args.iters = 200000

if args.optim == 'rms':
    print("Optimizer: Using RMSProp")
    args.optimizer = torch.optim.RMSprop
elif args.optim == 'adam':
    print("Optimizer: Using Adam")
    args.optimizer = torch.optim.Adam
elif args.optim == 'sgd':
    print("Optimizer: sgd")
    args.optimizer = torch.optim.SGD
elif args.optim == 'adabound':
    print("Optimizer: adabound")
    args.optimizer = adabound.AdaBound



log_dir = 'snap/%s' % args.name
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

TRAIN_VOCAB = 'tasks/R2R/data/train_vocab.txt'
TRAINVAL_VOCAB = 'tasks/R2R/data/trainval_vocab.txt'

IMAGENET_FEATURES = 'img_features/ResNet-152-imagenet.tsv'
PLACE365_FEATURES = 'img_features/ResNet-152-places365.tsv'

if args.features == 'imagenet':
    features = IMAGENET_FEATURES

if args.fast_train:
    name, ext = os.path.splitext(features)
    features = name + "-fast" + ext

feedback_method = args.feedback # teacher or sample

print(args)


def setup():
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)
    # Check for vocabs
    if not os.path.exists(TRAIN_VOCAB):
        write_vocab(build_vocab(splits=['train']), TRAIN_VOCAB)
    if not os.path.exists(TRAINVAL_VOCAB):
        write_vocab(build_vocab(splits=['train','val_seen','val_unseen']), TRAINVAL_VOCAB)
#
setup()

vocab = read_vocab(TRAIN_VOCAB)
tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)

feat_dict = read_img_features(features)

print('start extract keys...')
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
print('keys extracted...')

# Load the augmentation data
aug_path = args.aug

# Create the training environment
train_env = R2RBatch(feat_dict, batch_size=args.batchSize,
                     splits=['train'], tokenizer=tok)
aug_env   = R2RBatch(feat_dict, batch_size=args.batchSize,
                     splits=[aug_path], tokenizer=tok, name='aug')
aug_unseen_env   = R2RBatch(feat_dict, batch_size=args.batchSize,
                     splits=['tasks/R2R/data/aug_paths_unseenvalid.json'], tokenizer=tok, name='aug')

# Printing out the statistics of the dataset
stats = train_env.get_statistics()
print("The training data_size is : %d" % train_env.size())
print("The average instruction length of the dataset is %0.4f." % (stats['length']))
print("The average action length of the dataset is %0.4f." % (stats['path']))
# stats = aug_env.get_statistics()
# print("The augmentation data size is %d" % aug_env.size())
# print("The average instruction length of the dataset is %0.4f." % (stats['length']))
# print("The average action length of the dataset is %0.4f." % (stats['path']))

# Setup the validation data
val_envs = {split: (R2RBatch(feat_dict, batch_size=args.batchSize, splits=[split],
                             tokenizer=tok), Evaluation([split], featurized_scans, tok))
            for split in ['train', 'val_seen', 'val_unseen']}



In [None]:
import imp
import traceback


args.load = 'snap/agent/state_dict/best_val_unseen'
torch.autograd.set_detect_anomaly(True)


def test(train_env, tok, n_iters, log_every=100, val_envs={}, aug_env=None):
    writer = SummaryWriter(logdir=log_dir)
    
    listner = ActiveExplore_v1(train_env, "", tok, episode_len=args.maxAction)
    
    speaker = None
    if args.self_train:
        speaker = Speaker(train_env, listner, tok)
        if args.speaker is not None:
            print("Load the speaker from %s." % args.speaker)
            speaker.load(args.speaker)

    start_iter = 0
    if args.load is not None:
        print("LOAD THE listener from %s" % args.load)
        start_iter = listner.load(os.path.join(args.load))
        start_iter = 0
        
    ths = np.ones([args.maxAction])*3
    ths[0] = 0.456
    ths[1] = 0.485
    ths[2] = 0.493
    ths[3] = 0.584
    ths[4] = 0.574
    ths[5] = 0.418
    ths[6] = 0.675
        

    start = time.time()

    best_val = {'val_seen': {"accu": 0., "state":"", 'update':False},
                'val_unseen': {"accu": 0., "state":"", 'update':False}}
    if args.fast_train:
        log_every = 40
    try:
        loss_str = ""
        for env_name, (env, evaluator) in val_envs.items():
            listner.env = env
            listner.logs = defaultdict(list)

            # Get validation loss under the same conditions as training
            iters = None if args.fast_train or env_name != 'train' else 20     # 20 * 64 = 1280
#             iters = 5

            # Get validation distance from goal under test evaluation conditions
            listner.test(use_dropout=False, feedback='argmax', iters=iters, train_exp=True, ths=ths)
            result = listner.get_results()
            score_summary, score_details = evaluator.score(result)


            with open('traj.json','w') as fp:
                json.dump(result,fp,indent=4)


            loss_str += ", %s \n" % env_name
     
            for metric,val in score_summary.items():
                loss_str += ', %s: %.3f' % (metric, val)
            loss_str += '\n'


            print(loss_str)
        torch.cuda.empty_cache()


    except Exception as e:
        del listner
        traceback.print_exc()



torch.cuda.empty_cache()
test(train_env, tok, args.iters,log_every=20, val_envs=val_envs, aug_env=[aug_env,aug_unseen_env])
# listener