In [25]:
import os
import gc
import psutil
from pathlib import Path
import pathlib

import pandas as pd
import numpy as np
pd.set_option('display.max_rows', 100)
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import get_cosine_schedule_with_warmup, DataCollatorWithPadding

from scipy.spatial.distance import cosine
from torch.utils.data import DataLoader, Dataset, SequentialSampler
from sklearn.neighbors import NearestNeighbors

device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'

In [26]:
class PATH:
    # epoch 34668
    model = '/root/autodl-nas/model/sentence-transformers/all-MiniLM-L6-v2_new_r1.1'
    model_cpt = '/root/autodl-nas/model/r1_13/checkpoint-9738'
    # input_dir = '/root/autodl-nas/data/k12'
    input_dir = '/root/autodl-nas/data/k12/cv_data/fold_0'
    
    output_dir = '/root/autodl-nas/data/k12/out'
    cv_dir = '/root/autodl-nas/data/k12/cv_data'
    pretrained_dir = '/root/autodl-nas/model/'
    content_dir = os.path.join(input_dir, 'content.csv')
    correlation_dir = os.path.join(input_dir, 'correlations.csv')
    submission_dir = os.path.join(input_dir, 'sample_submission.csv')
    topic_dir = os.path.join(input_dir, 'topics.csv')
    
    
class CFG:
    seed = 11
    fold = 0
    n_fold = 3
    model_name = 'sentence-transformers/all-MiniLM-L6-v2'
    # cpt = '/root/autodl-nas/model/checkpoint-34668/'

## Lang

In [27]:
def get_level_features(df_topic, level_cols=['title']):
    df_hier = df_topic[list(set(level_cols + ['id', 'parent', 'level', 'has_content']))]
    highest_level = df_hier['level'].max()
    
    df_level = df_hier.query('level == 0').copy(deep=True)
    level_list = list()
    for col in level_cols:
        df_level[f'{col}_level'] = df_level[f'{col}'].apply(lambda x: [x])

    for i in tqdm(range(highest_level + 1)):
        level_list.append(df_level[df_level['has_content']])
        df_level_high = df_hier.query('level == @i+1')
        df_level = df_level_high.merge(df_level, left_on='parent', right_on='id', suffixes=['', '_parent'], how='inner')
        for col in level_cols:
            df_level[f'{col}_level'] = df_level[f'{col}_level'] + df_level[f'{col}'].apply(lambda x: [x])
        for col in df_level.columns:
            if col.endswith('_parent'):
                df_level.drop(columns=col, inplace=True)
    df = pd.concat(level_list).reset_index(drop=True)
    return df[list(set(['id'] + [f'{col}_level' for col in level_cols]))]

In [28]:
def get_topic_field(d):
    title = list(filter(lambda x: pd.notna(x), d['title_level']))
    title = ' of '.join(title[-1::-1])
    title = 'No information' if title=='' else title
    title = '[TITLE] ' + title + '. '
    description = d['description'] if pd.notna(d['description']) else 'No information'
    description = '[DESCRIPTION]' + description + '. '
    field = title + description
    return field

def get_content_field(d):
    title = d['title']
    title = 'No information' if pd.isna(title) else title
    title = '[TITLE] ' + title + '. '
    description = d['description'] if pd.notna(d['description']) else 'No information'
    description = '[DESCRIPTION]' + description + '. '
    kind = '[' + d['kind'] + '] '
    field = kind + title + description
    return field

In [29]:
def prepare_language_match(path, mode='train'):
    topic = pd.read_csv(path.topic_dir)[['id', 'language']]
    content = pd.read_csv(path.content_dir)[['id', 'language']]
    if mode == 'train':
        corr = pd.read_csv(path.correlation_dir)
    elif mode == 'valid':
        corr = pd.read_csv(path.submission_dir)
    
    topic = topic.merge(corr, left_on='id', right_on='topic_id', how='right')[['id', 'language']]
    match_dict = {}
    for language in topic['language'].unique():
        match_dict[language] = (topic.query('language==@language')[['id']], content.query('language==@language')[['id']])
    return match_dict

