<a href="https://colab.research.google.com/github/MathieuRita/LE_test/blob/master/LazImpa_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **LazImpa**

*Branche Félix*



In [5]:
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
import re
import random
from PIL import Image
import json
import argparse
import torch.utils.data
import torch.nn.functional as F
from torchsummary import summary
from torchviz import make_dot

import egg.core as core
from egg.zoo.channel.position_analysis import main as position_analysis
from egg.zoo.channel.train import main as train
from egg.zoo.channel.train import get_params, loss_impatient, dump_impatient
from egg.zoo.channel.test import main as test
from egg.core import EarlyStopperAccuracy
from egg.zoo.channel.features import OneHotLoader, UniformLoader
from egg.zoo.channel.archs import Sender, Receiver
from egg.core.reinforce_wrappers import RnnReceiverImpatient
from egg.core.reinforce_wrappers import SenderImpatientReceiverRnnReinforce
from egg.core.util import dump_sender_receiver_impatient

In [6]:
def clean_npy_files(base_dir="dir_save", analysis_dir="analysis", images_dir="images_dir", epoch_min=0,n_min = 0):
    """
    Deletes .npy and .pth files from specified subdirectories and analysis directory
    ONLY if the file contains an epoch number > epoch_min.
    """
    subdirs = ["accuracy", "messages", "sender", "receiver"]
    epoch_pattern = re.compile(r"epoch_(\d+)")
    n_patern = re.compile(r"features_(\d+)")

    def should_delete(file):
        match = epoch_pattern.search(file)
        match2 = n_patern.search(file)
        if match and match2:
            epoch = int(match.group(1))
            n = int(match2.group(1))
            return (epoch > epoch_min or n > n_min)
        return False

    # Clean subdirectories under base_dir
    for subdir in subdirs:
        path = os.path.join(base_dir, subdir)
        for extension in ["*.npy", "*.pth"]:
            files = glob.glob(os.path.join(path, extension))
            for file in files:
                if should_delete(file):
                    try:
                        os.remove(file)
                        print(f"✅ Deleted: {file}")
                    except Exception as e:
                        print(f"⚠️ Error deleting {file}: {e}")

    # Clean analysis directory
    analysis_files = glob.glob(os.path.join(analysis_dir, "*.npy"))
    for file in analysis_files:
        if should_delete(file):
            try:
                os.remove(file)
                print(f"✅ Deleted: {file}")
            except Exception as e:
                print(f"⚠️ Error deleting {file}: {e}")

    # Clean images directory (optional - based on epoch in filename)
    image_files = glob.glob(os.path.join(images_dir, "*.png"))
    for file in image_files:
        if should_delete(file):
            try:
                os.remove(file)
                print(f"✅ Deleted: {file}")
            except Exception as e:
                print(f"⚠️ Error deleting {file}: {e}")

In [7]:
clean_npy_files() # WARNING

✅ Deleted: dir_save\accuracy\accuracy_epoch_0_n_features_25.npy
✅ Deleted: dir_save\accuracy\accuracy_epoch_100_n_features_42.npy
✅ Deleted: dir_save\accuracy\accuracy_epoch_101_n_features_42.npy
✅ Deleted: dir_save\accuracy\accuracy_epoch_102_n_features_43.npy
✅ Deleted: dir_save\accuracy\accuracy_epoch_103_n_features_43.npy
✅ Deleted: dir_save\accuracy\accuracy_epoch_104_n_features_43.npy
✅ Deleted: dir_save\accuracy\accuracy_epoch_105_n_features_44.npy
✅ Deleted: dir_save\accuracy\accuracy_epoch_106_n_features_44.npy
✅ Deleted: dir_save\accuracy\accuracy_epoch_107_n_features_44.npy
✅ Deleted: dir_save\accuracy\accuracy_epoch_108_n_features_45.npy
✅ Deleted: dir_save\accuracy\accuracy_epoch_109_n_features_45.npy
✅ Deleted: dir_save\accuracy\accuracy_epoch_10_n_features_25.npy
✅ Deleted: dir_save\accuracy\accuracy_epoch_110_n_features_45.npy
✅ Deleted: dir_save\accuracy\accuracy_epoch_11_n_features_25.npy
✅ Deleted: dir_save\accuracy\accuracy_epoch_12_n_features_25.npy
✅ Deleted: dir_

## I - Train LazImpa

____________________
In this section we propose the code to train agents. Here we put the Hparams to run LazImpa. *If you want to test other agents model shown in the paper, change the parameters `impatient` and `reg`. You can also play with the other H-parameters*


