In [21]:
import logging
import os
import concurrent.futures
import importlib
from pathlib import Path
import re
import shutil
import subprocess
import sys
import time
from typing import *

import colorama
import jsonlines
import matplotlib.pyplot as plt
import numpy as np
import rich
import tqdm

In [22]:
##########################################################################################
# Prepare the file lists
##########################################################################################
def parse_num(path):
    numbers = re.findall("[0-9]+", str(Path(path).name))
    assert len(numbers) == 1, len(numbers)
    return int(numbers[0])


target_dir = Path("iterated_decoding_output/first_test/").resolve()
wd_dir_retr = target_dir / "transformed_retr"
wd_dir_gen = target_dir / "transformed_gen"

for path in [wd_dir_retr, wd_dir_gen]:
    if path.exists():
        shutil.rmtree(path)
    path.mkdir(exist_ok=False)

contexts_outputs = sorted(target_dir.glob("retr_outs_*"), key=parse_num)
reader_outputs = sorted(target_dir.glob("reader_outs_*"), key=parse_num)

assert len(contexts_outputs) == len(reader_outputs)

print(f"{len(contexts_outputs) = }")
print(f"{len(reader_outputs) = }")

len(contexts_outputs) = 2
len(reader_outputs) = 2


In [23]:
def count_lines(path):
    return int(subprocess.check_output(
        ["wc", "-l", str(path)]
    ).decode().strip().split()[0])


def job(packed):
    i, path = packed
    assert i == parse_num(path), (i, parse_num(path))
    loop = []

    with jsonlines.open(path) as fin:
        for line in fin:
            loop.extend(line)
    return loop
            

reader_gens = list(map(job, enumerate(reader_outputs)))

In [32]:
##########################################################################################
# Clean and move the data
##########################################################################################

orig_targets = Path("../../GAR/data/nq-answer/val.target").resolve()
wd_targets = wd_dir_gen / "val_targets.txt"

if not wd_targets.exists():
    shutil.copy(orig_targets, wd_targets)

for i, lines in enumerate(tqdm.tqdm(reader_gens)):
    with open(wd_dir_gen/f"val_predictions-{i}.txt", "w") as fout:
        for line in lines:
            fout.write(line + "\n")

for path in tqdm.tqdm(contexts_outputs):
    print(path)
    shutil.copy(path, wd_dir_retr / path.name)
        
contexts_outputs_wd = sorted(wd_dir_retr.glob("retr_outs_*"), key=parse_num)
reader_outputs_wd = sorted(wd_dir_gen.glob("val_predictions-*"), key=parse_num)

print(f"{len(contexts_outputs_wd) = }")
print(f"{len(reader_outputs_wd) = }")

100%|██████████| 2/2 [00:00<00:00, 187.25it/s]
100%|██████████| 2/2 [00:00<00:00, 55.79it/s]

/home/mila/g/gagnonju/IteratedDecoding/jobs/iterated_decoding_output/first_test/retr_outs_0.jsonl
/home/mila/g/gagnonju/IteratedDecoding/jobs/iterated_decoding_output/first_test/retr_outs_1.jsonl
len(contexts_outputs_wd) = 2
len(reader_outputs_wd) = 2





In [43]:
!which python
!python compute_rouge.py {str(wd_dir_gen)}

