In [1]:
%load_ext autoreload

%autoreload 2

In [2]:
import warnings
warnings.filterwarnings("ignore")

### Load data

In [3]:
import sys
sys.path.append("runormas/")

In [4]:
from modules.data.read import DataReader

  class IteratorBase(collections.Iterator, trackable.Trackable,
  class DatasetV2(collections.Iterable, tracking_base.Trackable,


In [5]:
reader = DataReader(
    path="data/public_test/",
    output_dir="data/public_test_v6",
    tokenizer_name="sberbank-ai/rugpt3xl",
    part="test",
    answer_sep="<answer>",
    start_sep="<start>",
    local_rank=0,
    word_size=1,
    add_start_sep=True,
)

  and should_run_async(code)


Add answer_sep: <answer>
Add start_sep <start>


In [6]:
reader.prc(is_save=False)

Parsing files from test part on 0...: 100%|██████████| 4370/4370 [00:21<00:00, 203.03it/s]
Parsing files from test part on 0...: 100%|██████████| 536/536 [00:03<00:00, 174.26it/s]
Making raw files...: 100%|██████████| 4370/4370 [00:00<00:00, 6126.57it/s]
Making raw files...: 100%|██████████| 536/536 [00:01<00:00, 361.92it/s]


In [7]:
import os
from tqdm import tqdm
from collections import defaultdict


def get_progress(path, reader):
    total = defaultdict(int)
    done = defaultdict(int)
    predicted_anns = defaultdict(dict)
    errors = defaultdict(int)
    preds = defaultdict(dict)
    for data_part in reader.lm_prefixes:
        store_dir = os.path.join(path, data_part)
        names = list(reader.lm_prefixes[data_part].keys())

        for name in names:
            fn = os.path.join(store_dir, f"{name}.norm")
            total[data_part] += len(reader.lm_prefixes[data_part][name])
            preds[data_part][name] = []
            predicted_anns[data_part][name] = []
            if os.path.exists(fn):
                with open(fn, 'r', encoding='utf-8') as file:
                    predicted_norms = [x.strip() for x in file.read().split("\n") if x.strip()]

                for idx in range(len(reader.lm_prefixes[data_part][name])):
                    lm_prefix = reader.lm_prefixes[data_part][name][idx]
                    ann = reader.anns[data_part][name][idx]
                    if len(predicted_norms) == idx:
                        errors[data_part] += len(reader.lm_prefixes[data_part][name]) - idx
                        break
                    done[data_part] += 1
                    pred = predicted_norms[idx]
                    text = reader.texts[data_part][name]
                    start, stop = list(map(int, ann.split()))
                    gen_res = text[start:stop].strip()
                    if pred == gen_res:
                        predicted_anns[data_part][name].append(idx)
                    preds[data_part][name].append(pred)
    return total, done, predicted_anns, errors, preds
            

In [8]:
#fix_no_beamsv8_baseline_fix

In [61]:
total, done, predicted_anns, errors, no_beams_preds = get_progress("test_pred/fix_no_beams_v21", reader)

In [62]:
# total, done, predicted_anns, errors, baseline_preds = get_progress("test_pred/baseline", reader)
baseline_preds = no_beams_preds

In [79]:
# total, done, predicted_anns, errors, fixed_preds = get_progress("test_pred/beams16/", reader)
total, done, predicted_anns, errors, fixed_preds = get_progress("test_pred/beams_v2/", reader)

In [34]:
done

defaultdict(int, {'named': 115135, 'generic': 88869})

In [35]:
total

defaultdict(int, {'named': 115904, 'generic': 89877})

In [36]:
tp = 0
for data_part in predicted_anns:
    for name in predicted_anns[data_part]:
        tp += len(predicted_anns[data_part][name])

In [37]:
tp

68648

In [38]:
errors

defaultdict(int, {'named': 737, 'generic': 72})

In [39]:
sum(done.values()) / sum(total.values())

0.9913646060617841

In [40]:
tp / sum(done.values()) * sum(total.values())

69245.96619674124

### Test on train and valid data

In [16]:
import sys
sys.path.append("runormas/")

In [17]:
import os
os.environ["USE_DEEPSPEED"] = "1"

In [18]:
from src.xl_wrapper import RuGPT3XL

In [19]:
model = RuGPT3XL.from_pretrained(
    model_name_or_path="sberbank-ai/rugpt3xl",
    seq_len=512,
    weights_path="models/xl/v7/20000/mp_rank_00_model_states.pt",
    deepspeed_config_path="runormas/src/deepspeed_config/gpt3_xl_sparse_2048.json"
)

> initializing model parallel with size 1
Use alternating sparse & dense attention layers
> initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel rank 0 with model parallel seed: 3952 and data parallel seed: 1234


In [20]:
model.tokenizer = reader.tokenizer

In [18]:
import os
from tqdm import tqdm
import torch

In [19]:
def filter_results(nr):
    return [x[:x.find("<|endoftext|>")][:x.find("</s>")] for x in nr]


def generate(model, text, additional_len=32, num_beams=5, do_sample=None):
    min_len = min(len(model.tokenizer.encode(text)), 2048 - additional_len)
    with torch.no_grad():
        return filter_results(model.generate(
            text=text,
            max_length=min_len + additional_len,
            do_sample=do_sample,
            num_beams=num_beams,
            eos_token_id=model.tokenizer.eos_token_id,
            num_return_sequences=1,
        ))[0]



In [None]:
fixed_preds = defaultdict(dict)
for data_part in reader.lm_prefixes:

    for name in tqdm(reader.lm_prefixes[data_part], total=len(reader.lm_prefixes[data_part])):
        fixed_preds[data_part][name] = []
        for idx, lm_prefix in tqdm(enumerate(reader.lm_prefixes[data_part][name])):
            text = reader.texts[data_part][name]
            ann = reader.anns[data_part][name][idx]
            start, stop = list(map(int, ann.split()))
            ann = text[start:stop].strip()
            try:
                gen_res = preds[data_part][name][idx]
            except KeyboardInterrupt:
                raise StopIteration
            except:
                print("Error at", name)
                gen_res = ann
            if gen_res == ann:
                try:
                    gen_res = generate(model, lm_prefix, num_beams=None, do_sample=False)
                    gen_res = gen_res.split(reader.answer_sep)
                except KeyboardInterrupt:
                    raise StopIteration
                except:
                    print("Error at", name)
                    gen_res = [ann]
                if len(gen_res) == 1:
                    gen_res = text[start:stop].strip()
                else:
                    gen_res = gen_res[1].strip()
            fixed_preds[data_part][name].append(gen_res)


In [162]:
gen_res = generate(model, lm_prefix, num_beams=None, do_sample=False)
gen_res = gen_res.split(reader.answer_sep)
if len(gen_res) == 1:
    gen_res = text[start:stop].strip()
else:
    gen_res = gen_res[1].strip()
gen_res

'постановлять'

In [None]:
for data_part in preds:
    for name in preds[data_part]:
        for idx, pred in enumerate(preds[data_part][name]):
            lm_prefix = reader.lm_prefixes[data_part][name][idx]
            text = reader.texts[data_part][name]
            ann = reader.anns[data_part][name][idx]
            start, stop = list(map(int, ann.split()))
            ann = text[start:stop].strip()
            break
        break

In [53]:
from pymorphy2 import MorphAnalyzer

In [54]:
morph = MorphAnalyzer()

  args, varargs, kw, default = inspect.getargspec(cls.__init__)


In [64]:
from string import punctuation
punctuation = set(punctuation + "«»")


def fix_verbs(data_part, name, idx, fixed_preds, ann):
    if len(fixed_preds[data_part][name][idx].split()) == 1 and data_part == "generic":
        parsed = morph.parse(fixed_preds[data_part][name][idx])[0]
        if parsed.tag.POS in ["VERB", "INFN"]:
            fixed_preds[data_part][name][idx] = parsed.normal_form
        
def fix_title(data_part, name, idx, fixed_preds, ann):
    if ann.istitle():
        fixed_preds[data_part][name][idx] = fixed_preds[data_part][name][idx].title()
        return int(not fixed_preds[data_part][name][idx].istitle())
    return 0


def fix_upper_names(data_part, name, idx, fixed_preds, ann, ends=["ом", "а", "у", "е"]):
    for x in ends:
        if ann.endswith(x) and ann.replace(x, "").lower() == fixed_preds[data_part][name][idx]:
            if ann.replace(x, "").isupper():
                # print(fixed_preds[data_part][name][idx], ann)
                fixed_preds[data_part][name][idx] = fixed_preds[data_part][name][idx].upper()
                return 1
    return 0


def fix_e(data_part, name, idx, fixed_preds, ann):
    return 0
    if "ё" in fixed_preds[data_part][name][idx] and "ё" not in ann and data_part == "generic":
        fixed_preds[data_part][name][idx] = fixed_preds[data_part][name][idx].replace("ё", "е")
        return 1
    return 0


def fix_special_tokens(data_part, name, idx, fixed_preds, reader):
    if sum([x in fixed_preds[data_part][name][idx] for x in ["<s>", "</s>", reader.answer_sep, reader.start_sep]]):
        for x in ["<s>", "</s>", reader.answer_sep, reader.start_sep]:
            fixed_preds[data_part][name][idx] = fixed_preds[data_part][name][idx].replace(f" {x} ", " ")
            fixed_preds[data_part][name][idx] = fixed_preds[data_part][name][idx].replace(f"{x} ", " ")
            fixed_preds[data_part][name][idx] = fixed_preds[data_part][name][idx].replace(f" {x}", " ")
            fixed_preds[data_part][name][idx] = fixed_preds[data_part][name][idx].replace(f" {x}", " ")
            fixed_preds[data_part][name][idx] = fixed_preds[data_part][name][idx].replace(f"{x}", " ").strip()
        return 1
    return 0


def fix_words_count(data_part, name, idx, fixed_preds, ann):
    if len(fixed_preds[data_part][name][idx].split()) != len(ann.split()):
        # print(ann, fixed_preds[data_part][name][idx], baseline_preds[data_part][name][idx], sep=" | ")
        fixed_preds[data_part][name][idx] = baseline_preds[data_part][name][idx]
        return 1
    return 0


def fix_punct(data_part, name, idx, fixed_preds, ann):
    gen_res = fixed_preds[data_part][name][idx]
    is_err = False
    while len(gen_res) and not gen_res.lower().startswith(ann[0].lower()) and gen_res[0] in punctuation:
        gen_res = gen_res[1:]
        is_err = True
    while len(gen_res) and not gen_res.lower().endswith(ann[-1].lower()) and gen_res[-1] in punctuation:
        gen_res = gen_res[:-1]
        is_err = True
    if len(gen_res):
        fixed_preds[data_part][name][idx] = gen_res
    if not fixed_preds[data_part][name][idx].lower().startswith(ann[0].lower()):
        is_err = True
        fixed_preds[data_part][name][idx] = baseline_preds[data_part][name][idx]
    return is_err


def fix_is_title(data_part, name, idx, fixed_preds, ann):
    gen_res = fixed_preds[data_part][name][idx]
    try:
        if not gen_res[0].istitle() and ann[0].istitle():
            fixed_preds[data_part][name][idx] = gen_res[0].title() + gen_res[1:]
            return 1
    except:
        print(data_part, ann, fixed_preds[data_part][name][idx], baseline_preds[data_part][name][idx], sep="|")
    return 0


def fix_end_sym(data_part, name, idx, fixed_preds, ann):
    if ann[-1] in punctuation and fixed_preds[data_part][name][idx][-1] not in punctuation:
        fixed_preds[data_part][name][idx] += ann[-1]
        return 1
    return 0

In [66]:
fixed_preds_anns = 0
fixed_predicted_anns = defaultdict(dict)
errors = 0
is_title_errors = 0
compare_baseline = defaultdict(dict)
token_errors = 0
errors_len = 0
intersect_bound = 0.6
intersect_errors = 0
e_errors = 0
is_generate = False
upper_errors = 0
word_count_err = 0
start_sym_err = 0
bound_len = 0.5
char_len_errors = 0
is_title_errors = 0
end_sym_errors = 0
for data_part in reader.lm_prefixes:

    for name in tqdm(reader.lm_prefixes[data_part], total=len(reader.lm_prefixes[data_part])):
        fixed_predicted_anns[data_part][name] = []
        compare_baseline[data_part][name] = []
        for idx, lm_prefix in enumerate(reader.lm_prefixes[data_part][name]):
            text = reader.texts[data_part][name]
            ann = reader.anns[data_part][name][idx]
            start, stop = list(map(int, ann.split()))
            ann = text[start:stop].strip()
            try:
                _ = baseline_preds[data_part][name][idx]
            except:
                baseline_preds[data_part][name].append(ann)
            try:
                _ = fixed_preds[data_part][name][idx]
            except:
                fixed_preds[data_part][name].append(no_beams_preds[data_part][name][idx])
            if not len(fixed_preds[data_part][name][idx]):
                fixed_preds[data_part][name][idx] = no_beams_preds[data_part][name][idx]
            if not len(fixed_preds[data_part][name][idx]):
                fixed_preds[data_part][name][idx] = baseline_preds[data_part][name][idx]
            if not len(fixed_preds[data_part][name][idx]):
                fixed_preds[data_part][name][idx] = ann
            token_errors += fix_special_tokens(data_part, name, idx, fixed_preds, reader)
            is_title_errors += fix_title(data_part, name, idx, fixed_preds, ann)
            upper_errors += fix_upper_names(data_part, name, idx, fixed_preds, ann)
            e_errors += fix_e(data_part, name, idx, fixed_preds, ann)
            is_title_errors + fix_is_title(data_part, name, idx, fixed_preds, ann)
            start_sym_err += fix_punct(data_part, name, idx, fixed_preds, ann)
            end_sym_errors + fix_end_sym(data_part, name, idx, fixed_preds, ann)
            word_count_err += fix_words_count(data_part, name, idx, fixed_preds, ann)
            if len(fixed_preds[data_part][name][idx]) / len(ann) < bound_len:
                char_len_errors += 1
                fixed_preds[data_part][name][idx] = baseline_preds[data_part][name][idx]
            if fixed_preds[data_part][name][idx] == ann:
                fixed_preds_anns += 1
                fixed_predicted_anns[data_part][name].append(
                    {"ann": ann, "gen_res": fixed_preds[data_part][name][idx],
                     "text": text, "lm_prefix": lm_prefix})
            
            intersect = len(set(fixed_preds[data_part][name][idx]).intersection(ann))
            intersect /= max(1, len(set(fixed_preds[data_part][name][idx])))
            if intersect < intersect_bound:
                intersect_errors += 1
            
            if intersect < intersect_bound and is_generate:
                # print(ann, fixed_preds[data_part][name][idx], baseline_preds[data_part][name][idx], sep="|")
                gen_res = generate(model, lm_prefix, num_beams=None, do_sample=False)
                gen_res = gen_res.split(reader.answer_sep)
                if len(gen_res) == 1:
                    gen_res = text[start:stop].strip()
                else:
                    gen_res = gen_res[1].strip()
                fixed_preds[data_part][name][idx] = gen_res

                fix_special_tokens(data_part, name, idx, fixed_preds, reader)
                fix_title(data_part, name, idx, fixed_preds, ann)
                fix_upper_names(data_part, name, idx, fixed_preds, ann)
                fix_e(data_part, name, idx, fixed_preds, ann)
                word_count_err += fix_words_count(data_part, name, idx, fixed_preds, ann)
                is_title_errors + fix_is_title(data_part, name, idx, fixed_preds, ann)
                start_sym_err += fix_punct(data_part, name, idx, fixed_preds, ann)
                end_sym_errors + fix_end_sym(data_part, name, idx, fixed_preds, ann)
                intersect = len(set(fixed_preds[data_part][name][idx]).intersection(ann))
                intersect /= max(1, len(set(fixed_preds[data_part][name][idx])))
                if intersect < 0.6:
                    print(data_part, ann, fixed_preds[data_part][name][idx], baseline_preds[data_part][name][idx], sep="|")
            elif intersect < intersect_bound:
                fixed_preds[data_part][name][idx] = baseline_preds[data_part][name][idx]
            fix_verbs(data_part, name, idx, fixed_preds, ann)
            if fixed_preds[data_part][name][idx] != baseline_preds[data_part][name][idx]:
                compare_baseline[data_part][name].append(
                    {"ann": ann,
                     "gen_res": fixed_preds[data_part][name][idx],
                     "text": text,
                     "lm_prefix": lm_prefix,
                     "baseline": baseline_preds[data_part][name][idx]
                    })

100%|██████████| 4370/4370 [00:01<00:00, 2823.53it/s]
100%|██████████| 536/536 [00:02<00:00, 223.96it/s]


In [628]:
{
    "is_title_errors": is_title_errors,
    "token_errors": token_errors,
    "intersect_errors": intersect_errors,
    "e_errors": e_errors,
    "upper_errors": upper_errors,
    "word_count_err": word_count_err,
    "char_len_errors": char_len_errors,
    "start_sym_err": start_sym_err
}

{'is_title_errors': 0,
 'token_errors': 0,
 'intersect_errors': 48,
 'e_errors': 0,
 'upper_errors': 0,
 'word_count_err': 0,
 'char_len_errors': 0,
 'start_sym_err': 10}

In [625]:
# errors

In [92]:
morph.parse("горит")[0].tag.POS

'VERB'

In [None]:
["ADJF", "NOUN"]

In [89]:
single_counts = 0

In [93]:
from_baseline_count = 0
for data_part in ["generic"]:
    for name in list(reader.lm_prefixes[data_part]):
        for idx, lm_prefix in enumerate(reader.lm_prefixes[data_part][name]):
            text = reader.texts[data_part][name]
            ann = reader.anns[data_part][name][idx]
            start, stop = list(map(int, ann.split()))
            ann = text[start:stop].strip()
            gen_res = fixed_preds[data_part][name][idx]
            if len(gen_res.split()) == 1:
                parsed = morph.parse(gen_res)[0]
                if parsed.tag.POS in ["ADJF", "NOUN", "VERB"]:
                    single_counts += 1
                    fixed_preds[data_part][name][idx] = parsed.normal_form

In [94]:
single_counts

93029

In [107]:
errors = 0
for data_part in ["named"]:
    for name in list(reader.lm_prefixes[data_part]):
        for idx, lm_prefix in enumerate(reader.lm_prefixes[data_part][name]):
            text = reader.texts[data_part][name]
            ann = reader.anns[data_part][name][idx]
            start, stop = list(map(int, ann.split()))
            ann = text[start:stop].strip()
            gen_res = fixed_preds[data_part][name][idx]
            if sum([x.istitle() for x in gen_res.split()]) != sum([x.istitle() for x in ann.split()]):
                errors += 1
            # print(data_part, ann, fixed_preds[data_part][name][idx], baseline_preds[data_part][name][idx], sep="|")

In [109]:
errors

89

In [110]:
errors / total["named"]

0.0007678768636112644

In [70]:
from_baseline_count/sum(total.values())

0.9659152205500022

In [95]:
path = "test_pred/beams_v2_single_word_fix.2/"

In [96]:
for data_part in fixed_preds:
    store_dir = os.path.join(path, data_part)
    if not os.path.exists(store_dir):
        os.makedirs(store_dir, exist_ok=True)
    names = list(fixed_preds[data_part].keys())

    for name in tqdm(names, total=len(names), leave=True, desc=f"Predict on {0}"):
        with open(os.path.join(store_dir, f"{name}.norm"), 'w', encoding='utf-8') as file:
            for pred in fixed_preds[data_part][name]:
                file.write(f"{pred}\n")

Predict on 0: 100%|██████████| 4370/4370 [00:25<00:00, 168.92it/s]
Predict on 0: 100%|██████████| 536/536 [00:01<00:00, 347.99it/s]


In [443]:
1

1