# Install dependencies

In [None]:
!pip install torch==1.6.0

Collecting torch==1.6.0
  Downloading torch-1.6.0-cp37-cp37m-manylinux1_x86_64.whl (748.8 MB)
[K     |████████████████████████████████| 748.8 MB 15 kB/s 
Installing collected packages: torch
  Attempting uninstall: torch
    Found existing installation: torch 1.11.0+cu113
    Uninstalling torch-1.11.0+cu113:
      Successfully uninstalled torch-1.11.0+cu113
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchvision 0.12.0+cu113 requires torch==1.11.0, but you have torch 1.6.0 which is incompatible.
torchtext 0.12.0 requires torch==1.11.0, but you have torch 1.6.0 which is incompatible.
torchaudio 0.11.0+cu113 requires torch==1.11.0, but you have torch 1.6.0 which is incompatible.[0m
Successfully installed torch-1.6.0


# Setup training

In [None]:
# Clone repository
! git clone https://github.com/Daetheys/Lazimpa.git
! mv "./Lazimpa/egg" "./egg"
! mv "./Lazimpa/example" "./example"

Cloning into 'Lazimpa'...
remote: Enumerating objects: 2453, done.[K
remote: Total 2453 (delta 0), reused 0 (delta 0), pack-reused 2453[K
Receiving objects: 100% (2453/2453), 39.15 MiB | 11.03 MiB/s, done.
Resolving deltas: 100% (1530/1530), done.


In [None]:
! mkdir -p dir_save/{accuracy,messages,sender,receiver}
! mkdir analysis

In [None]:
!git clone https://github.com/Daetheys/NLPMVA

Cloning into 'NLPMVA'...
remote: Enumerating objects: 181, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 181 (delta 0), reused 1 (delta 0), pack-reused 178[K
Receiving objects: 100% (181/181), 287.11 MiB | 12.18 MiB/s, done.
Resolving deltas: 100% (35/35), done.
Checking out files: 100% (82/82), done.


# Distribution generation

In [None]:
import numpy as np
distribution = []
with open('NLPMVA/distribution (corrected).txt') as f:
    try:
        while True:
            data = f.readline()
            distribution.append(eval(data))
    except SyntaxError:
        pass
distribution_proba = np.array([d[1] for d in distribution])
distribution_lbl = np.array([d[0] for d in distribution])
distribution_lbl.max()

38

In [None]:
import json
import argparse
import numpy as np
import itertools
import torch.utils.data
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
import egg.core as core
#from scipy.stats import entropy
from egg.core import EarlyStopperAccuracy
from egg.zoo.channel.features import OneHotLoader, UniformLoader, OneHotLoaderCompositionality, TestLoaderCompositionality
from egg.zoo.channel.archs import Sender, Receiver
from egg.core.reinforce_wrappers import RnnReceiverImpatient, RnnReceiverImpatientCompositionality, RnnReceiverCompositionality
from egg.core.reinforce_wrappers import SenderImpatientReceiverRnnReinforce, CompositionalitySenderImpatientReceiverRnnReinforce, CompositionalitySenderReceiverRnnReinforce
from egg.core.util import dump_sender_receiver_impatient, dump_sender_receiver_impatient_compositionality, dump_sender_receiver_compositionality

from egg.core.trainers import CompoTrainer
from egg.zoo.channel.train_compositionality import *

In [None]:
class _DistributionIterator:
    """
    >>> it_1 = _OneHotIterator(n_features=128, n_batches_per_epoch=2, batch_size=64, probs=np.ones(128)/128, seed=1)
    >>> it_2 = _OneHotIterator(n_features=128, n_batches_per_epoch=2, batch_size=64, probs=np.ones(128)/128, seed=1)
    >>> list(it_1)[0][0].allclose(list(it_2)[0][0])
    True
    >>> it = _OneHotIterator(n_features=8, n_batches_per_epoch=1, batch_size=4, probs=np.ones(8)/8)
    >>> data = list(it)
    >>> len(data)
    1
    >>> batch = data[0]
    >>> x, y = batch
    >>> x.size()
    torch.Size([4, 8])
    >>> x.sum(dim=1)
    tensor([1., 1., 1., 1.])
    >>> probs = np.zeros(128)
    >>> probs[0] = probs[1] = 0.5
    >>> it = _OneHotIterator(n_features=128, n_batches_per_epoch=1, batch_size=256, probs=probs, seed=1)
    >>> batch = list(it)[0][0]
    >>> batch[:, 0:2].sum().item()
    256.0
    >>> batch[:, 2:].sum().item()
    0.0
    """
    def __init__(self, n_features, n_batches_per_epoch, batch_size, seed=None):
        self.n_batches_per_epoch = n_batches_per_epoch
        self.n_features = n_features
        self.batch_size = batch_size

        self.batches_generated = 0
        self.random_state = np.random.RandomState(seed)

    def __iter__(self):
        return self

    def __next__(self):
        if self.batches_generated >= self.n_batches_per_epoch:
            raise StopIteration()

        idxs = np.random.choice(len(distribution),self.batch_size,p=distribution_proba)
        batch_data = np.zeros((self.batch_size,self.n_features*2,))
        
        for i in range(distribution_lbl.shape[1]):
            batch_data[np.arange(self.batch_size),distribution_lbl[idxs,i]+i*self.n_features] = 1

        self.batches_generated += 1
        return torch.from_numpy(batch_data).float(), torch.zeros(1)

In [None]:
class DistributionLoader(torch.utils.data.DataLoader):
    """
    >>> probs = np.ones(8) / 8
    >>> data_loader = OneHotLoader(n_features=8, batches_per_epoch=3, batch_size=2, probs=probs, seed=1)
    >>> epoch_1 = []
    >>> for batch in data_loader:
    ...     epoch_1.append(batch)
    >>> [b[0].size() for b in epoch_1]
    [torch.Size([2, 8]), torch.Size([2, 8]), torch.Size([2, 8])]
    >>> data_loader_other = OneHotLoader(n_features=8, batches_per_epoch=3, batch_size=2, probs=probs)
    >>> all_equal = True
    >>> for a, b in zip(data_loader, data_loader_other):
    ...     all_equal = all_equal and (a[0] == b[0]).all()
    >>> all_equal.item()
    0
    """
    def __init__(self, n_features, batches_per_epoch, batch_size, seed=None):
        self.seed = seed
        self.batches_per_epoch = batches_per_epoch
        self.n_features = n_features
        self.batch_size = batch_size

    def __iter__(self):
        if self.seed is None:
            seed = np.random.randint(0, 2 ** 32)
        else:
            seed = self.seed

        return _DistributionIterator(n_features=self.n_features, n_batches_per_epoch=self.batches_per_epoch,
                               batch_size=self.batch_size,  seed=seed)

In [None]:
class DistributionUniformLoader(torch.utils.data.DataLoader):
    def __init__(self, n_features):
        idxs = np.arange(len(distribution))
        batch_data = np.zeros((len(distribution),n_features*2,))
        
        for i in range(distribution_lbl.shape[1]):
            batch_data[np.arange(len(distribution)),distribution_lbl[idxs,i]+i*n_features] = 1

        self.batch = torch.from_numpy(batch_data).float(), torch.zeros(1)

    def __iter__(self):
        return iter([self.batch])

#Define parameters

In [None]:
import time
import os
import torch
class Config:

    #IMPORTANT PARAMETERS
    impatient = False #Impatient Listener
    reg = False #Lazy Speaker
    random_seed = np.random.randint(0,100) #Seed used for the training
    lr = 3e-4 #Learning rate of both neural networks
    batch_size = 512 #Batch size for the training
    n_epochs = 101 #Nb of epochs for the training
    batches_per_epoch = 10 #Nb of batch per epoch

    vocab_size = 100 #Number of words in the speaker's vocabulary
    max_len = 2 #Maximum number of words the speaker can use to describe the combination of concepts given as input
    n_values = 100 #Nb of different concepts
    n_attributes = 2 #Nb of concepts that will be combined (and need to be guessed at the same time)

    custom_dist = False #Use Eva's distribution instead of the one of the paper

    checkpoint_path = "NLPMVA/trainingscompo"
    nb_children_per_language = 3

    #Eva's parameters
    if custom_dist:
        n_values = distribution_lbl.max()+1
        n_attributes = 2

        vocab_size = n_values
        max_len = 2

    #LESS IMPORTANT PARAMETERS
    n_features = 100

    receiver_hidden = 450
    receiver_num_layers = 1
    receiver_num_heads = 1
    receiver_embedding = 100
    receiver_cell = 'gru'
    receiver_entropy_coeff = 0.1

    sender_hidden = 600
    sender_num_layers = 1
    sender_num_heads = 1
    sender_embedding = 100
    sender_cell = 'gru'
    sender_entropy_coeff = 0.4

    length_cost = 0.
    name = 'model'
    early_stopping_thr = 0.99

    dir_save = 'dir_save'#'expe_'+str(time.time()).split('.')[0]
    checkpoint_dir = None#os.path.join('expe_'+str(time.time()).split('.')[0],'checkpoint')

    unigram_pen = 0.

    force_eos = False

    optimizer_class = torch.optim.Adam
    validation_freq = 1
    device = 'cuda:0'

    load_from_checkpoint = None
    checkpoint_freq = 0
    preemptable = False

    probs = 'uniform'
    probs_attributes = 'uniform'

    att_weights = [1,1,1]

# Dump compo

In [None]:
def dump_compositionality(game, n_attributes, n_values, device, gs_mode,epoch):
    one_hots = torch.eye(n_values)

    val=np.arange(n_values)
    combination=list(itertools.product(val,repeat=n_attributes))

    dataset=[]

    for i in range(len(combination)):
      new_input=torch.zeros(0)
      for j in combination[i]:
        new_input=torch.cat((new_input,one_hots[j]))
      dataset.append(new_input)

    dataset=torch.stack(dataset)

    dataset=[[dataset,None]]

    sender_inputs, messages, receiver_inputs, receiver_outputs, _ = \
        dump_sender_receiver_compositionality(game, dataset, gs=gs_mode, device=device, variable_length=True)

    unif_acc = 0.
    acc_vec=np.zeros(((n_values**n_attributes), n_attributes))

    for i in range(len(receiver_outputs)):
      message=messages[i]
      correct=True
      if i<n_values**n_attributes:
          for j in range(len(list(combination[i]))):
            if receiver_outputs[i][j]==list(combination[i])[j]:
              unif_acc+=1
              acc_vec[i,j]=1
      #if i<5:
      #    print(f'input: {",".join([str(x) for x in combination[i]])} -> message: {",".join([str(x.item()) for x in message])} -> output: {",".join([str(x) for x in receiver_outputs[i]])}', flush=True)

    unif_acc /= (n_values**n_attributes) * n_attributes

    #print(json.dumps({'unif': unif_acc}))

    return acc_vec, messages

def custom_dump_compositionality(game, n_attributes, n_values, device, gs_mode,epoch):

    dataset=DistributionUniformLoader(n_values)

    sender_inputs, messages, receiver_inputs, receiver_outputs, _ = \
        dump_sender_receiver_compositionality(game, dataset, gs=gs_mode, device=device, variable_length=True)

    unif_acc = 0.
    acc_vec=np.zeros(((n_values**n_attributes), n_attributes))

    for i in range(len(receiver_outputs)):
      message=messages[i]
      correct=True
      if i<n_values**n_attributes:
          for j in range(len(distribution_lbl[i])):
            if receiver_outputs[i][j]==list(distribution_lbl[i])[j]:
              unif_acc+=1
              acc_vec[i,j]=1
      #if i<5:
      #    print(f'input: {",".join([str(x) for x in distribution_lbl[i]])} -> message: {",".join([str(x.item()) for x in message])} -> output: {",".join([str(x) for x in receiver_outputs[i]])}', flush=True)

    unif_acc /= (n_values**n_attributes) * n_attributes

    #print(json.dumps({'unif': unif_acc}))

    return acc_vec, messages

def dump_impatient_compositionality(game, n_attributes, n_values, device, gs_mode,epoch):
    # tiny "dataset"
    one_hots = torch.eye(n_values)

    val=np.arange(n_values)
    combination=list(itertools.product(val,repeat=n_attributes))

    dataset=[]

    for i in range(len(combination)):
      new_input=torch.zeros(0)
      for j in combination[i]:
        new_input=torch.cat((new_input,one_hots[j]))
      dataset.append(new_input)

    dataset=torch.stack(dataset)

    dataset=[[dataset,None]]

    sender_inputs, messages, receiver_inputs, receiver_outputs, _ = \
        dump_sender_receiver_impatient_compositionality(game, dataset, gs=gs_mode, device=device, variable_length=True)

    unif_acc = 0.
    acc_vec=np.zeros(((n_values**n_attributes), n_attributes))

    for i in range(len(receiver_outputs)):
      message=messages[i]
      correct=True
      for j in range(len(list(combination[i]))):
        if receiver_outputs[i][j]==list(combination[i])[j]:
          unif_acc+=1
          acc_vec[i,j]=1
      #if epoch%5==0 and i<5:
      #    print(f'input: {",".join([str(x) for x in combination[i]])} -> message: {",".join([str(x.item()) for x in message])} -> output: {",".join([str(x) for x in receiver_outputs[i]])}', flush=True)

    unif_acc /= (n_values**n_attributes) * n_attributes

    #print(json.dumps({'unif': unif_acc}))

    return acc_vec, messages

def custom_dump_impatient_compositionality(game, n_attributes, n_values, device, gs_mode,epoch):

    dataset = DistributionUniformLoader(n_values)

    sender_inputs, messages, receiver_inputs, receiver_outputs, _ = \
        dump_sender_receiver_impatient_compositionality(game, dataset, gs=gs_mode, device=device, variable_length=True)

    unif_acc = 0.
    acc_vec=np.zeros(((n_values**n_attributes), n_attributes))

    for i in range(len(receiver_outputs)):
      message=messages[i]
      correct=True
      for j in range(len(receiver_outputs[i])):
        if receiver_outputs[i][j]==list(distribution_lbl[i])[j]:
          unif_acc+=1
          acc_vec[i,j]=1
      #if epoch%5==0 and i<5:
      #    print(f'input: {",".join([str(x) for x in distribution_lbl[i]])} -> message: {",".join([str(x.item()) for x in message])} -> output: {",".join([str(x) for x in receiver_outputs[i]])}', flush=True)

    unif_acc /= (n_values**n_attributes) * n_attributes

    #print(json.dumps({'unif': unif_acc}))

    return acc_vec, messages

# Pipeline from the git

In [None]:
def entropy_dict(freq_table):
    H = 0
    n = sum(v for v in freq_table.values())

    for m, freq in freq_table.items():
        p = freq_table[m] / n
        H += -p * np.log(p)
    return H / np.log(2)

def _hashable_tensor(t):
    if isinstance(t, tuple):
        return t
    if isinstance(t, int):
        return t

    try:
        t = t.item()
    except ValueError:
        t = tuple(t.view(-1).tolist())
    return t

def entropy(messages):
    from collections import defaultdict

    freq_table = defaultdict(float)

    for m in messages:
        m = _hashable_tensor(m)
        freq_table[m] += 1.0

    return entropy_dict(freq_table)

def mutual_info(xs, ys):
    e_x = entropy(xs)
    e_y = entropy(ys)

    xys = []

    for x, y in zip(xs, ys):
        xy = (_hashable_tensor(x), _hashable_tensor(y))
        xys.append(xy)

    e_xy = entropy(xys)

    return e_x + e_y - e_xy

In [None]:
from typing import Any, Dict, Optional

import torch

from egg.core.util import move_to


class Batch:
    def __init__(
        self,
        sender_input: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        receiver_input: Optional[torch.Tensor] = None,
        aux_input: Optional[Dict[Any, Any]] = None,
    ):
        self.sender_input = sender_input
        self.labels = labels
        self.receiver_input = receiver_input
        self.aux_input = aux_input

    def __getitem__(self, idx):
        """
        >>> b = Batch(torch.Tensor([1]), torch.Tensor([2]), torch.Tensor([3]), {})
        >>> b[0]
        tensor([1.])
        >>> b[1]
        tensor([2.])
        >>> b[2]
        tensor([3.])
        >>> b[3]
        {}
        >>> b[6]
        Traceback (most recent call last):
            ...
        IndexError: Trying to access a wrong index in the batch
        """
        if idx == 0:
            return self.sender_input
        elif idx == 1:
            return self.labels
        elif idx == 2:
            return self.receiver_input
        elif idx == 3:
            return self.aux_input
        else:
            raise IndexError("Trying to access a wrong index in the batch")

    def __iter__(self):
        """
        >>> _ = torch.manual_seed(111)
        >>> sender_input = torch.rand(2, 2)
        >>> labels = torch.rand(2, 2)
        >>> batch = Batch(sender_input, labels)
        >>> it = batch.__iter__()
        >>> it_sender_input = next(it)
        >>> torch.allclose(sender_input, it_sender_input)
        True
        >>> it_labels = next(it)
        >>> torch.allclose(labels, it_labels)
        True
        """
        return iter(
            [self.sender_input, self.labels, self.receiver_input, self.aux_input]
        )

    def to(self, device: torch.device):
        """Method to move all (nested) tensors of the batch to a specific device.
        This operation doest not change the original batch element and returns a new Batch instance.
        """
        self.sender_input = move_to(self.sender_input, device)
        self.labels = move_to(self.labels, device)
        self.receiver_input = move_to(self.receiver_input, device)
        self.aux_input = move_to(self.aux_input, device)
        return self

In [None]:
import json

import torch
from scipy import spatial
from scipy.stats import spearmanr

import egg.core as core

try:
    import editdistance  # package to install https://pypi.org/project/editdistance/0.3.1/
except ImportError:
    print(
        "Please install editdistance package: `pip install editdistance`. "
        "It is used for calculating topographic similarity."
    )


def ask_sender(n_attributes, n_values, dataset, sender, device):
    attributes = []
    strings = []
    meanings = []

    for i in range(len(dataset)):
        meaning = dataset[i]

        attribute = meaning.view(n_attributes, n_values).argmax(dim=-1)
        attributes.append(attribute)
        meanings.append(meaning.to(device))

        with torch.no_grad():
            string, *other = sender(meaning.unsqueeze(0).to(device))
        strings.append(string.squeeze(0))

    attributes = torch.stack(attributes, dim=0)
    strings = torch.stack(strings, dim=0)
    meanings = torch.stack(meanings, dim=0)

    return attributes, strings, meanings


def information_gap_representation(meanings, representations):
    gaps = torch.zeros(representations.size(1))
    non_constant_positions = 0.0

    for j in range(representations.size(1)):
        symbol_mi = []
        h_j = None
        for i in range(meanings.size(1)):
            x, y = meanings[:, i], representations[:, j]
            info = mutual_info(x, y)
            symbol_mi.append(info)

            if h_j is None:
                h_j = entropy(y)

        symbol_mi.sort(reverse=True)

        if h_j > 0.0:
            gaps[j] = (symbol_mi[0] - symbol_mi[1]) / h_j
            non_constant_positions += 1

    score = gaps.sum() / non_constant_positions
    return score.item()


def information_gap_position(n_attributes, n_values, dataset, sender, device):
    attributes, strings, _meanings = ask_sender(
        n_attributes, n_values, dataset, sender, device
    )
    return information_gap_representation(attributes, strings)


def histogram(strings, vocab_size):
    batch_size = strings.size(0)

    histogram = torch.zeros(batch_size, vocab_size, device=strings.device)

    for v in range(vocab_size):
        histogram[:, v] = strings.eq(v).sum(dim=-1)

    return histogram


def information_gap_vocab(n_attributes, n_values, dataset, sender, device, vocab_size):
    attributes, strings, _meanings = ask_sender(
        n_attributes, n_values, dataset, sender, device
    )

    histograms = histogram(strings, vocab_size)
    return information_gap_representation(attributes, histograms[:, 1:])


def edit_dist(_list):
    distances = []
    count = 0
    for i, el1 in enumerate(_list[:-1]):
        for j, el2 in enumerate(_list[i + 1 :]):
            count += 1
            # Normalized edit distance (same in our case as length is fixed)
            distances.append(editdistance.eval(el1, el2) / len(el1))
    return distances


def cosine_dist(_list):
    distances = []
    for i, el1 in enumerate(_list[:-1]):
        for j, el2 in enumerate(_list[i + 1 :]):
            distances.append(spatial.distance.cosine(el1, el2))
    return distances


def topographic_similarity(n_attributes, n_values, dataset, sender, device):
    _attributes, strings, meanings = ask_sender(
        n_attributes, n_values, dataset, sender, device
    )
    list_string = []
    for s in strings:
        list_string.append([x.item() for x in s])
    distance_messages = edit_dist(list_string)
    distance_inputs = cosine_dist(meanings.cpu().numpy())

    corr = spearmanr(distance_messages, distance_inputs).correlation
    return corr


class Metrics(core.Callback):
    def __init__(self, dataset, device, n_attributes, n_values, vocab_size, freq=1):
        self.dataset = dataset
        self.device = device
        self.n_attributes = n_attributes
        self.n_values = n_values
        self.epoch = 0
        self.vocab_size = vocab_size
        self.freq = freq

    def dump_stats(self):
        game = self.trainer.game
        game.eval()

        positional_disent = information_gap_position(
            self.n_attributes, self.n_values, self.dataset, game.sender, self.device
        )
        bos_disent = information_gap_vocab(
            self.n_attributes,
            self.n_values,
            self.dataset,
            game.sender,
            self.device,
            self.vocab_size,
        )
        topo_sim = topographic_similarity(
            self.n_attributes, self.n_values, self.dataset, game.sender, self.device
        )

        output = dict(
            epoch=self.epoch,
            positional_disent=positional_disent,
            bag_of_symbol_disent=bos_disent,
            topographic_sim=topo_sim,
        )

        output_json = json.dumps(output)
        print(output_json, flush=True)

        game.train()

        return output

    def on_train_end(self):
        pass
        #self.dump_stats()

    def on_epoch_end(self, *stuff):
        self.epoch += 1

        if self.freq <= 0 or self.epoch % self.freq != 0:
            return

        self.dump_stats()

# Training

In [None]:
import sys
class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout

In [None]:
def train_config(opts):

    import egg
    egg.core.util.common_opts = opts
    device = opts.device

    force_eos = opts.force_eos == 1

    # Distribution of the inputs
    if opts.probs=="uniform":
        probs=[]
        probs_by_att = np.ones(opts.n_values)
        probs_by_att /= probs_by_att.sum()
        for i in range(opts.n_attributes):
            probs.append(probs_by_att)

    if opts.probs=="entropy_test":
        probs=[]
        for i in range(opts.n_attributes):
            probs_by_att = np.ones(opts.n_values)
            probs_by_att[0]=1+(1*i)
            probs_by_att /= probs_by_att.sum()
            probs.append(probs_by_att)

    if opts.probs_attributes=="uniform":
        probs_attributes=[1]*opts.n_attributes

    if opts.probs_attributes=="uniform_indep":
        probs_attributes=[]
        probs_attributes=[0.2]*opts.n_attributes

    if opts.probs_attributes=="echelon":
        probs_attributes=[]
        for i in range(opts.n_attributes):
            #probs_attributes.append(1.-(0.2)*i)
            #probs_attributes.append(0.7+0.3/(i+1))
            probs_attributes=[1.,0.95,0.9,0.85]

    if opts.custom_dist:
        train_loader = DistributionLoader(n_features=opts.n_values, batch_size=opts.batch_size,batches_per_epoch=opts.batches_per_epoch)
        #OneHotLoaderCompositionality
        # single batches with 1s on the diag
        test_loader = DistributionUniformLoader(n_features=opts.n_values)
    else:
        #TestLoaderCompositionality
        train_loader = OneHotLoaderCompositionality(n_values=opts.n_values, n_attributes=opts.n_attributes, batch_size=opts.batch_size*opts.n_attributes,
                                                    batches_per_epoch=opts.batches_per_epoch, probs=probs, probs_attributes=probs_attributes)

        # single batches with 1s on the diag
        test_loader = TestLoaderCompositionality(n_values=opts.n_values,n_attributes=opts.n_attributes)
    ### SENDER ###

    sender = Sender(n_features=opts.n_attributes*opts.n_values, n_hidden=opts.sender_hidden)

    sender = core.RnnSenderReinforce(sender,opts.vocab_size, opts.sender_embedding, opts.sender_hidden,
                                    cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers,
                                    force_eos=force_eos)

    sender.load_state_dict(torch.load(opts.checkpoint_path))
    for p in sender.parameters():
        p.requires_grad = False
    ### RECEIVER ###

    receiver = Receiver(n_features=opts.n_values, n_hidden=opts.receiver_hidden)

    if not opts.impatient:
        receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden)
        receiver = RnnReceiverCompositionality(receiver, opts.vocab_size, opts.receiver_embedding,
                                            opts.receiver_hidden, cell=opts.receiver_cell,
                                            num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_attributes=opts.n_attributes, n_values=opts.n_values)
    else:
        receiver = Receiver(n_features=opts.receiver_hidden, n_hidden=opts.vocab_size)
        # If impatient 1
        receiver = RnnReceiverImpatientCompositionality(receiver, opts.vocab_size, opts.receiver_embedding,
                                            opts.receiver_hidden, cell=opts.receiver_cell,
                                            num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_attributes=opts.n_attributes, n_values=opts.n_values)


    if not opts.impatient:
        game = CompositionalitySenderReceiverRnnReinforce(sender, receiver, loss_compositionality, sender_entropy_coeff=opts.sender_entropy_coeff,
                                            n_attributes=opts.n_attributes,n_values=opts.n_values,receiver_entropy_coeff=opts.receiver_entropy_coeff,
                                            length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen,reg=opts.reg)
    else:
        game = CompositionalitySenderImpatientReceiverRnnReinforce(sender, receiver, loss_impatient_compositionality, sender_entropy_coeff=opts.sender_entropy_coeff,
                                            n_attributes=opts.n_attributes,n_values=opts.n_values,att_weights=opts.att_weights,receiver_entropy_coeff=opts.receiver_entropy_coeff,
                                            length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen,reg=opts.reg)


    optimizer = opts.optimizer_class(game.parameters(),lr=opts.lr)

    trainer = CompoTrainer(n_attributes=opts.n_attributes,n_values=opts.n_values,game=game, optimizer=optimizer, train_data=train_loader,
                            validation_data=test_loader, callbacks=[EarlyStopperAccuracy(opts.early_stopping_thr)])

    curr_accs=[0]*7

    game.att_weights=[1]*(game.n_attributes)

    accs = []

    for epoch in range(int(opts.n_epochs)):

        if epoch%50==0:
          print("Epoch: "+str(epoch))

        with HiddenPrints():
            trainer.train(n_epochs=1)
        if not opts.impatient:
            if opts.custom_dist:
                acc_vec,messages=custom_dump_compositionality(trainer.game, opts.n_attributes, opts.n_values, device, False,epoch)
            else:
                acc_vec,messages=dump_compositionality(trainer.game, opts.n_attributes, opts.n_values, device, False,epoch)
        else:
            if opts.custom_dist:
                acc_vec,messages=custom_dump_impatient_compositionality(trainer.game, opts.n_attributes, opts.n_values, device, False,epoch)
            else:
                acc_vec,messages=dump_impatient_compositionality(trainer.game, opts.n_attributes, opts.n_values, device, False,epoch)

        # ADDITION TO SAVE MESSAGES
        #print(acc_vec.T)
        accs.append(np.mean(acc_vec))

    metric_vals = None
    if opts.compute_metrics:
      #Create dataset for metric computation
      if opts.custom_dist:
          #Dataset creation
          idxs = np.arange(len(distribution))
          batch_data = np.zeros((len(distribution),opts.n_values*2,))
          for i in range(distribution_lbl.shape[1]):
              batch_data[np.arange(len(distribution)),distribution_lbl[idxs,i]+i*opts.n_values] = 1
          dataset = torch.from_numpy(batch_data).float()
      else:
          l = []
          c = 0
          data = next(iter(test_loader))[0]
          for idx,i in enumerate(data):
              c += 1
              if idx < 1000:
                  l.append(i[None])
          dataset = torch.cat(l,axis=0)
          print("metric dataset",dataset.shape)

      metric = Metrics(dataset, opts.device, opts.n_attributes, opts.n_values, opts.vocab_size, freq=20)
      metric.trainer = trainer
      metric_vals = metric.dump_stats()

    core.close()

    return accs,metric_vals



In [None]:
print(torch.cuda.is_available())
opts = Config()
data = []
for i,f in enumerate(os.listdir(opts.checkpoint_path)):
    print('---',i)
    data.append([])
    for j in range(opts.nb_children_per_language):#68
        opts = Config()
        opts.nb_children_per_language = 3
        #opts.n_epochs = 1
        opts.checkpoint_path += "/"+f
        opts.compute_metrics = j==0
        accs,metrics_val = train_config(opts)
        data[i].append([accs,metrics_val])
        print(accs[-1])

import pickle
with open('datastd.p','wb') as f:
    pickle.dump(data,f)

from google.colab import files
files.download("datastd.p")

True
--- 0
Epoch: 0


	nonzero()
Consider using one of the following signatures instead:
	nonzero(*, bool as_tuple) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:766.)
  eos_positions = (message[i, :] == 0).nonzero()


Epoch: 50
Epoch: 100
metric dataset torch.Size([1000, 200])
{"epoch": 0, "positional_disent": 0.9702141284942627, "bag_of_symbol_disent": 0.9445710182189941, "topographic_sim": 0.9729554940844154}
0.93505
Epoch: 0
Epoch: 50
Epoch: 100
0.93505
Epoch: 0
Epoch: 50
Epoch: 100
0.9349
--- 1
Epoch: 0
Epoch: 50
Epoch: 100
metric dataset torch.Size([1000, 200])
{"epoch": 0, "positional_disent": 0.9960936307907104, "bag_of_symbol_disent": 0.9553039073944092, "topographic_sim": 0.9956404687760368}
0.93145
Epoch: 0
Epoch: 50
Epoch: 100
0.9316
Epoch: 0
Epoch: 50
Epoch: 100
0.93165
--- 2
Epoch: 0
Epoch: 50
Epoch: 100
metric dataset torch.Size([1000, 200])
{"epoch": 0, "positional_disent": 0.9936408996582031, "bag_of_symbol_disent": 0.9490382671356201, "topographic_sim": 0.9906570885661963}
0.94505
Epoch: 0
Epoch: 50
Epoch: 100
0.9451
Epoch: 0
Epoch: 50
Epoch: 100
0.945


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

# Preprocess data

In [None]:
import pickle
import numpy as np

In [None]:
with open('data(4).p','rb') as  f:
    data = pickle.load(f)

FileNotFoundError: ignored

In [None]:
compo_scores = []
learn_scores_mean = []
learn_scores_var = []
for d in data:
    compo_scores.append([d[0][1]['positional_disent'],d[0][1]['bag_of_symbol_disent'],d[0][1]['topographic_sim']])
    learn_scores_mean.append(np.mean([np.sum(d[i][0]) for i in range(len(d))]))
    learn_scores_var.append(np.var([np.sum(d[i][0]) for i in range(len(d))]))

compo_scores = np.array(compo_scores)
learn_scores_mean = np.array(learn_scores_mean)
learn_scores_var = np.array(learn_scores_var)

In [None]:
learn_scores_mean = (learn_scores_mean-learn_scores_mean.mean())/learn_scores_mean.var()

In [None]:
learn_scores_mean = (learn_scores_mean-learn_scores_mean.min())/(learn_scores_mean.max()-learn_scores_mean.min())

# Plot

In [None]:
for i in range(len(compo_scores)):
    print(compo_scores[i],learn_scores_mean[i])

In [None]:
import matplotlib.pyplot as plt
import matplotlib
from mpl_toolkits import mplot3d

fig = plt.figure()
ax = plt.axes(projection='3d')

cmap = cmap = matplotlib.cm.get_cmap('jet')

for i in range(len(compo_scores)):
    color = cmap(learn_scores_mean[i])
    ax.scatter(*compo_scores[i],c=color)
plt.show()