/home/mila/g/gagnonju/condaless/bin/python
[33mStacking [0m[33m/home/mila/g/gagnonju/IteratedDecoding/jobs/iterated_decoding_output/fi[0m
[33mrst_test/transformed_gen/[0m[33mrouge_1.txt[0m
[32mStarting [0m[32m/home/mila/g/gagnonju/IteratedDecoding/jobs/iterated_decoding_output/fi[0m
[32mrst_test/transformed_gen/[0m[32mrouge_1.txt[0m
python -m rouge_score.rouge --[33mtarget_filepattern[0m=[35m/home/mila/g/gagnonju/IteratedD[0m
[35mecoding/jobs/iterated_decoding_output/first_test/transformed_gen/[0m[95mval_targets.txt[0m
 --[33mprediction_filepattern[0m=[35m/home/mila/g/gagnonju/IteratedDecoding/jobs/iterated_d[0m
[35mecoding_output/first_test/transformed_gen/[0m[95mval_predictions-1.txt[0m 
--[33muse_stemmer[0m=[3;92mTrue[0m --[33moutput_filename[0m=[35m/home/mila/g/gagnonju/IteratedDecoding/jobs[0m
[35m/iterated_decoding_output/first_test/transformed_gen/[0m[95mrouge_1.txt[0m
I1001 04:09:09.112025 139703116732032 io.py:108] Reading targets fr

In [44]:
paths_rouge = sorted(wd_dir_gen.glob("rouge*"), key=parse_num)
print(f"{len(paths_rouge) = }")

len(paths_rouge) = 2


In [26]:
def read_file(path: Path):
    with open(path) as fin:
        return fin.read()

def read_jsonl(path: Path):
    assert str(path).endswith(".jsonl"), path
    chars = read_file(path)
    with jsonlines.open(path) as fin:
        return list(fin)


In [41]:
# print(contexts_outputs_wd)
f0_path = contexts_outputs_wd[0]
f0 = read_jsonl(f0_path)

for path_f1 in sorted(set(contexts_outputs_wd) - {f0_path}, key=parse_num):
    f1 = read_jsonl(path_f1)
    # assert len(f0) == len(f1), (len(f0), len(f1))
    qty_all = len(f0[0])

    goods = 0
    bads = 0
    for i, (f0_l, f1_l) in enumerate(zip(f0, f1)):
        bath_size = len(f0_l["ids"])
        assert len(f0_l["ids"]) == len(f1_l["ids"])
        for index in range(bath_size):
            set_0 = {int(x) for x in f0_l["ids"][index]}
            set_1 = {int(x) for x in f1_l["ids"][index]}
            set_or = set_0 | set_1
            set_and = set_0 & set_1
            ratio = len(set_and) / len(set_or)
            print(index + bath_size * i, f"{ratio:0.0%}")
            if len(set_and) != 0:
                goods += 1
            else:
                bads += 1
    print(f"{parse_num(path_f1)}: {goods / (goods + bads):.2%} that have more than absolute 0 overlap")


0 25%
1 43%
2 0%
3 67%
4 67%
5 0%
6 0%
7 43%
8 11%
9 0%
10 43%
11 25%
12 25%
13 0%
14 0%
15 0%
16 43%
17 0%
18 67%
19 11%
20 0%
21 67%
22 43%
23 0%
24 43%
25 25%
26 25%
27 43%
28 11%
29 25%
30 25%
31 11%
32 25%
33 25%
34 0%
35 67%
36 25%
37 25%
38 11%
39 0%
40 25%
41 43%
42 67%
43 11%
44 67%
45 43%
46 43%
47 67%
48 43%
49 11%
50 11%
51 0%
52 43%
53 25%
54 0%
55 25%
56 43%
57 25%
58 100%
59 11%
60 0%
61 67%
62 11%
63 11%
64 25%
65 11%
66 25%
67 11%
68 25%
69 11%
70 43%
71 11%
72 25%
73 43%
74 43%
75 25%
76 11%
77 25%
78 0%
79 25%
80 0%
81 0%
82 25%
83 11%
84 11%
85 0%
86 67%
87 43%
88 43%
89 25%
90 11%
91 11%
92 25%
93 25%
94 43%
95 67%
96 67%
97 11%
98 100%
99 100%
100 43%
101 43%
102 0%
103 25%
104 0%
105 0%
106 0%
107 0%
108 11%
109 25%
110 0%
111 0%
112 100%
113 11%
114 43%
115 25%
116 25%
117 0%
118 100%
119 43%
120 43%
121 43%
122 25%
123 43%
124 0%
125 11%
126 25%
127 25%
128 25%
129 67%
130 11%
131 25%
132 67%
133 43%
134 11%
135 11%
136 11%
137 11%
138 25%
139 0%
140 25%
141 25