**Data saved:**

Current messages and accuracy by input are saved at each training episode in `dir_save/messages` and `dir_save/accuracy` and the weights of the agents are saved every 50 epochs in `dir_save/sender` and `dir_save/receiver`.
_____________________

In [12]:
# Variables for training

vocab_size = 20 # default : 40
max_length = 15 # default : 30
n_features = 25 # default : 100
n_epochs = 51 # default : 501
batch_size = 512 # default : 512
length_cost = 0.0  # default : 0.0
lr = 0.001 # default : 0.001
sender_hidden = 250 # default : 250
receiver_hidden = 600 # default : 600
receiver_embedding = 100 # default : 100
sender_embedding = 10 # default : 10
sender_entropy_coeff = 2.0 # default : 2.0
batches_per_epoch = 100 # default : 100
early_stopping_thr = 0.99   # default : 0.99
impatient = True
epoch = 50*((n_epochs-1)//50)
sender_weights = f"dir_save/sender/sender_weights_epoch_{epoch}_n_features_{n_features}.pth"
receiver_weights = f"dir_save/receiver/receiver_weights_epoch_{epoch}_n_features_{n_features}.pth"
sender_cell = "lstm"
receiver_cell = "lstm"
sender_num_layers = 1
receiver_num_layers = 1


In [13]:
def get_train_args():
    args = [
            f"--dir_save=dir_save",
            "--impatient=True",
            "--reg=True",
            f"--vocab_size={vocab_size}",
            f"--max_len={max_length}",
            f"--n_features={n_features}",
            "--print_message=False",
            "--random_seed=7",
            '--probs=powerlaw',
            f"--n_epoch={n_epochs}",
            f"--batch_size={batch_size}",
            f"--length_cost={length_cost}",
            "--sender_cell=lstm",
            "--receiver_cell=lstm",
            f"--sender_hidden={sender_hidden}",
            f"--receiver_hidden={receiver_hidden}",
            f"--receiver_embedding={receiver_embedding}",
            f"--sender_embedding={sender_embedding}",
            f"--batches_per_epoch={batches_per_epoch}",
            f"--lr={lr}",
            f"--sender_entropy_coeff={sender_entropy_coeff}",
            "--sender_num_layers=1",
            "--receiver_num_layers=1",
            f"--early_stopping_thr={early_stopping_thr}",
        ]
    return args

In [14]:
train(get_train_args()) # WARNING  35 minutes

🧠 PyTorch version: 2.6.0+cu118
🖥️  CUDA available: True
🚀 CUDA version: 11.8
🧪 cuDNN version: 90100
🧠 Number of GPUs: 1
🔹 GPU 0: NVIDIA GeForce RTX 4060 Laptop GPU
   - Memory Allocated: 147.25 MB
   - Memory Cached:    262.00 MB
🧰 Python version: 3.13.2
True
PARAMETERS
Namespace(n_features=25, batches_per_epoch=100, dim_dataset=10240, force_eos=0, sender_hidden=250, receiver_hidden=600, receiver_num_layers=1, sender_num_layers=1, receiver_num_heads=8, sender_num_heads=8, sender_embedding=10, receiver_embedding=100, causal_sender=False, causal_receiver=False, sender_generate_style='in-place', sender_cell='lstm', receiver_cell='lstm', sender_entropy_coeff=2.0, receiver_entropy_coeff=0.1, probs='powerlaw', length_cost=0.0, name='model', early_stopping_thr=0.99, dir_save='dir_save', unigram_pen=0.0, impatient=True, print_message=True, reg=True, random_seed=7, checkpoint_dir=None, preemptable=False, checkpoint_freq=0, validation_freq=1, n_epochs=51, load_from_checkpoint=None, no_cuda=False

In [15]:

opts = get_params(get_train_args())

force_eos = opts.force_eos == 1

probs = 1 / np.arange(1, opts.n_features+1, dtype=np.float32)
probs /= probs.sum()

train_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size,
                            batches_per_epoch=opts.batches_per_epoch, probs=probs)

# single batches with 1s on the diag
test_loader = UniformLoader(opts.n_features)

sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden)
sender = core.RnnSenderReinforce(sender,
                            opts.vocab_size, opts.sender_embedding, opts.sender_hidden,
                            cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers,
                            force_eos=force_eos)