In [30]:
def prepare_match_features(topic, content, path):
    df_topic = pd.read_csv(path.topic_dir)
    df_content = pd.read_csv(path.content_dir)
    level = get_level_features(df_topic)
    df_topic = df_topic.merge(level, on='id', how='right')
    df_topic['field'] = df_topic.apply(lambda x: get_topic_field(x), axis=1)
    df_content['field'] = df_content.apply(lambda x: get_content_field(x), axis=1)
    topic = topic[['id']].merge(df_topic[['id', 'field']], on='id', how='left')
    content = content[['id']].merge(df_content[['id', 'field']], on='id', how='left')
    return topic, content

In [31]:
%%time
topic_content_match = prepare_language_match(PATH, mode='valid')

CPU times: user 10.1 s, sys: 1.41 s, total: 11.5 s
Wall time: 11.5 s


In [32]:
# %%time
for lang in topic_content_match.keys():
    print(f'{lang}\t - topics: {len(topic_content_match[lang][0])}\t - contents: {len(topic_content_match[lang][1])}')
#     topic, content = topic_content_match[lang]
#     topic, content = prepare_match_features(topic, content, PATH)

gu	 - topics: 131	 - contents: 3677
en	 - topics: 5373	 - contents: 65939
es	 - topics: 2189	 - contents: 30844
fr	 - topics: 111	 - contents: 10682
hi	 - topics: 132	 - contents: 4042
fil	 - topics: 83	 - contents: 516
pt	 - topics: 76	 - contents: 10435
bn	 - topics: 191	 - contents: 2513
as	 - topics: 18	 - contents: 641
sw	 - topics: 31	 - contents: 1447


## Calc Embeddings

In [33]:
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

CPT_PATH = PATH.model_cpt
MODEL_PATH = PATH.model
OUTPUT_PATH = os.path.join(PATH.cv_dir, f"fold_{CFG.fold}")
N_NEIGHBOR = 100

In [34]:
class PlainDataset(Dataset):

    def __init__(self, df, tokenizer, label_name="") -> None:
        super().__init__()
        self.data = df[label_name].tolist()
        self.tokenizer = tokenizer

    def __getitem__(self, index):
        text = self.data[index]
        inputs = self.tokenizer(
                text, 
                add_special_tokens = True,
                truncation='longest_first',
                max_length = 64,
                padding = 'max_length',
                return_attention_mask = True,
                return_tensors = 'pt',
        )
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        return inputs

    def __len__(self):
        return len(self.data)

In [35]:
class Convert2Embed(object):

    def __init__(self) -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
        self.model = AutoModel.from_pretrained(CPT_PATH).cuda()

    def convert2embeddind(self, df, label_name=""):
        embed: list = []
        dataset = PlainDataset(df, tokenizer=self.tokenizer, label_name=label_name)
        dataloader = DataLoader(dataset, sampler=SequentialSampler(dataset), batch_size=32)
        for batch in dataloader:
            batch = {k: v.cuda() for k, v in batch.items()}
            with torch.no_grad():
                embeddings = self.model(**batch, output_hidden_states=True, return_dict=True).pooler_output
                embed.append(embeddings.cpu().clone().detach().numpy())
        embed = np.concatenate(embed, axis=0)
        return embed

    def get_embed(self):
        for lang in topic_content_match.keys():
            # topic, content = topic_content_match[lang]
            # topic, content = prepare_match_features(topic, content, PATH)
            # topic_path = os.path.join(OUTPUT_PATH, "valid", f"topic_{lang}.pqt")
            # content_path = os.path.join(OUTPUT_PATH, "valid", f"content_{lang}.pqt")
            # topic.to_parquet(topic_path)
            # content.to_parquet(content_path)
            
            for t in ["content", "topic"]:
                path = os.path.join(OUTPUT_PATH, "valid", f"{t}_{lang}.pqt")
                df = pd.read_parquet(path)
                embed = self.convert2embeddind(df, label_name=f"field")
                np.save(path.replace(".pqt", ".npy"), embed)


