In [4]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## initialization

In [5]:
from beam_search import create_dataset
from data import encode_line, load_vocabulary
from model import Transformer
import torch
import torch.nn as nn
import argparse

args = argparse.Namespace()
args.src_vocab = '../corpus/vocab_en_fr.txt'
args.tgt_vocab = '../corpus/vocab_en_fr.txt'
args.ckpt = './averaged_checkpoint.pt'
args.device = 'cpu'

batch_size = 16
beam_size = 5
length_penalty = 1
max_length = 256

source_vocabulary, _ = load_vocabulary(args.src_vocab)
target_vocabulary, target_vocabulary_rev = load_vocabulary(args.tgt_vocab)

bos = target_vocabulary["<s>"]
eos = target_vocabulary["</s>"]
model = Transformer(
    len(source_vocabulary),
    len(target_vocabulary),
    share_embeddings=True,
)
checkpoint = torch.load(args.ckpt)
model.load_state_dict(checkpoint["model"])
model.to(args.device)
model.eval()
n = 0
for param in model.parameters():
    # print(param.shape)
    n += param.numel()
print(n)

source_path = '../corpus/rep_test.en.tok'
dataset = create_dataset(
    source_path, source_vocabulary, batch_size, args.device)
ref_path = '../corpus/rep_test.fr'
with open(ref_path, 'r', encoding='utf-8') as file:
    ref = file.readlines()
ref[0]

209129472


'Cette fois-ci, la baisse est due à la chute des actions au Wall Street.\n'

## helper functions

In [6]:
def view(hypotheses):
    for hypo in hypotheses:
        score = hypo[0].item()
        tokens = hypo[1]
        if tokens and tokens[-1] == eos:
            tokens.pop(-1)
        tokens = [target_vocabulary_rev[token_id] for token_id in tokens]
        print(f"{score:.4f}", " ".join(tokens))

## inference analysis

In [7]:
from my_beam_search import beam_search

In [8]:
batch = next(iter(dataset))

In [9]:
from utils import *

### original version

In [10]:
# we just take small batch_size for test
# _ = beam_search(model, batch[None, 0], bos, eos)
with torch.no_grad():
    result = beam_search(
        model, batch[0:2], bos, eos, rep_penalty=False)

result is a batch of data, whose element is (score, sentence)

In [11]:
view(result[0])

-0.2325 Cette fois ￭, la chute des actions de Wall Street est responsable de la chute ￭.
-0.2389 Cette fois ￭, la chute des actions de Wall Street en est la cause ￭.
-0.2402 Cette fois ￭-￭ ci ￭, la chute des actions de Wall Street est responsable de la chute ￭.
-0.2528 Cette fois ￭-￭ ci ￭, la chute des actions de Wall Street en est la cause ￭.
-0.2846 Cette fois ￭, la chute des actions de Wall Street est responsable de la baisse ￭.
-0.3051 Cette fois ￭-￭ ci ￭, la chute des actions de Wall Street est responsable de la baisse ￭.
-0.3260 Cette fois ￭, la chute des actions de Wall Street est responsable de la chute des cours ￭.


Maybe what we want is 

`Cette fois ￭, la chute des actions de Wall Street est responsable de la baisse ￭.`

However, it is not the second candidate

### naive version

In [12]:
# we just take small batch_size for test
# _ = beam_search(model, batch[None, 0], bos, eos)
with torch.no_grad():
    penalized_result = beam_search(
        model, batch[0:2], bos, eos, rep_penalty=True, naive_penalty=5)

In [13]:
view(penalized_result[0])

-0.3882 Cette fois ￭-￭ ci ￭, la chute des actions de Wall Street en est responsable ￭.
-0.4981 Cette fois ￭-￭ ci ￭, la baisse des actions de Wall Street en est responsable ￭.
-0.5306 Cette fois ￭-￭ ci ￭, la chute des actions de Wall Street en est la cause ￭.
-0.7309 Cette fois ￭-￭ ci ￭, la chute des actions de Wall Street est à l ￭’￭ origine de cette baisse ￭.
-0.7371 Cette fois ￭-￭ ci ￭, la chute des actions de Wall Street est à l ￭'￭ origine de cette baisse ￭.
-0.7391 Cette fois ￭-￭ ci ￭, la chute des actions de Wall Street est à l ￭’￭ œuvre ￭.
-0.8594 Cette fois ￭-￭ ci ￭, la chute des actions de Wall Street est à l ￭’￭ origine de la baisse ￭.
-0.9045 Cette fois ￭-￭ ci ￭, la chute des actions de Wall Street est à l ￭’￭ origine de cette diminution ￭.
-0.9075 Cette fois ￭-￭ ci ￭, la chute des actions de Wall Street est à l ￭'￭ origine de cette diminution ￭.