receiver = Receiver(n_features=opts.receiver_hidden, n_hidden=opts.vocab_size)
receiver = RnnReceiverImpatient(receiver, opts.vocab_size, opts.receiver_embedding,
                                opts.receiver_hidden, cell=opts.receiver_cell,
                                num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_features=opts.n_features)

game = SenderImpatientReceiverRnnReinforce(sender, receiver, loss_impatient, sender_entropy_coeff=opts.sender_entropy_coeff,
                                        receiver_entropy_coeff=opts.receiver_entropy_coeff,
                                        length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen,reg=opts.reg)

optimizer = core.build_optimizer(game.parameters())

trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader,
                        validation_data=test_loader, callbacks=[EarlyStopperAccuracy(opts.early_stopping_thr)])

summary(sender, input_size=(opts.n_features,))

# Simulate a single input sample
x = torch.randn(1, opts.n_features).to(opts.device)

# Forward pass to create the graph
out = sender(x)

# Create and render the graph
dot = make_dot(out, params=dict(sender.named_parameters()))
#dot.view()

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 250]           6,500
            Sender-2                  [-1, 250]               0
          LSTMCell-3     [[-1, 250], [-1, 250]]               0
            Linear-4                   [-1, 20]           5,020
         Embedding-5                   [-1, 10]             200
          LSTMCell-6     [[-1, 250], [-1, 250]]               0
            Linear-7                   [-1, 20]           5,020
         Embedding-8                   [-1, 10]             200
          LSTMCell-9     [[-1, 250], [-1, 250]]               0
           Linear-10                   [-1, 20]           5,020
        Embedding-11                   [-1, 10]             200
         LSTMCell-12     [[-1, 250], [-1, 250]]               0
           Linear-13                   [-1, 20]           5,020
        Embedding-14                   

### Perform the analytical tests