In [36]:
def valid():
#     with open(os.path.join(OUTPUT_PATH, "valid", "language.txt"), "r") as f:
#         valid_language = f.read().splitlines()
    recall_amount = 0
    recall_amount_total = 0
    recall_num = 0
    recall_total = {}
    f2_sum = 0
    df_list = []
    global df_pred_all
    for lang in topic_content_match.keys():
        ## debug
        # global df_pred, df_correlations
        global df_pred
        content_path = os.path.join(OUTPUT_PATH, "valid", f"content_{lang}.npy")
        topics_path = os.path.join(OUTPUT_PATH, "valid", f"topic_{lang}.npy")
        correlations_path = PATH.submission_dir
        content_array = np.load(content_path)
        topics_array = np.load(topics_path)
        model = NearestNeighbors(n_neighbors=N_NEIGHBOR, metric="cosine")
        model.fit(content_array)
        d, r = model.kneighbors(topics_array)
        df_content = pd.read_parquet(content_path.replace(".npy", ".pqt"))
        df_topics = pd.read_parquet(topics_path.replace(".npy", ".pqt"))
        df_correlations = pd.read_csv(correlations_path).astype({"topic_id": str})
        
        pred = {"topic_id": [], "content_ids": [], 'dists': []}
        for i in range(len(df_topics)):
            r_t = r[i]
            tmp = []
            for c in r_t:
                content_id = df_content.iloc[c]["id"]
                tmp.append(content_id)
            topics_id = df_topics.iloc[i]["id"]
            pred["topic_id"].append(topics_id)
            pred["content_ids"].append(tmp)
            pred['dists'].append(d[i])
        
        df_pred = pd.DataFrame(pred).astype({"topic_id": str})
        
        df_correlations['content_ids'] = df_correlations['content_ids'].apply(lambda x: list(x.split()))
        df_pred = df_pred.merge(df_correlations, on='topic_id', how='left', suffixes=['_pred', '_true'])
        df_pred['num'] = df_pred.apply(lambda x: len(x['content_ids_true']), axis=1)
        df_pred['hit'] = df_pred.apply(lambda x: len(set(x['content_ids_true']).intersection(x['content_ids_pred'])), axis=1)
        df_pred['recall'] = df_pred.apply(lambda x: x['hit'] / len(x['content_ids_true']), axis=1)
        df_pred['precision'] = df_pred.apply(lambda x: x['hit'] / len(x['content_ids_pred']), axis=1)
        df_pred['f2'] = 5*df_pred['precision']*df_pred['recall']/(4*df_pred['precision']+df_pred['recall'])
        recall = df_pred['recall'].mean()
        recall_total[lang] = recall
        recall_num += len(df_pred)
        recall_amount += df_pred['recall'].sum()
        recall_amount_total += df_pred['hit'].sum()/df_pred['num'].sum() * len(df_pred)
        f2_sum += df_pred['f2'].sum()
        df_list.append(df_pred)
        print(f"{lang}: {df_pred['hit'].sum()/df_pred['num'].sum()} - f2: {df_pred['f2'].sum()/len(df_pred)}")
        
    df_pred_all = pd.concat(df_list)
    print(f"Recall: total - {recall_amount_total/recall_num} line average - {recall_amount/recall_num} - f2 - {f2_sum/recall_num}")
    print(f"----------------Details----------------")
    for k, v in recall_total.items():
        print(f"Recall for language {k}: {v}")


In [37]:
%%time
P = Convert2Embed()
P.get_embed()

