In [1]:
# flake8: noqa: E128
from asyncore import write
import argparse
import os
import random
import time
from distutils.util import strtobool
import spacy
nlp = spacy.load("en_core_web_sm")

import gym
import wandb
import numpy as np
import transformers
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
from referential_game_env import ReferentialGameEnv
from speaker import Speaker
from tom_speaker import TOMSpeaker
from coco_speaker import COCOSpeaker
from metrics.metrics import Fluency, SemanticSimilarity, sentence_length, num_nouns
from metrics.analysis import pos_count, get_overlap
from metrics.compute_bleu import compute_bleu

  from .autonotebook import tqdm as notebook_tqdm
Using custom data configuration ChristophSchuhmann--MS_COCO_2017_URL_TEXT-14fff710bb0bfd5b
Reusing dataset parquet (/home/david/.cache/huggingface/datasets/parquet/ChristophSchuhmann--MS_COCO_2017_URL_TEXT-14fff710bb0bfd5b/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)
100%|██████████| 1/1 [00:00<00:00, 451.92it/s]


### Cell to parse all the default args (since there is no command line in notebook)

In [2]:
def parse_args():
    # fmt: off
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp-name', type=str, default=os.path.basename("evaluation"),
        help='the name of this experiment')
    parser.add_argument('--gym-id', type=str, default="ReferentialGame-v0",
        help='the id of the gym environment')
    parser.add_argument('--learning-rate', type=float, default=0.0,
        help='the learning rate of the optimizer')
    parser.add_argument('--seed', type=int, default=1,
        help='seed of the experiment')
    parser.add_argument('--total-timesteps', type=int, default=10000,
        help='total timesteps of the experiments')
    parser.add_argument('--torch-deterministic', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
        help='if toggled, `torch.backends.cudnn.deterministic=False`')
    parser.add_argument('--cuda', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
        help='if toggled, cuda will be enabled by default')
    parser.add_argument('--track', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
        help='if toggled, this experiment will be tracked with Weights and Biases')
    parser.add_argument('--wandb-project-name', type=str, default="ToM-Language-Acquisition-Eval",
        help="the wandb's project name")
    parser.add_argument('--wandb-entity', type=str, default=None,
        help="the entity (team) of wandb's project")
    parser.add_argument('--captions-file', type=str, default="data/test_org",
        help="file to get auxiliary captions from")
    parser.add_argument('--capture-video', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
        help='weather to capture videos of the agent performances (check out `videos` folder)')
    parser.add_argument('--less-logging', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
        help='logs every 1000 timesteps instead of every timestep (recommended for performance)')

    # Algorithm specific arguments
    parser.add_argument('--num-envs', type=int, default=4,
        help='the number of parallel game environments')
    parser.add_argument('--num-steps', type=int, default=128,
        help='the number of steps to run in each environment per policy rollout')
    parser.add_argument('--anneal-lr', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
        help="Toggle learning rate annealing for policy and value networks")
    parser.add_argument('--exp-decay', type=float, default=0.994)
    parser.add_argument('--gae', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
        help='Use GAE for advantage computation')
    parser.add_argument('--gamma', type=float, default=1.0,
        help='the discount factor gamma')
    parser.add_argument('--gae-lambda', type=float, default=0.95,
        help='the lambda for the general advantage estimation')
    parser.add_argument('--num-minibatches', type=int, default=4,
        help='the number of mini-batches')
    parser.add_argument('--update-epochs', type=int, default=4,
        help="the K epochs to update the policy")
    parser.add_argument('--norm-adv', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
        help="Toggles advantages normalization")
    parser.add_argument('--clip-coef', type=float, default=0.2,
        help="the surrogate clipping coefficient")
    parser.add_argument('--clip-vloss', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
        help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')
    parser.add_argument('--ent-coef', type=float, default=0.01,
        help="coefficient of the entropy")
    parser.add_argument('--vf-coef', type=float, default=0.5,
        help="coefficient of the value function")
    parser.add_argument('--max-grad-norm', type=float, default=0.5,
        help='the maximum norm for the gradient clipping')
    parser.add_argument('--target-kl', type=float, default=None,
        help='the target KL divergence threshold')

    parser.add_argument('--supervised-coef', type=float, default=0.01, help='the ratio of supervised loss')
    parser.add_argument('--length-pen', type=float, default=0.0, help='length penalty')

    # tom arguments
    parser.add_argument('--use-coco', type=lambda x:bool(strtobool(x)), default = False, nargs='?', 
        const = True, help = 'toggle usage of COCOSpeaker')
    parser.add_argument('--use-tom', type=lambda x:bool(strtobool(x)), default = False, nargs='?', 
        const = True, help = 'toggle usage of theory of mind')
    parser.add_argument('--sigma', type=float, default = 0.0, help = "exploration sigma value for ToM speaker")
    parser.add_argument('--tom-weight', type=float, default=1.0, 
        help = "If using a ToM speaker, what weight to give to ToM listener ranking")
    parser.add_argument('--tom-losscoef', type=float, default=0.1, help = "coef for tom loss")
    parser.add_argument('--separate-training', type=lambda x:bool(strtobool(x)), default = False, nargs='?',
        const = True, help = "Separate ToM Listener training from rest of network")
    parser.add_argument('--beam-size', type=int, default=25,
        help = "number of candidates to generate for ToM listener")
    parser.add_argument('--beam-search', type=lambda x:bool(strtobool(x)), default = False, nargs = '?',
        const = True, help = 'use beam search instead of sampling')
    parser.add_argument('--tom-anneal', type=lambda x:bool(strtobool(x)), default = False, nargs='?',
        const = True, help = 'toggle anneal of ToM listener influence')
    parser.add_argument('--tom-anneal-start', type=float, default=0.2, 
        help = "fraction of updates that must pass to start using ToM listener")
    parser.add_argument('--sigma-decay', type=lambda x:bool(strtobool(x)), default = False, nargs='?',
        const = True, help = 'toggle anneal of ToM listener influence')
    parser.add_argument('--sigma-decay-end', type=float, default=1.0, 
        help = "fraction of updates that must pass to converge to final sigma value")
    parser.add_argument('--sigma-low', type=float, default=0.1, 
        help = "final sigma value to converge to")
    parser.add_argument('--gold-standard', type=lambda x:bool(strtobool(x)), default = False, nargs='?',
        const = True, help = 'give ToM speaker access to gold standard ToM listener')
    
    # Environment specific arguments
    parser.add_argument('--vocabulary-size', type=int, 
        default=200,
        help='vocabulary size of speaker')
    parser.add_argument('--max-len', type=int,
        default=20,
        help='maximum utterance length')
    parser.add_argument('--game-file-path', type=str)

    parser.add_argument('--theta-1', type=float, default=.4, help='theta 1')
    parser.add_argument('--theta-2', type=float, default=.9, help='theta 2')
    parser.add_argument('--model-path', type=str, default=None, help='the path of the model')
    parser.add_argument('--n-distr', type=int, default=2)
    parser.add_argument('--distribution', type=str, default='uniform', help='uniform or zipf')

    parser.add_argument('--sup-coef-decay', action='store_true', help='decay supervised coeff')
    parser.add_argument('--D_img', type=int, default=2048,)
    parser.add_argument('--pretrained-path', type=str, default=None,
        help='load in the wandb path for a pretrained model if you want to run in evaluation mode')

    parser.add_argument('--render-html', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
        help="whether to save HTML images")
    parser.add_argument('--run-name', type=str, default="test",
        help="run name to save HTML files under")
    parser.add_argument('--render-every-N', type=int, default=5000,
        help="render an HTML file every N updates")

    args = parser.parse_args([])
    # fmt: on
    return args


args = parse_args()

### Cell to adjust the args with the values that were normally set during the commandline (so adjust these if u wanted to alter the command)

In [3]:
# The arguments you normally set within the command provided in the readme
args.total_timesteps = 10000
args.supervised_coef = 0.01
args.game_file_path = "data/game_file_dev.pt"
# args.game_file_path = "data/new_game_file_with_high_similarity.pt"
args.exp_name = "test1"
args.captions_file = "data/test_org"
args.less_logging = True
args.use_coco = True
args.beam_size = 5 # This is the amount of samples to draw for each target, default 10
args.prune_size = args.beam_size  # The amount to prune the beams
args.beam_search = False
# Diverse beam search related
args.diverse = False
args.diverse_G = 5 # The amount of groups
args.diverse_multiplier = 0.01 # The amount the diversity should be taken into account, higher = , lower = more diversity
#
args.sigma = 0.0
args.seed = 517
args.tom_weight = 1.0
args.pretrained_path = "andy_files"
args.batch_size = 1
#
args.max_len = 12 # Default is 20
args.render = True
args.hard = False
#
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class Agent(nn.Module):
    def __init__(self, envs):
        super(Agent, self).__init__()
        self.critic = nn.Sequential(
            layer_init(
                nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )
        self.actor = nn.Sequential(
            layer_init(
                nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01),
        )

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        logits = self.actor(x)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(x)



### Initiliazation of all the needed stuff happens in the next few cells

In [4]:
print("Starting!")
print("Parsed the provided arguments!")
################################################################################
# Setup Experiment and Logger                                                  #
################################################################################
if True:
    run_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    if args.track:
        print("Initialising wanDB since tracking is on")
        import wandb
        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=args.exp_name,
            monitor_gym=True,
            save_code=True,
        )
################################################################################
# Seeding                                                                      #
################################################################################
if True:
    print("Seeding numpy and torch")
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

Starting!
Parsed the provided arguments!
Seeding numpy and torch


In [5]:
################################################################################
# Referential Game Environments                                                #
################################################################################
print("Creating the referential games!")
envs = ReferentialGameEnv(max_len=args.max_len,
                eos_id=3,
                noop_penalty=0.5,
                length_penalty=args.length_pen,
                batch_size=args.batch_size,
                n_distr=args.n_distr,
                game_file_path=args.game_file_path,
                theta_1=args.theta_1,
                theta_2=args.theta_2,
                distribution=args.distribution,
                model_path = args.model_path,
                captions_file = args.captions_file,
                hard=args.hard)
dev_envs = ReferentialGameEnv(max_len=args.max_len,
            eos_id=3,
            noop_penalty=0.5,
            length_penalty=args.length_pen,
            batch_size=args.batch_size,
            n_distr=args.n_distr,
            game_file_path=args.game_file_path,
            theta_1=args.theta_1,
            theta_2=args.theta_2,
            distribution=args.distribution,
            model_path = args.model_path,
            captions_file = args.captions_file,
            hard=args.hard)
i2w = torch.load("i2w")

Creating the referential games!


100%|██████████| 381/381 [00:02<00:00, 140.13it/s]
100%|██████████| 381/381 [00:02<00:00, 140.45it/s]


In [6]:
################################################################################
# Device                                                                       #
################################################################################
if True:
    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
print("Succesfully created the referential games, now loading the speaker and listeners models!")
# Load necessary components such as the agent and tokenizer
speaker_path = "wandb/" + args.pretrained_path + "/files/speaker_model.pt"
listener_path = "wandb/" + args.pretrained_path + "/files/tom_listener.pt"
tokenizer = transformers.RobertaTokenizer.from_pretrained("roberta-base")
print(tokenizer.vocab_size)
print(args.gold_standard)
agent = TOMSpeaker(maxlen=args.max_len, vocabsize=tokenizer.vocab_size, 
                    sigma=args.sigma, beam_size=args.beam_size, tom_weight=args.tom_weight,
                    use_pretrained=args.gold_standard, beam_search=args.beam_search,
                    loaded_model_paths=(speaker_path, listener_path), use_coco=args.use_coco, word_list=list(range(200))).to(device)

Succesfully created the referential games, now loading the speaker and listeners models!
50265
False


## From here on experiments are ran

In [7]:
import matplotlib.pyplot as plt
################################################################################
# Evaluation                                                                   #
################################################################################
total_games = 10
accuracy = []
entropies = []
unigrams = []
bigrams = []
trigrams = []
obs = envs.reset() # Not needed because .step also updates to do a new game
for i in range(1, total_games+1):
    with torch.no_grad():
        # This is to prepare the images for the models
        B = obs["images"].shape[0]
        next_images = torch.Tensor(
                obs["images"][range(B), :]
        ).to(device)

        # This gets the next target in the reference game
        next_target = torch.Tensor(obs["goal"]).long().to(device)

        # This creates the sentences, so this makes the Speaker sample sentences based on the images and the target
        sentences_tensor, logp, entropy, temp_unigrams, temp_bigrams, temp_trigrams = agent.sample(next_images, next_target, beam_size=args.beam_size, prune_size=args.prune_size, diverse=args.diverse, diverse_G=args.diverse_G, diverse_multiplier=args.diverse_multiplier, render=args.render) # This sentences tensor contains 1 sentence per "game"

        # These are the generated sentences but as actual sentences, so decoded
        generated_sentences = [
            ' '.join([i2w[token_id.item()] for token_id in sentence])
            for sentence in sentences_tensor
        ]

        # The env step function makes the listener pick a target based on the given sentences
        # Let hierbij op, dat als je render = true op false zet, je geen plaatjes meer zult zien. dit is dus opzich sneller voor alleen accuracy zien
        obs = envs.step(sentences_tensor.cpu().numpy(), render=args.render)
        dev_accuracy = obs["accuracy"]

        # This simply appends the accuracy of this game to the accuracy list of all games
        accuracy.append(dev_accuracy)
        entropies.append(entropy)
        unigrams.append(temp_unigrams[1])
        bigrams.append(temp_bigrams[1])
        trigrams.append(temp_trigrams[1])
        if args.render:
            print("\n\n\n")

print("The accuracy of a total of", total_games, "games is:", (sum(accuracy) / len(accuracy)))
print("The entropy of a total of", total_games, "games is:", (sum(entropies) / len(entropies)))
print("The unigram count of a total of", total_games, "games is:", (sum(unigrams) / len(unigrams)))
print("The bigram count of a total of", total_games, "games is:", (sum(bigrams) / len(bigrams)))
print("The trigram count of a total of", total_games, "games is:", (sum(trigrams) / len(trigrams)))

print("Done!")

Using beam search
<BOS> a baseball player holding a bat on a <UNK> outside of people open
<BOS> a baseball player holding a bat on top of a field . <EOS>
<BOS> a baseball player holding a bat on a <UNK> outside of food .
<BOS> a baseball player holding a bat on a <UNK> outside of people sitting
<BOS> a baseball player holding a bat on a area <EOS> <PAD> <PAD> <PAD>





0,1,2
a building with a very large clock on the side of it,a baseball player holding a bat on a field .,a person on some skis in the snow .
,(GOAL) (RESULT),






Using beam search
<BOS> a sink and bed a laptop is and white dog . together ball
<BOS> a sink and bed a white skateboard <EOS> <PAD> <PAD> <PAD> <PAD> <PAD>
<BOS> a sink and bed a white skateboard . area sign sidewalk . <EOS>
<BOS> a sink and bed a white skateboard . area sign sidewalk . area
<BOS> a sink and bed a white skateboard . area sign sidewalk . glass





0,1,2
a laptop computer sitting on top of a wooden desk .,two giraffe standing next to each other in a field .,a small room has a bed and desk with a laptop .
,,(GOAL) (RESULT)






Using beam search
<BOS> a group of a street under beach flying umbrella together glass the ocean
<BOS> a man riding a wave on top of a white glass city computer
<BOS> a person sitting on a beach with an open umbrella <EOS> <PAD> <PAD>
<BOS> a man riding a wave on top of a surfboard . <EOS> <PAD>
<BOS> a group of a street under beach flying umbrella together glass to it





0,1,2
a person that is holding some food in her hand .,a pizza sitting on top of a table .,a person standing on top of a beach holding an umbrella .
,,(GOAL) (RESULT)






Using beam search
<BOS> a computer is on top of a table . <EOS> <PAD> <PAD> <PAD>
<BOS> an open laptop that is on sitting on top of it near pizza
<BOS> an open laptop that is on sitting on top of her bowl table
<BOS> an open laptop that is on sitting on top of her glass <EOS>
<BOS> an open laptop that is on sitting on top of another laptop .





0,1,2
a cat that is sitting in a sink,a laptop computer sitting on top of a wooden desk .,a couple of people standing next to each other .
,(GOAL) (RESULT),






Using beam search
<BOS> a pizza sitting on top of a table eating traffic skis sandwich airplane
<BOS> a pizza sitting on top of a table eating food large airplane group
<BOS> a pizza sitting on top of a table eating traffic skis table road
<BOS> a pizza on top of a bowl <EOS> <PAD> <PAD> <PAD> <PAD> <PAD>
<BOS> a pizza sitting on top of a table eating traffic skis sandwich horse





0,1,2
a person is flying a kite in a field,a group of people with a cell phone the side of a building .,a pizza sitting on top of a white plate on a wooden table .
,,(GOAL) (RESULT)






Using beam search
<BOS> a young men playing a game of frisbee . room . area <EOS>
<BOS> a young men playing a game of frisbee . room . together <EOS>
<BOS> a young men playing a game of frisbee in <UNK> the child picture
<BOS> a young men playing a game of frisbee in <UNK> the child wall
<BOS> a young men playing a game of frisbee in <UNK> the child bunch





0,1,2
a man flying through while riding a skateboard .,a young boy stands at bat in a park .,a street sign on a road near a building
,(GOAL) (RESULT),






Using beam search
<BOS> a group of people standing on the outside of a building together big
<BOS> a group of people standing on the outside of a building . area
<BOS> a group of people standing outside <EOS> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>
<BOS> a group of people standing on the outside of a building together clock
<BOS> a group of people standing on the outside of a building together picture





0,1,2
a white toilet in a very small bathroom .,a city sidewalk with people walking up and down,a white sink and toilet in a small room .
,(GOAL) (RESULT),






Using beam search
<BOS> a red stop sign sitting on the side of the bowl walking large
<BOS> a red stop sign sitting on the side of the glass open <BOS>
<BOS> a red stop sign sitting on the side of the bowl walking woman
<BOS> a red stop sign sitting next fence outside <EOS> <PAD> <PAD> <PAD> <PAD>
<BOS> a man is standing near the street area . <EOS> <PAD> <PAD> <PAD>





0,1,2
a blue and white street sign next to a white building .,a man flying through the air while riding a skateboard .,a group of men standing next to each other .
(GOAL) (RESULT),,






Using beam search
<BOS> a group of young men playing a game of frisbee together the tree
<BOS> a group of people are on the park it <EOS> <PAD> <PAD> <PAD>
<BOS> a group of men in <UNK> into skis . together view trees large
<BOS> a group of young men playing a game of frisbee <EOS> <PAD> <PAD>
<BOS> a group of young men playing a game of frisbee bowl is large





0,1,2
a computer desk with a laptop computer on top of it .,a young man is holding a blue .,a group of people on a grass field together playing a game with a frisbee .
,,(GOAL) (RESULT)






Using beam search
<BOS> a man walking a dog under in the open in tree . <EOS>
<BOS> a man walking a dog down a road while holding an umbrella .
<BOS> a woman a wooden fence walking in front surfboard around eating together glass
<BOS> a man walking a dog under in the open in tree walking large
<BOS> a man walking a dog under in the table it &apos;s open car





0,1,2
a small dog sitting on a wooden chair .,a group of people standing next to a building with a large dog,a green white with a sandwich on top of it .
(RESULT),(GOAL),






The accuracy of a total of 10 games is: 0.9
The entropy of a total of 10 games is: 0.6449990168213844
The unigram count of a total of 10 games is: 0.35821612045428913
The bigram count of a total of 10 games is: 0.4237533493056002
The trigram count of a total of 10 games is: 0.40675907580837756
Done!