### mask version

In [63]:
from my_beam_search import calc_mask


penalty_mask, leading_marker = calc_mask(target_vocabulary)

100%|██████████| 32000/32000 [00:00<00:00, 1655557.82it/s]


In [64]:
leading_marker

tensor([0., 0., 0.,  ..., 1., 0., 0.])

In [15]:
torch.save(penalty_mask, 'penalty_mask.pt')

In [16]:
penalty_mask = torch.load('penalty_mask.pt')

In [17]:
# we just take small batch_size for test
# _ = beam_search(model, batch[None, 0], bos, eos)
with torch.no_grad():
    penalized_result = beam_search(
        model, batch[0:2], bos, eos, rep_penalty=True, naive_penalty=2, penalty_mask=penalty_mask)

In [18]:
view(penalized_result[0])

-0.2389 Cette fois ￭, la chute des actions de Wall Street en est la cause ￭.
-0.2528 Cette fois ￭-￭ ci ￭, la chute des actions de Wall Street en est la cause ￭.
-0.2846 Cette fois ￭, la chute des actions de Wall Street est responsable de la baisse ￭.
-0.3051 Cette fois ￭-￭ ci ￭, la chute des actions de Wall Street est responsable de la baisse ￭.
-0.3103 Cette fois ￭, la baisse des actions de Wall Street est responsable de la chute ￭.
-0.3454 Cette fois ￭-￭ ci ￭, la chute des actions de Wall Street est responsable de la chute ￭.
-0.3845 Cette fois ￭, la chute des actions de Wall Street est responsable de la baisse des cours ￭.


## evaluation

In [96]:
def experi(model, dataset, **kwarg):
    isMask = 'penalty_mask' in kwarg
    filename = '_'.join(
        [f'{key}_{value}' for key, value in kwarg.items() if key != 'penalty_mask'])+('_mask' if isMask else '')+'.fr.tok'
    print(filename)
    with open('./corpus/'+filename, 'w', encoding='utf-8') as file:
        with torch.no_grad():
            for batch in dataset:
                result = beam_search(model, batch, bos, eos, **kwarg)
                # result = (batch_size, ), hypo = (score, sent)
                for hypotheses in result:
                    tokens = hypotheses[0][1]
                    if tokens and tokens[-1] == eos:
                        tokens.pop(-1)
                    tokens = [target_vocabulary_rev[token_id]
                              for token_id in tokens]
                    print(" ".join(tokens), file=file, flush=True)

In [100]:
kwarg = {
    'naive_penalty': 3
}
experi(model, dataset, **kwarg)

naive_penalty_3.fr.tok


In [101]:
kwarg = {
    'naive_penalty': 3,
    'penalty_mask': penalty_mask,
}
experi(model, dataset, **kwarg)

naive_penalty_3_mask.fr.tok


KeyboardInterrupt: 

In [None]:
kwarg = {
    'naive_penalty': 3,
    'penalty_mask': penalty_mask,
    'penalty_decay': 0.9
}
experi(model, dataset, **kwarg)

In [None]:
kwarg = {
    'naive_penalty': 2
}
experi(model, dataset, file, **kwarg)

In [None]:
kwarg = {
    'naive_penalty': 2,
    'penalty_mask': penalty_mask,
}
experi(model, dataset, **kwarg)

In [None]:
kwarg = {
    'naive_penalty': 2,
    'penalty_mask': penalty_mask,
    'penalty_decay': 0.9
}
experi(model, dataset, **kwarg)

In [109]:
import os
import subprocess
import json
from wordrep.utils import BaseRepDetector