Some weights of the model checkpoint at /root/autodl-nas/model/r1_13/checkpoint-9738 were not used when initializing BertModel: ['lm_head.transform.dense.weight', 'mlp.dense.bias', 'lm_head.transform.dense.bias', 'lm_head.decoder.bias', 'mlp.dense.weight', 'lm_head.bias', 'lm_head.decoder.weight', 'lm_head.transform.LayerNorm.bias', 'lm_head.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at /root/autodl-nas/model/r1_13/checkpoint-9738 and are newly initialized: ['bert.pooler.dense.weight

CPU times: user 1min 23s, sys: 1.59 s, total: 1min 24s
Wall time: 1min 22s


In [38]:
def get_pred_by_thresh(df, thresh, n_default=1):
    pred_by_thresh = list()
    for i, d in enumerate(df['dists']):
        if d<=thresh:
            pred_by_thresh.append(df['content_ids_pred'][i])
    if not pred_by_thresh:
        pred_by_thresh = df['content_ids_pred'][:n_default]
    return pred_by_thresh

def calculate_f2(df_pred_all, true_label='content_ids_true', pred_label='pred_by_thresh') -> int:
    df_pred_all['hit'] = df_pred_all.apply(lambda x: len(set(x[pred_label]).intersection(x[true_label])), axis=1)
    df_pred_all['recall'] = df_pred_all.apply(lambda x: x['hit']/len(x[true_label]), axis=1)
    df_pred_all['precision'] = df_pred_all.apply(lambda x: x['hit']/len(x[pred_label]), axis=1)
    df_pred_all['f2'] = (5*df_pred_all['precision']*df_pred_all['recall']/(4*df_pred_all['precision']+df_pred_all['recall'])).fillna(0)
    return df_pred_all['f2'].mean()

def optimize_f2(df_pred_all, init_thresh=0.2, init_step=0.01) -> int:
    thresh = init_thresh
    df_pred_all['pred_by_thresh'] = df_pred_all.apply(lambda x: get_pred_by_thresh(x, 0.2), axis=1)
    f2 = calculate_f2(df_pred_all)

In [39]:
# 275 383

In [40]:
# df_pred_all['pred_by_thresh'] = df_pred_all.apply(lambda x: get_pred_by_thresh(x, 0.4), axis=1)
# print(calculate_f2(df_pred_all))

# df_pred_all = df_pred_all[['topic_id', 'content_ids_pred', 'dists', 'content_ids_true']]
# # df_pred_all['labels'] = df_pred_all.apply(lambda x: label_pred(x), axis=1)
# for i in range(170, 450, 2):
#     thresh = i/1000
#     df_pred_all['pred_by_thresh'] = df_pred_all.apply(lambda x: get_pred_by_thresh(x, thresh, 3), axis=1)
#     f2 = calculate_f2(df_pred_all)
#     print(i, f2)

# df_pred_all.explode(['content_ids_pred', 'dists', 'labels'])

170 0.4024688386835773
172 0.4031549520762678
174 0.40409479083512645
176 0.40498780130113327
178 0.4056221467042033
180 0.40624726560527924
182 0.40721841030008704
184 0.409052690435736
186 0.41030785013923904
188 0.4105493930238
190 0.4121692194096475
192 0.4132842857948191
194 0.4144896801795744
196 0.41615635613829677
198 0.41766006453443766
200 0.41965067582784593
202 0.4202251120409235
204 0.42037800850695833
206 0.4214418561517112
208 0.4235262148368645
210 0.42532777889037154
212 0.42699839812736545
214 0.42848119762586157
216 0.43026228612064904
218 0.43200040637419174
220 0.4332824488732244
222 0.43483586366382354
224 0.4367978856794664
226 0.43781717992134717
228 0.4391767284118734
230 0.44009180614394006
232 0.4417736674332854
234 0.4430857341640051
236 0.44381462956587736
238 0.44525578177706604
240 0.4471523710828244
242 0.44870426476357084
244 0.4497284873353738
246 0.45163057757041214
248 0.4530941896947352
250 0.4548706386758581
252 0.45620181582970554
254 0.45695022797459456
256 0.4576392577793879
258 0.4577774966032547
260 0.45801281237339675
262 0.45868644081494886
264 0.458481439820201
266 0.45903038240513294
268 0.4595054445782289
270 0.4594602890972266
272 0.4594325829137616
274 0.45987696587445226
276 0.4602119501917246
278 0.46098732313376983
280 0.46121830781870066
282 0.46157880182145566
284 0.460916378887469
286 0.46063599518269166
288 0.46015462896328635
290 0.46026819515777456
292 0.46005084723586476
294 0.45913039790317867
296 0.45844901040602476
298 0.45750240076191034
300 0.45692206443176586
302 0.45608328843132245
304 0.4549603767536112
306 0.45329443276876796
308 0.4517214889907147
310 0.45083376456284796
312 0.44961959295741877
314 0.448034045429716
316 0.4461080456862716
318 0.44414396537646245
320 0.4425980612628054
322 0.44092277734608776
324 0.4389668410680323
326 0.4367121022990652
328 0.43420323722183946
330 0.4315968321759618
332 0.4294783204714232
334 0.4271767339470532
336 0.42463422374234044
338 0.42189842384377607
340 0.41922310125359186
342 0.4163151008820663
344 0.4132511312819868
346 0.4100375863146136
348 0.4068978474694664
350 0.4041449159496747
352 0.4010115542362085
354 0.3977875787529075
356 0.3948931385115175
358 0.3913653717224516
360 0.3877346850668936
362 0.38452566672457694
364 0.38094351006016375
366 0.37740241893768
368 0.3736814676972883
370 0.36987613243759815
372 0.36632612945386933
374 0.36227639122395994
376 0.35868013462650117
378 0.35507737565132824
380 0.3512169514101543
382 0.3472301898209479
384 0.343266006151315
386 0.3389849776348872
388 0.33505053612718283
390 0.33077591982865007
392 0.32659240791737504
394 0.3223306480778996
396 0.3180902050053218
398 0.3139647617144388
400 0.30977561550307586
402 0.3056369682307033
404 0.3015421054155209
406 0.29711072715398107
408 0.2929429504219797
410 0.2888552872196978
412 0.2845286862129868
414 0.28014803099027047
416 0.2759176750171542
418 0.2719160264420336
420 0.26765219196053575
422 0.2634692270432641
424 0.2592592306452557
426 0.25531781926012703
428 0.25119991693635824
430 0.24713812191438406
432 0.24296433979588497
434 0.23907443839564776
436 0.23522140710825673
438 0.2312847831007713
440 0.2273921890591881
442 0.22357551167977419
444 0.2195538619615006
​

In [41]:
%%time
N_NEIGHBOR = 6
valid()

gu: 0.13582677165354332 - f2: 0.11966925839183284
en: 0.3658066071859175 - f2: 0.3919186712598838
es: 0.24667540737965912 - f2: 0.16398126419158054
fr: 0.29172141918528255 - f2: 0.34113255163769163
hi: 0.20483870967741935 - f2: 0.2021483142316011
fil: 0.5422222222222223 - f2: 0.49821351264896346
pt: 0.38190954773869346 - f2: 0.3659902029283047
bn: 0.14225941422594143 - f2: 0.10952561109629172
as: 0.14492753623188406 - f2: 0.12700563513571642
sw: 0.1165644171779141 - f2: 0.10499548540349872
Recall: total - 0.3227459834711503 line average - 0.42257588727041817 - f2 - 0.3168071147841835
----------------Details----------------
Recall for language gu: 0.15069639748225525
Recall for language en: 0.5106737642420355
Recall for language es: 0.2509608240772359
Recall for language fr: 0.405689314084567
Recall for language hi: 0.237771164021164
Recall for language fil: 0.6928427997705106
Recall for language pt: 0.4718885281385282
Recall for language bn: 0.16709503411074092
Recall for language as: 

In [42]:
%%time
N_NEIGHBOR = 5
valid()

gu: 0.11614173228346457 - f2: 0.1108722993701157
en: 0.33498111084317983 - f2: 0.3885571468838667
es: 0.23000561902978087 - f2: 0.1675327693416328
fr: 0.2628120893561104 - f2: 0.3277426884341225
hi: 0.1774193548387097 - f2: 0.18345470800212874
fil: 0.52 - f2: 0.5179257588896142
pt: 0.35678391959798994 - f2: 0.36596929019995544
bn: 0.12761506276150628 - f2: 0.10735863479874244
as: 0.14492753623188406 - f2: 0.1357241327942252
sw: 0.09815950920245399 - f2: 0.09141621492364496
Recall: total - 0.2965139253290117 line average - 0.3976897902244021 - f2 - 0.3150750394395113
----------------Details----------------
Recall for language gu: 0.13392482478460782
Recall for language en: 0.4800535200680752
Recall for language es: 0.23902816934166288
Recall for language fr: 0.37369914135652016
Recall for language hi: 0.20570406445406447
Recall for language fil: 0.6710413080895008
Recall for language pt: 0.4516243259664313
Recall for language bn: 0.15274472473425355
Recall for language as: 0.17543859649

In [43]:
%%time
N_NEIGHBOR = 10
valid() # Recall: total - 0.49870318143289205 line average - 0.5800252324028706 - f2 - 0.3687912849848095


gu: 0.18110236220472442 - f2: 0.12461782405585814
en: 0.45137046861184793 - f2: 0.3788072441053249
es: 0.30230380221015174 - f2: 0.15236436008049092
fr: 0.36136662286465177 - f2: 0.3416748853066222
hi: 0.2564516129032258 - f2: 0.20835327388740452
fil: 0.6133333333333333 - f2: 0.43246466854956545
pt: 0.44221105527638194 - f2: 0.35114859783866736
bn: 0.16736401673640167 - f2: 0.09958438600261821
as: 0.2608695652173913 - f2: 0.17334751853675703
sw: 0.17177914110429449 - f2: 0.12809894234277586
Recall: total - 0.3972581053379692 line average - 0.48782840780643766 - f2 - 0.30465557136961435
----------------Details----------------
Recall for language gu: 0.18263597926112993
Recall for language en: 0.5892591211340802
Recall for language es: 0.28967090660601086
Recall for language fr: 0.46774943251723433
Recall for language hi: 0.2892824250912486
Recall for language fil: 0.7631812966150315
Recall for language pt: 0.5320184932027038
Recall for language bn: 0.19187688402348088
Recall for languag

In [44]:
%%time
N_NEIGHBOR = 20
valid() # Recall: total - 0.6107718178846423 line average - 0.6663599709233686 - f2 - 0.3068963565307041

gu: 0.22244094488188976 - f2: 0.11064560969141123
en: 0.5617313720761996 - f2: 0.32037751832679573
es: 0.37628769432477993 - f2: 0.1232934550448457
fr: 0.4704336399474376 - f2: 0.3131679522449647
hi: 0.3467741935483871 - f2: 0.20149420203731702
fil: 0.6622222222222223 - f2: 0.29587321981900294
pt: 0.5100502512562815 - f2: 0.295590830952673
bn: 0.22384937238493724 - f2: 0.08752327878636781
as: 0.37681159420289856 - f2: 0.16529631633798297
sw: 0.22085889570552147 - f2: 0.11786997724497723
Recall: total - 0.49419571020184905 line average - 0.5652046442179235 - f2 - 0.2564486935734837
----------------Details----------------
Recall for language gu: 0.2205070564451842
Recall for language en: 0.6823018792684383
Recall for language es: 0.3346764812169921
Recall for language fr: 0.5415687223685159
Recall for language hi: 0.38228698044874515
Recall for language fil: 0.7846528973034999
Recall for language pt: 0.588352035391509
Recall for language bn: 0.25287164842138665
Recall for language as: 0.

In [45]:
%%time
N_NEIGHBOR = 50
valid()

gu: 0.31299212598425197 - f2: 0.08589057552610163
en: 0.6924282613937787 - f2: 0.2128634386408154
es: 0.47874133732908786 - f2: 0.08313422904622883
fr: 0.59526938239159 - f2: 0.23554027319602194
hi: 0.4532258064516129 - f2: 0.1461030266808471
fil: 0.8355555555555556 - f2: 0.17918331731701276
pt: 0.5954773869346733 - f2: 0.20127385663394282
bn: 0.34518828451882844 - f2: 0.06622440610841619
as: 0.6086956521739131 - f2: 0.15822316465005876
sw: 0.37423312883435583 - f2: 0.11490463474931141
Recall: total - 0.616482377429802 line average - 0.6631366583259692 - f2 - 0.17175840061055478
----------------Details----------------
Recall for language gu: 0.2966699509588219
Recall for language en: 0.7841729773022118
Recall for language es: 0.42370964750662
Recall for language fr: 0.641633774420152
Recall for language hi: 0.48244150185326656
Recall for language fil: 0.8824727481353989
Recall for language pt: 0.6499800637958532
Recall for language bn: 0.34647786768205624
Recall for language as: 0.6386

In [46]:
%%time
N_NEIGHBOR = 100
valid()

gu: 0.3838582677165354 - f2: 0.06167456579286359
en: 0.7810465396672294 - f2: 0.13946628340349818
es: 0.5583442592245739 - f2: 0.05573199637761451
fr: 0.6885676741130092 - f2: 0.1692763514264554
hi: 0.535483870967742 - f2: 0.10042036444613164
fil: 0.9111111111111111 - f2: 0.10822716173038485
pt: 0.6582914572864321 - f2: 0.13413946372102253
bn: 0.4393305439330544 - f2: 0.046863186496774435
as: 0.8840579710144928 - f2: 0.1317272608638015
sw: 0.5214723926380368 - f2: 0.09768715558465284
Recall: total - 0.7027980985551304 line average - 0.7326370925224693 - f2 - 0.11337755570937746
----------------Details----------------
Recall for language gu: 0.3715427244957177
Recall for language en: 0.8462934159186151
Recall for language es: 0.5079205998340828
Recall for language fr: 0.7344904850838803
Recall for language hi: 0.5611461266544386
Recall for language fil: 0.943631669535284
Recall for language pt: 0.7006071048834205
Recall for language bn: 0.4111772171981596
Recall for language as: 0.89074

In [47]:
%%time
N_NEIGHBOR = 200
valid()

gu: 0.45866141732283466 - f2: 0.040237891384655645
en: 0.8550357688288722 - f2: 0.08462437563424251
es: 0.6486233377036899 - f2: 0.03528622085094099
fr: 0.7700394218134035 - f2: 0.10921526732968498
hi: 0.6080645161290322 - f2: 0.06295383842076561
fil: 0.9644444444444444 - f2: 0.0608427832418007
pt: 0.7462311557788944 - f2: 0.08523448395486598
bn: 0.5418410041841004 - f2: 0.031034916642755945
as: 0.9855072463768116 - f2: 0.08235975052351248
sw: 0.6809815950920245 - f2: 0.07194327263229684
Recall: total - 0.7821080813653752 line average - 0.8032735671638902 - f2 - 0.06944216324792452
----------------Details----------------
Recall for language gu: 0.45040878277117735
Recall for language en: 0.9013983612722337
Recall for language es: 0.616537514775718
Recall for language fr: 0.8061895722938034
Recall for language hi: 0.6269500017837613
Recall for language fil: 0.9759036144578314
Recall for language pt: 0.7775504101161995
Recall for language bn: 0.4842971600825004
Recall for language as: 0.

shuffle?

groupby lang shuffle

In [48]:

!free -h

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
              total        used        free      shared  buff/cache   available
Mem:          375Gi        56Gi        13Gi       1.2Gi       306Gi       315Gi
Swap:            0B          0B          0B


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
3.2G	/root/.local/share/Trash