In [16]:
impatient = True
epoch = 50*((n_epochs-1)//50)
sender_weights = f"dir_save/sender/sender_weights_epoch_{epoch}_n_features_{n_features}.pth"
receiver_weights = f"dir_save/receiver/receiver_weights_epoch_{epoch}_n_features_{n_features}.pth"
sender_cell = "lstm"
receiver_cell = "lstm"
sender_num_layers = 1
receiver_num_layers = 1

In [17]:
def get_test_args():
    test_args = [
            f"--impatient={impatient}",
            "--save_dir=analysis/",
            f"--receiver_weights={receiver_weights}",
            f"--sender_weights={sender_weights}",
            f"--vocab_size={vocab_size}",
            f"--max_len={max_length}",
            f"--n_features={n_features}",
            f"--sender_cell={sender_cell}",
            f"--receiver_cell={receiver_cell}",
            f"--sender_hidden={sender_hidden}",
            f"--receiver_hidden={receiver_hidden}",
            f"--receiver_embedding={receiver_embedding}",
            f"--sender_embedding={sender_embedding}",
            f"--sender_num_layers={sender_num_layers}",
            f"--receiver_num_layers={receiver_num_layers}"
        ]
    return test_args

In [18]:
test(get_test_args())

Namespace(n_features=25, batches_per_epoch=1000, dim_dataset=10240, force_eos=0, sender_hidden=250, receiver_hidden=600, receiver_num_layers=1, sender_num_layers=1, receiver_num_heads=8, sender_num_heads=8, sender_embedding=10, receiver_embedding=100, causal_sender=False, causal_receiver=False, sender_generate_style='in-place', sender_cell='lstm', receiver_cell='lstm', sender_entropy_coeff=0.1, receiver_entropy_coeff=0.1, probs='uniform', length_cost=0.0, name='model', early_stopping_thr=0.9999, receiver_weights='dir_save/receiver/receiver_weights_epoch_50_n_features_25.pth', sender_weights='dir_save/sender/sender_weights_epoch_50_n_features_25.pth', save_dir='analysis/', impatient=True, unigram_pen=0.0, random_seed=1390851128, checkpoint_dir=None, preemptable=False, checkpoint_freq=0, validation_freq=1, n_epochs=10, load_from_checkpoint=None, no_cuda=False, batch_size=32, optimizer='adam', lr=0.01, vocab_size=20, max_len=15, tensorboard=False, tensorboard_dir='runs/', cuda=True, devic

Impatient score=59
input: 0 -> message: 0 -> output: 0
input: 1 -> message: 9,0 -> output: 1
input: 2 -> message: 6,17,0 -> output: 2
input: 3 -> message: 2,17,17,0 -> output: 3
input: 4 -> message: 12,9,17,0 -> output: 4
input: 5 -> message: 12,10,17,0 -> output: 5
input: 6 -> message: 1,1,6,0 -> output: 6
input: 7 -> message: 3,0 -> output: 7
input: 8 -> message: 6,0 -> output: 8
input: 9 -> message: 1,0 -> output: 9
input: 10 -> message: 17,0 -> output: 10
input: 11 -> message: 2,8,0 -> output: 11
input: 12 -> message: 19,10,13,0 -> output: 12
input: 13 -> message: 7,17,0 -> output: 13
input: 14 -> message: 1,9,0 -> output: 14
input: 15 -> message: 7,1,18,0 -> output: 15
input: 16 -> message: 17,9,8,0 -> output: 16
input: 17 -> message: 17,7,0 -> output: 17
input: 18 -> message: 7,0 -> output: 18
input: 19 -> message: 7,2,8,0 -> output: 19
input: 20 -> message: 10,8,0 -> output: 20
input: 21 -> message: 17,8,3,0 -> output: 10
input: 22 -> message: 2,9,8,0 -> output: 22
input: 23 -> 

In [171]:
position_analysis(get_test_args())

Namespace(n_features=25, batches_per_epoch=1000, dim_dataset=10240, force_eos=0, sender_hidden=250, receiver_hidden=600, receiver_num_layers=1, sender_num_layers=1, receiver_num_heads=8, sender_num_heads=8, sender_embedding=10, receiver_embedding=100, causal_sender=False, causal_receiver=False, sender_generate_style='in-place', sender_cell='lstm', receiver_cell='lstm', sender_entropy_coeff=0.1, receiver_entropy_coeff=0.1, probs='uniform', length_cost=0.0, name='model', early_stopping_thr=0.9999, receiver_weights='dir_save/receiver/receiver_weights_epoch_50_n_features_25.pth', sender_weights='dir_save/sender/sender_weights_epoch_50_n_features_25.pth', save_dir='analysis/', impatient=True, unigram_pen=0.0, random_seed=1390851128, checkpoint_dir=None, preemptable=False, checkpoint_freq=0, validation_freq=1, n_epochs=10, load_from_checkpoint=None, no_cuda=False, batch_size=32, optimizer='adam', lr=0.01, vocab_size=20, max_len=15, tensorboard=False, tensorboard_dir='runs/', cuda=True, devic

# II - Add new Words

In [140]:
class DynamicPermutation:
    def __init__(self):
        self.permutation = []

    def initialize(self, n):
        self.permutation = list(range(1, n + 1))

    def insert(self, nb_add):
        """
        Inserts nb_add new elements, each at a random position.
        Each inserted value is one greater than the current max.
        Elements at and after the insertion position are incremented.
        """
        for _ in range(nb_add):
            n = len(self.permutation)
            if n == 0:
                self.permutation.append(1)
                continue

            l = random.randint(1, n)  # inclusive of end for random insertion
            self.permutation.append(l)
            # Increment elements at and after the insertion position
            for i in range(len(self.permutation)-1):
                if self.permutation[i] >= l:
                    self.permutation[i] += 1

    def get_permutation(self):
        return self.permutation

In [73]:
def load_and_train(params,probs,nb_ep,sender_weights,receiver_weights):
    opts = get_params(params)
    device = opts.device
    prev_ep = opts.n_epochs
    print("Previous epochs: "+str(prev_ep))
    
    train_loader = OneHotLoader(n_features=opts.n_features, batch_size=opts.batch_size,batches_per_epoch=opts.batches_per_epoch, probs=probs)
    test_loader = UniformLoader(opts.n_features)

    sender = Sender(n_features=opts.n_features, n_hidden=opts.sender_hidden)
    sender = core.RnnSenderReinforce(sender,
                                opts.vocab_size, opts.sender_embedding, opts.sender_hidden,
                                cell=opts.sender_cell, max_len=opts.max_len, num_layers=opts.sender_num_layers,
                                force_eos=force_eos)

    receiver = Receiver(n_features=opts.receiver_hidden, n_hidden=opts.vocab_size)
    receiver = RnnReceiverImpatient(receiver, opts.vocab_size, opts.receiver_embedding,
                                    opts.receiver_hidden, cell=opts.receiver_cell,
                                    num_layers=opts.receiver_num_layers, max_len=opts.max_len, n_features=opts.n_features)

    sender.load_state_dict(torch.load(sender_weights,map_location=torch.device(device)))
    receiver.load_state_dict(torch.load(receiver_weights,map_location=torch.device(device)))
    
    game = SenderImpatientReceiverRnnReinforce(sender, receiver, loss_impatient, sender_entropy_coeff=opts.sender_entropy_coeff,
                                           receiver_entropy_coeff=opts.receiver_entropy_coeff,
                                           length_cost=opts.length_cost,unigram_penalty=opts.unigram_pen,reg=opts.reg)

    optimizer = core.build_optimizer(game.parameters())

    trainer = core.Trainer(game=game, optimizer=optimizer, train_data=train_loader,
                           validation_data=test_loader, callbacks=[EarlyStopperAccuracy(opts.early_stopping_thr)])


    for epoch in range(1,1+nb_ep):
        print("Epoch: "+str(epoch))

        trainer.train(n_epochs=1)
        acc_vec,messages=dump_impatient(trainer.game, opts.n_features, device, False,epoch)
        all_messages=[]
        for x in messages:
            x = x.cpu().numpy()
            all_messages.append(x)
            
        np.save(opts.dir_save + '/messages/messages_epoch_' + str(epoch+prev_ep) + '_n_features_' + str(opts.n_features) + '.npy', np.array(all_messages, dtype=object), allow_pickle=True)
        np.save(opts.dir_save+'/accuracy/accuracy_epoch_'+str(epoch+prev_ep)+'_n_features_'+str(opts.n_features)+'.npy', acc_vec)
    
    torch.save(sender.state_dict(), opts.dir_save+"/sender/sender_weights_epoch_"+str(nb_ep+prev_ep)+"_n_features_"+str(opts.n_features)+".pth")
    torch.save(receiver.state_dict(), opts.dir_save+"/receiver/receiver_weights_epoch_"+str(nb_ep+prev_ep)+"_n_features_"+str(opts.n_features)+".pth")

    core.close()

In [47]:
def parse(file_name):
    # parse opts.dir_save+"/sender/sender_weights_epoch_"+str(epoch+prev_ep)+"_n_features_"+str(opts.n_features)+".pth"
    n = int(file_name.split("_")[-2])
    nb_ep = int(file_name.split("_")[-4])
    return n, nb_ep

def update_n(file_name,n_old,n_new):
    return file_name.replace(str(n_old),str(n_new))

def add_words(n_features,sender_weights,receiver_weights,nb_words=20):
    device  = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Load the sender weights modify the weights and save them
    state_dict = torch.load(sender_weights,map_location=torch.device(device))
    old_weight = state_dict['agent.fc1.weight'] #torch.Size([175, 50]) ie (sender_hidden,n_features)
    additionnal_weight = torch.randn(sender_hidden,nb_words).to(device)*0.01
    new_weight = torch.cat((old_weight,additionnal_weight),1)
    state_dict['agent.fc1.weight'] = new_weight
    torch.save(state_dict,update_n(sender_weights,n_features,n_features+nb_words))
    
    # Load the receiver weights modify the weights and save them
    state_dict = torch.load(receiver_weights,map_location=torch.device(device))
    old_weight = state_dict['hidden_to_output.weight'] # torch.Size([50, 300]) ie (n_features,receiver_hidden)
    old_bias = state_dict['hidden_to_output.bias'] # torch.Size([50]) ie (n_features)
    new_weight = torch.cat((old_weight,torch.randn(nb_words,receiver_hidden).to(device)*0.01),0)
    new_bias = torch.cat((old_bias,torch.zeros(nb_words).to(device)),0)
    state_dict['hidden_to_output.weight'] = new_weight
    state_dict['hidden_to_output.bias'] = new_bias
    torch.save(state_dict,update_n(receiver_weights,n_features,n_features+nb_words))

In [172]:
clean_npy_files(epoch_min=50,n_min = 25) # WARNING
n_features = 25
n_epochs = 50
sender_weights = f"dir_save/sender/sender_weights_epoch_{n_epochs}_n_features_{n_features}.pth"
receiver_weights = f"dir_save/receiver/receiver_weights_epoch_{n_epochs}_n_features_{n_features}.pth"
frequence = DynamicPermutation()
frequence.initialize(n_features)
print(frequence.get_permutation())

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]


In [173]:
for _ in range(20):
    nb_add = 1
    add_words(n_features,sender_weights,receiver_weights,nb_add)
    n_features+=nb_add
    frequence.insert(nb_add)
    sender_weights = f"dir_save/sender/sender_weights_epoch_{n_epochs}_n_features_{n_features}.pth"
    receiver_weights = f"dir_save/receiver/receiver_weights_epoch_{n_epochs}_n_features_{n_features}.pth"
    
    # probs uniform sur les nouveaux mots
    probs = 1 / np.array(frequence.get_permutation())
    probs /= probs.sum()
    
    #train sur les nouveaux mots
    nb_ep = 3
    load_and_train(get_train_args(),probs,nb_ep,sender_weights,receiver_weights)
    n_epochs += nb_ep
    sender_weights = f"dir_save/sender/sender_weights_epoch_{n_epochs}_n_features_{n_features}.pth"
    receiver_weights = f"dir_save/receiver/receiver_weights_epoch_{n_epochs}_n_features_{n_features}.pth"

Previous epochs: 50
Epoch: 1
Impatient score=68
Epoch: 2
Impatient score=64
Epoch: 3
Impatient score=57
Previous epochs: 53
Epoch: 1
Impatient score=79
Epoch: 2
Impatient score=75
Epoch: 3
Impatient score=74
Previous epochs: 56
Epoch: 1
Impatient score=154
Epoch: 2
Impatient score=128
Epoch: 3
Impatient score=130
Previous epochs: 59
Epoch: 1
Impatient score=270
Epoch: 2
Impatient score=203
Epoch: 3
Impatient score=256
Previous epochs: 62
Epoch: 1
Impatient score=287
Epoch: 2
Impatient score=287
Epoch: 3
Impatient score=314
Previous epochs: 65
Epoch: 1
Impatient score=332
Epoch: 2
Impatient score=309
Epoch: 3
Impatient score=274
Previous epochs: 68
Epoch: 1
Impatient score=342
Epoch: 2
Impatient score=305
Epoch: 3
Impatient score=349
Previous epochs: 71
Epoch: 1
Impatient score=364
Epoch: 2
Impatient score=349
Epoch: 3
Impatient score=361
Previous epochs: 74
Epoch: 1
Impatient score=374
Epoch: 2
Impatient score=370
Epoch: 3
Impatient score=343
Previous epochs: 77
Epoch: 1
Impatient scor

In [169]:
#train sur les nouveaux mots
nb_ep = 20
load_and_train(get_train_args(),probs,nb_ep,sender_weights,receiver_weights)
n_epochs += nb_ep
sender_weights = f"dir_save/sender/sender_weights_epoch_{n_epochs}_n_features_{n_features}.pth"
receiver_weights = f"dir_save/receiver/receiver_weights_epoch_{n_epochs}_n_features_{n_features}.pth"

Previous epochs: 130
Epoch: 1
Impatient score=435
Epoch: 2
Impatient score=412
Epoch: 3
Impatient score=456
Epoch: 4
Impatient score=435
Epoch: 5
Impatient score=460
Epoch: 6
Impatient score=431
Epoch: 7
Impatient score=455
Epoch: 8
Impatient score=444
Epoch: 9
Impatient score=459
Epoch: 10
Impatient score=456
Epoch: 11
Impatient score=443
Epoch: 12
Impatient score=449
Epoch: 13
Impatient score=442
Epoch: 14
Impatient score=447
Epoch: 15
Impatient score=446
Epoch: 16
Impatient score=465
Epoch: 17
Impatient score=479
Epoch: 18
Impatient score=462
Epoch: 19
Impatient score=418
Epoch: 20
Impatient score=424


In [168]:

test(get_test_args())
position_analysis(get_test_args())

Namespace(n_features=45, batches_per_epoch=1000, dim_dataset=10240, force_eos=0, sender_hidden=250, receiver_hidden=600, receiver_num_layers=1, sender_num_layers=1, receiver_num_heads=8, sender_num_heads=8, sender_embedding=10, receiver_embedding=100, causal_sender=False, causal_receiver=False, sender_generate_style='in-place', sender_cell='lstm', receiver_cell='lstm', sender_entropy_coeff=0.1, receiver_entropy_coeff=0.1, probs='uniform', length_cost=0.0, name='model', early_stopping_thr=0.9999, receiver_weights='dir_save/receiver/receiver_weights_epoch_130_n_features_45.pth', sender_weights='dir_save/sender/sender_weights_epoch_130_n_features_45.pth', save_dir='analysis/', impatient=True, unigram_pen=0.0, random_seed=1390851128, checkpoint_dir=None, preemptable=False, checkpoint_freq=0, validation_freq=1, n_epochs=10, load_from_checkpoint=None, no_cuda=False, batch_size=32, optimizer='adam', lr=0.01, vocab_size=20, max_len=15, tensorboard=False, tensorboard_dir='runs/', cuda=True, dev