def eval_score(directory='./corpus/', ref_path='../corpus/rep_test.fr'):
    c = "sacrebleu"  # Replace with your desired Bash command
    rep_detector = BaseRepDetector('fr')
    for filename in os.listdir(directory):
        if os.path.isfile(os.path.join(directory, filename)):
            # Process the file
            if filename.endswith('fr'):
                path = os.path.join(directory, filename)
                command = f'{c} {path} -i {ref_path}'
                # Execute the command
                process = subprocess.Popen(
                    command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
                # Wait for the command to finish and capture the output
                output, _ = process.communicate()
                # Decode the output
                output = output.decode('utf-8')
                o = json.loads(output)
                reps = rep_detector.detect_corpus(path, vis=False)
                print(filename, f"bleu={o['score']}, rep={len(reps)}")

In [110]:
eval_score()

100%|██████████| 739/739 [00:06<00:00, 121.97it/s]


naive_penalty_3_no_mask.fr bleu=21.7, rep=501


100%|██████████| 739/739 [00:03<00:00, 209.06it/s]


naive_penalty_2_no_mask.fr bleu=23.5, rep=558


100%|██████████| 739/739 [00:03<00:00, 207.62it/s]


naive_penalty_2.fr bleu=26.1, rep=607


100%|██████████| 739/739 [00:03<00:00, 223.40it/s]

naive_penalty_3.fr bleu=25.8, rep=567





### bleu

In [None]:
!sacrebleu ../corpus/rep_test.en.out.avg -i ../corpus/rep_test.fr

{
 "name": "BLEU",
 "score": 26.5,
 "signature": "nrefs:1|case:mixed|eff:no|tok:13a|smooth:exp|version:2.0.0",
 "verbose_score": "57.6/32.4/20.3/13.0 (BP = 1.000 ratio = 1.000 hyp_len = 28898 ref_len = 28891)",
 "nrefs": "1",
 "case": "mixed",
 "eff": "no",
 "tok": "13a",
 "smooth": "exp",
 "version": "2.0.0"
}
[0m

In [None]:
!sacrebleu ./corpus/naive_penalty_2_no_mask.fr -i ../corpus/rep_test.fr

{
 "name": "BLEU",
 "score": 23.5,
 "signature": "nrefs:1|case:mixed|eff:no|tok:13a|smooth:exp|version:2.0.0",
 "verbose_score": "55.0/29.2/17.5/10.8 (BP = 1.000 ratio = 1.031 hyp_len = 28898 ref_len = 28028)",
 "nrefs": "1",
 "case": "mixed",
 "eff": "no",
 "tok": "13a",
 "smooth": "exp",
 "version": "2.0.0"
}
[0m

In [None]:
!sacrebleu ./corpus/naive_penalty_2.fr -i ../corpus/rep_test.fr

{
 "name": "BLEU",
 "score": 26.1,
 "signature": "nrefs:1|case:mixed|eff:no|tok:13a|smooth:exp|version:2.0.0",
 "verbose_score": "57.2/32.0/20.0/12.8 (BP = 1.000 ratio = 1.003 hyp_len = 28898 ref_len = 28813)",
 "nrefs": "1",
 "case": "mixed",
 "eff": "no",
 "tok": "13a",
 "smooth": "exp",
 "version": "2.0.0"
}
[0m

In [None]:
!sacrebleu ./corpus/naive_penalty_3_no_mask.fr -i ../corpus/rep_test.fr

{
 "name": "BLEU",
 "score": 21.7,
 "signature": "nrefs:1|case:mixed|eff:no|tok:13a|smooth:exp|version:2.0.0",
 "verbose_score": "52.9/27.2/16.0/9.7 (BP = 1.000 ratio = 1.036 hyp_len = 28898 ref_len = 27883)",
 "nrefs": "1",
 "case": "mixed",
 "eff": "no",
 "tok": "13a",
 "smooth": "exp",
 "version": "2.0.0"
}
[0m

In [None]:
!sacrebleu ./corpus/naive_penalty_3.fr -i ../corpus/rep_test.fr

{
 "name": "BLEU",
 "score": 25.8,
 "signature": "nrefs:1|case:mixed|eff:no|tok:13a|smooth:exp|version:2.0.0",
 "verbose_score": "57.0/31.6/19.7/12.6 (BP = 1.000 ratio = 1.001 hyp_len = 28898 ref_len = 28867)",
 "nrefs": "1",
 "case": "mixed",
 "eff": "no",
 "tok": "13a",
 "smooth": "exp",
 "version": "2.0.0"
}
[0m

### rep rate

In [None]:
from wordrep.utils import BaseRepDetector

In [None]:
rep_detector = BaseRepDetector('fr')

In [None]:
reps = rep_detector.detect_corpus('../corpus/rep_test.en.out.avg', vis=False)
len(reps)

100%|██████████| 739/739 [00:06<00:00, 120.52it/s]


739

In [None]:
reps = rep_detector.detect_corpus('./corpus/naive_penalty_2.fr', vis=False)
len(reps)

100%|██████████| 739/739 [00:03<00:00, 223.68it/s]


607

In [None]:
reps = rep_detector.detect_corpus(
    './corpus/naive_penalty_2_no_mask.fr', vis=False)
len(reps)

100%|██████████| 739/739 [00:03<00:00, 228.90it/s]


558

In [None]:
reps = rep_detector.detect_corpus('./corpus/naive_penalty_3.fr', vis=False)
len(reps)

100%|██████████| 739/739 [00:03<00:00, 224.27it/s]


567

In [None]:
reps = rep_detector.detect_corpus(
    './corpus/naive_penalty_3_no_mask.fr', vis=False)
len(reps)

100%|██████████| 739/739 [00:03<00:00, 211.85it/s]


501

In [None]:
penalty_mask[target_vocabulary['b']]

tensor(1.)

In [46]:
import spacy
nlp = spacy.load('fr_core_news_lg')

In [58]:
nlp('i')[0].pos_

'ADJ'