In [1]:
import os
import gc
import psutil
from pathlib import Path
import pathlib

import pandas as pd
import numpy as np
pd.set_option('display.max_rows', 100)
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import get_cosine_schedule_with_warmup, DataCollatorWithPadding

from scipy.spatial.distance import cosine
from torch.utils.data import DataLoader, Dataset, SequentialSampler
from sklearn.neighbors import NearestNeighbors

device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'

In [2]:
class PATH:
    model = '/root/autodl-nas/model/sentence-transformers/all-MiniLM-L6-v2_new_r1.1'
    model_cpt = '/root/autodl-nas/model/f0r3/checkpoint-8708'
    # input_dir = '/root/autodl-nas/data/k12'
    input_dir = '/root/autodl-tmp/data/k12/cv_split_new/valid/fold_4'
    
    output_dir = '/root/autodl-tmp/data/k12/out'
    cv_dir = '/root/autodl-tmp/data/k12/cv_split_new/valid/fold_4'
    # pretrained_dir = '/root/autodl-nas/model/'
    content_dir = os.path.join(input_dir, 'content.csv')
    correlation_dir = os.path.join(input_dir, 'correlations.csv')
    submission_dir = os.path.join(input_dir, 'sample_submission.csv')
    topic_dir = os.path.join(input_dir, 'topics.csv')
    
    
class CFG:
    seed = 11
    fold = 1
    n_fold = 3
    model_name = 'sentence-transformers/all-MiniLM-L6-v2'


## Lang

In [3]:
def get_level_features(df_topic, level_cols=['title']):
    df_hier = df_topic[list(set(level_cols + ['id', 'parent', 'level', 'has_content']))]
    highest_level = df_hier['level'].max()
    
    df_level = df_hier.query('level == 0').copy(deep=True)
    level_list = list()
    for col in level_cols:
        df_level[f'{col}_level'] = df_level[f'{col}'].apply(lambda x: [x])

    for i in tqdm(range(highest_level + 1)):
        level_list.append(df_level[df_level['has_content']])
        df_level_high = df_hier.query('level == @i+1')
        df_level = df_level_high.merge(df_level, left_on='parent', right_on='id', suffixes=['', '_parent'], how='inner')
        for col in level_cols:
            df_level[f'{col}_level'] = df_level[f'{col}_level'] + df_level[f'{col}'].apply(lambda x: [x])
        for col in df_level.columns:
            if col.endswith('_parent'):
                df_level.drop(columns=col, inplace=True)
    df = pd.concat(level_list).reset_index(drop=True)
    return df[list(set(['id'] + [f'{col}_level' for col in level_cols]))]

In [4]:
def get_topic_field(d):
    title = list(filter(lambda x: pd.notna(x), d['title_level']))
    title = ' of '.join(title[-1::-1])
    title = 'No information' if title=='' else title
    title = '[TITLE] ' + title + '. '
    description = d['description'] if pd.notna(d['description']) else 'No information'
    description = '[DESCRIPTION]' + description + '. '
    field = title + description
    return field

def get_content_field(d):
    title = d['title']
    title = 'No information' if pd.isna(title) else title
    title = '[TITLE] ' + title + '. '
    description = d['description'] if pd.notna(d['description']) else 'No information'
    description = '[DESCRIPTION]' + description + '. '
    kind = '[' + d['kind'] + '] '
    field = kind + title + description
    return field

In [5]:
def prepare_language_match(path, mode='train'):
    topic = pd.read_csv(path.topic_dir)[['id', 'language']]
    content = pd.read_csv(path.content_dir)[['id', 'language']]
    if mode == 'train':
        corr = pd.read_csv(path.correlation_dir)
    elif mode == 'valid':
        corr = pd.read_csv(path.submission_dir)
    
    topic = topic.merge(corr, left_on='id', right_on='topic_id', how='right')[['id', 'language']]
    match_dict = {}
    for language in topic['language'].unique():
        match_dict[language] = (topic.query('language==@language')[['id']], content.query('language==@language')[['id']])
    return match_dict

In [6]:
def prepare_match_features(topic, content, path):
    df_topic = pd.read_csv(path.topic_dir)
    df_content = pd.read_csv(path.content_dir)
    level = get_level_features(df_topic)
    df_topic = df_topic.merge(level, on='id', how='right')
    df_topic['field'] = df_topic.apply(lambda x: get_topic_field(x), axis=1)
    df_content['field'] = df_content.apply(lambda x: get_content_field(x), axis=1)
    topic = topic[['id']].merge(df_topic[['id', 'field']], on='id', how='left')
    content = content[['id']].merge(df_content[['id', 'field']], on='id', how='left')
    return topic, content

In [7]:
%%time
topic_content_match = prepare_language_match(PATH, mode='valid')

CPU times: user 9 s, sys: 1.08 s, total: 10.1 s
Wall time: 10.1 s


In [8]:
# %%time
for lang in topic_content_match.keys():
    print(f'{lang}\t - topics: {len(topic_content_match[lang][0])}\t - contents: {len(topic_content_match[lang][1])}')
#     topic, content = topic_content_match[lang]
#     topic, content = prepare_match_features(topic, content, PATH)

en	 - topics: 3272	 - contents: 65939
es	 - topics: 1297	 - contents: 30844
gu	 - topics: 85	 - contents: 3677
pt	 - topics: 43	 - contents: 10435
hi	 - topics: 76	 - contents: 4042
fr	 - topics: 79	 - contents: 10682
bn	 - topics: 83	 - contents: 2513
fil	 - topics: 38	 - contents: 516
sw	 - topics: 17	 - contents: 1447
as	 - topics: 10	 - contents: 641


## Calc Embeddings

In [9]:
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

CPT_PATH = PATH.model_cpt
MODEL_PATH = PATH.model
OUTPUT_PATH = os.path.join(PATH.cv_dir)
N_NEIGHBOR = 100

In [10]:
class PlainDataset(Dataset):

    def __init__(self, df, tokenizer, label_name="") -> None:
        super().__init__()
        self.data = df[label_name].tolist()
        self.tokenizer = tokenizer

    def __getitem__(self, index):
        text = self.data[index]
        inputs = self.tokenizer(
                text, 
                add_special_tokens = True,
                truncation='longest_first',
                max_length = 128,
                padding = 'max_length',
                return_attention_mask = True,
                return_tensors = 'pt',
        )
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        return inputs

    def __len__(self):
        return len(self.data)

In [11]:
class Convert2Embed(object):

    def __init__(self) -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
        self.model = AutoModel.from_pretrained(CPT_PATH).cuda()

    def convert2embeddind(self, df, label_name=""):
        embed: list = []
        dataset = PlainDataset(df, tokenizer=self.tokenizer, label_name=label_name)
        dataloader = DataLoader(dataset, sampler=SequentialSampler(dataset), batch_size=32)
        for batch in dataloader:
            batch = {k: v.cuda() for k, v in batch.items()}
            with torch.no_grad():
                embeddings = self.model(**batch, output_hidden_states=True, return_dict=True).pooler_output
                embed.append(embeddings.cpu().clone().detach().numpy())
        embed = np.concatenate(embed, axis=0)
        return embed

    def get_embed(self):
        for lang in topic_content_match.keys():
            #
            # topic, content = topic_content_match[lang]
            # topic, content = prepare_match_features(topic, content, PATH)
            # topic_path = os.path.join(OUTPUT_PATH, "valid", f"topic_{lang}.pqt")
            # content_path = os.path.join(OUTPUT_PATH, "valid", f"content_{lang}.pqt")
            # topic.to_parquet(topic_path)
            # content.to_parquet(content_path)
            #
            
            for t in ["content", "topic"]:
                path = os.path.join(OUTPUT_PATH, "valid", f"{t}_{lang}.pqt")
                df = pd.read_parquet(path)
                embed = self.convert2embeddind(df, label_name=f"field")
                np.save(path.replace(".pqt", ".npy"), embed)


In [12]:
def valid():
#     with open(os.path.join(OUTPUT_PATH, "valid", "language.txt"), "r") as f:
#         valid_language = f.read().splitlines()
    recall_amount = 0
    recall_amount_total = 0
    recall_num = 0
    recall_total = {}
    f2_sum = 0
    df_list = []
    global df_pred_all
    for lang in topic_content_match.keys():
        ## debug
        # global df_pred, df_correlations
        global df_pred
        content_path = os.path.join(OUTPUT_PATH, "valid", f"content_{lang}.npy")
        topics_path = os.path.join(OUTPUT_PATH, "valid", f"topic_{lang}.npy")
        correlations_path = PATH.submission_dir
        content_array = np.load(content_path)
        topics_array = np.load(topics_path)
        model = NearestNeighbors(n_neighbors=N_NEIGHBOR, metric="cosine")
        model.fit(content_array)
        d, r = model.kneighbors(topics_array)
        df_content = pd.read_parquet(content_path.replace(".npy", ".pqt"))
        df_topics = pd.read_parquet(topics_path.replace(".npy", ".pqt"))
        df_correlations = pd.read_csv(correlations_path).astype({"topic_id": str})
        
        pred = {"topic_id": [], "content_ids": [], 'dists': []}
        for i in range(len(df_topics)):
            r_t = r[i]
            tmp = []
            for c in r_t:
                content_id = df_content.iloc[c]["id"]
                tmp.append(content_id)
            topics_id = df_topics.iloc[i]["id"]
            pred["topic_id"].append(topics_id)
            pred["content_ids"].append(tmp)
            pred['dists'].append(d[i])
        
        df_pred = pd.DataFrame(pred).astype({"topic_id": str})
        
        df_correlations['content_ids'] = df_correlations['content_ids'].apply(lambda x: list(x.split()))
        df_pred = df_pred.merge(df_correlations, on='topic_id', how='left', suffixes=['_pred', '_true'])
        df_pred['num'] = df_pred.apply(lambda x: len(x['content_ids_true']), axis=1)
        df_pred['hit'] = df_pred.apply(lambda x: len(set(x['content_ids_true']).intersection(x['content_ids_pred'])), axis=1)
        df_pred['recall'] = df_pred.apply(lambda x: x['hit'] / len(x['content_ids_true']), axis=1)
        df_pred['precision'] = df_pred.apply(lambda x: x['hit'] / len(x['content_ids_pred']), axis=1)
        df_pred['f2'] = 5*df_pred['precision']*df_pred['recall']/(4*df_pred['precision']+df_pred['recall'])
        recall = df_pred['recall'].mean()
        recall_total[lang] = recall
        recall_num += len(df_pred)
        recall_amount += df_pred['recall'].sum()
        recall_amount_total += df_pred['hit'].sum()/df_pred['num'].sum() * len(df_pred)
        f2_sum += df_pred['f2'].sum()
        df_list.append(df_pred)
        print(f"{lang}: {df_pred['hit'].sum()/df_pred['num'].sum()} - f2: {df_pred['f2'].sum()/len(df_pred)}")
        
    df_pred_all = pd.concat(df_list)
    print(f"Recall: total - {recall_amount_total/recall_num} line average - {recall_amount/recall_num} - f2 - {f2_sum/recall_num}")
    print(f"----------------Details----------------")
    for k, v in recall_total.items():
        print(f"Recall for language {k}: {v}")


In [13]:
%%time
P = Convert2Embed()
P.get_embed()

Some weights of the model checkpoint at /root/autodl-nas/model/f0r3/checkpoint-4976 were not used when initializing BertModel: ['lm_head.transform.LayerNorm.weight', 'lm_head.transform.dense.weight', 'mlp.dense.weight', 'lm_head.transform.LayerNorm.bias', 'lm_head.transform.dense.bias', 'lm_head.decoder.weight', 'lm_head.bias', 'lm_head.decoder.bias', 'mlp.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at /root/autodl-nas/model/f0r3/checkpoint-4976 and are newly initialized: ['bert.pooler.dense.bias', '

CPU times: user 55.1 s, sys: 1min 12s, total: 2min 7s
Wall time: 2min 6s


In [14]:
def get_pred_by_thresh(df, thresh, n_default=1):
    pred_by_thresh = list()
    for i, d in enumerate(df['dists']):
        if d<=thresh:
            pred_by_thresh.append(df['content_ids_pred'][i])
    if not pred_by_thresh:
        pred_by_thresh = df['content_ids_pred'][:n_default]
    return pred_by_thresh

def calculate_f2(df_pred_all, true_label='content_ids_true', pred_label='pred_by_thresh') -> int:
    df_pred_all['hit'] = df_pred_all.apply(lambda x: len(set(x[pred_label]).intersection(x[true_label])), axis=1)
    df_pred_all['recall'] = df_pred_all.apply(lambda x: x['hit']/len(x[true_label]), axis=1)
    df_pred_all['precision'] = df_pred_all.apply(lambda x: x['hit']/len(x[pred_label]), axis=1)
    df_pred_all['f2'] = (5*df_pred_all['precision']*df_pred_all['recall']/(4*df_pred_all['precision']+df_pred_all['recall'])).fillna(0)
    return df_pred_all['f2'].mean()

def optimize_f2(df_pred_all, init_thresh=0.2, init_step=0.01) -> int:
    thresh = init_thresh
    df_pred_all['pred_by_thresh'] = df_pred_all.apply(lambda x: get_pred_by_thresh(x, 0.2), axis=1)
    f2 = calculate_f2(df_pred_all)

In [15]:
pd.read_csv(PATH.submission_dir).astype({"topic_id": str})

Unnamed: 0,topic_id,content_ids
0,t_0012a45fa09c,c_dde078b8ea7a
1,t_001bcbb22694,c_1d9dfc709413
2,t_00260f878951,c_86f126d7f1f8
3,t_003e944a4758,c_8f6966ad85f6
4,t_00535b89fd1d,c_189799d6b9a3 c_33e862f552f5 c_4f1702d3ffce c...
...,...,...
4995,t_ffba5459a977,c_04a421dba8aa c_787a7a2e7217 c_a46e0ec1377b c...
4996,t_ffc6ba0459d6,c_877d4e87d0e8
4997,t_ffc71a181765,c_d7be46af5e2b c_ea92fdd9899a
4998,t_ffdc013937fc,c_c27c5e711e25


In [16]:
# 275 383

In [17]:
# df_pred_all['pred_by_thresh'] = df_pred_all.apply(lambda x: get_pred_by_thresh(x, 0.4), axis=1)
# print(calculate_f2(df_pred_all))

# df_pred_all = df_pred_all[['topic_id', 'content_ids_pred', 'dists', 'content_ids_true']]
# # df_pred_all['labels'] = df_pred_all.apply(lambda x: label_pred(x), axis=1)
# for i in range(170, 450, 2):
#     thresh = i/1000
#     df_pred_all['pred_by_thresh'] = df_pred_all.apply(lambda x: get_pred_by_thresh(x, thresh, 3), axis=1)
#     f2 = calculate_f2(df_pred_all)
#     print(i, f2)

# df_pred_all.explode(['content_ids_pred', 'dists', 'labels'])

170 0.4024688386835773
172 0.4031549520762678
174 0.40409479083512645
176 0.40498780130113327
178 0.4056221467042033
180 0.40624726560527924
182 0.40721841030008704
184 0.409052690435736
186 0.41030785013923904
188 0.4105493930238
190 0.4121692194096475
192 0.4132842857948191
194 0.4144896801795744
196 0.41615635613829677
198 0.41766006453443766
200 0.41965067582784593
202 0.4202251120409235
204 0.42037800850695833
206 0.4214418561517112
208 0.4235262148368645
210 0.42532777889037154
212 0.42699839812736545
214 0.42848119762586157
216 0.43026228612064904
218 0.43200040637419174
220 0.4332824488732244
222 0.43483586366382354
224 0.4367978856794664
226 0.43781717992134717
228 0.4391767284118734
230 0.44009180614394006
232 0.4417736674332854
234 0.4430857341640051
236 0.44381462956587736
238 0.44525578177706604
240 0.4471523710828244
242 0.44870426476357084
244 0.4497284873353738
246 0.45163057757041214
248 0.4530941896947352
250 0.4548706386758581
252 0.45620181582970554
254 0.45695022797459456
256 0.4576392577793879
258 0.4577774966032547
260 0.45801281237339675
262 0.45868644081494886
264 0.458481439820201
266 0.45903038240513294
268 0.4595054445782289
270 0.4594602890972266
272 0.4594325829137616
274 0.45987696587445226
276 0.4602119501917246
278 0.46098732313376983
280 0.46121830781870066
282 0.46157880182145566
284 0.460916378887469
286 0.46063599518269166
288 0.46015462896328635
290 0.46026819515777456
292 0.46005084723586476
294 0.45913039790317867
296 0.45844901040602476
298 0.45750240076191034
300 0.45692206443176586
302 0.45608328843132245
304 0.4549603767536112
306 0.45329443276876796
308 0.4517214889907147
310 0.45083376456284796
312 0.44961959295741877
314 0.448034045429716
316 0.4461080456862716
318 0.44414396537646245
320 0.4425980612628054
322 0.44092277734608776
324 0.4389668410680323
326 0.4367121022990652
328 0.43420323722183946
330 0.4315968321759618
332 0.4294783204714232
334 0.4271767339470532
336 0.42463422374234044
338 0.42189842384377607
340 0.41922310125359186
342 0.4163151008820663
344 0.4132511312819868
346 0.4100375863146136
348 0.4068978474694664
350 0.4041449159496747
352 0.4010115542362085
354 0.3977875787529075
356 0.3948931385115175
358 0.3913653717224516
360 0.3877346850668936
362 0.38452566672457694
364 0.38094351006016375
366 0.37740241893768
368 0.3736814676972883
370 0.36987613243759815
372 0.36632612945386933
374 0.36227639122395994
376 0.35868013462650117
378 0.35507737565132824
380 0.3512169514101543
382 0.3472301898209479
384 0.343266006151315
386 0.3389849776348872
388 0.33505053612718283
390 0.33077591982865007
392 0.32659240791737504
394 0.3223306480778996
396 0.3180902050053218
398 0.3139647617144388
400 0.30977561550307586
402 0.3056369682307033
404 0.3015421054155209
406 0.29711072715398107
408 0.2929429504219797
410 0.2888552872196978
412 0.2845286862129868
414 0.28014803099027047
416 0.2759176750171542
418 0.2719160264420336
420 0.26765219196053575
422 0.2634692270432641
424 0.2592592306452557
426 0.25531781926012703
428 0.25119991693635824
430 0.24713812191438406
432 0.24296433979588497
434 0.23907443839564776
436 0.23522140710825673
438 0.2312847831007713
440 0.2273921890591881
442 0.22357551167977419
444 0.2195538619615006
​

In [18]:
%%time
N_NEIGHBOR = 6
valid()

en: 0.4559003806697206 - f2: 0.4833704268895825
es: 0.35022450288646567 - f2: 0.21913522024447865
gu: 0.2724252491694352 - f2: 0.23763295528001407
pt: 0.44 - f2: 0.42079420714530763
hi: 0.3786127167630058 - f2: 0.3801770600084109
fr: 0.35795454545454547 - f2: 0.40121214750673617
bn: 0.3130841121495327 - f2: 0.2769304323521191
fil: 0.5663716814159292 - f2: 0.5758738277919866
sw: 0.21428571428571427 - f2: 0.23828981097132673
as: 0.13043478260869565 - f2: 0.12142857142857144
Recall: total - 0.41950633154210637 line average - 0.5303243860248875 - f2 - 0.40296444592032865
----------------Details----------------
Recall for language en: 0.6203178383818506
Recall for language es: 0.32840179181736623
Recall for language gu: 0.2917647058823529
Recall for language pt: 0.5617571059431524
Recall for language hi: 0.4580357142857143
Recall for language fr: 0.46225457962241284
Recall for language bn: 0.4523001095290253
Recall for language fil: 0.812155388471178
Recall for language sw: 0.31280255692020

In [19]:
%%time
N_NEIGHBOR = 5
valid()

en: 0.4160913607329505 - f2: 0.47876447655758103
es: 0.3197562540089801 - f2: 0.21829933678309532
gu: 0.2425249169435216 - f2: 0.22790988065019624
pt: 0.4177777777777778 - f2: 0.42992759881479914
hi: 0.3439306358381503 - f2: 0.3616690479796676
fr: 0.32386363636363635 - f2: 0.3909319033061913
bn: 0.29439252336448596 - f2: 0.2914584729633973
fil: 0.5575221238938053 - f2: 0.6098800372847561
sw: 0.17857142857142858 - f2: 0.21187478274665966
as: 0.13043478260869565 - f2: 0.13247863247863248
Recall: total - 0.38328765880159 line average - 0.49849305861456544 - f2 - 0.39963489080405784
----------------Details----------------
Recall for language en: 0.5842774428277642
Recall for language es: 0.3050552933744915
Recall for language gu: 0.26640522875816997
Recall for language pt: 0.5471576227390181
Recall for language hi: 0.4065737259816207
Recall for language fr: 0.43456987704940053
Recall for language bn: 0.4427710843373494
Recall for language fil: 0.7989974937343358
Recall for language sw: 0.2

In [20]:
%%time
N_NEIGHBOR = 10
valid() # Recall: total - 0.5210813346754415 line average - 0.6065079786898661 - f2 - 0.3860233037535843

en: 0.5608103748628944 - f2: 0.46272820117882496
es: 0.4448364336112893 - f2: 0.20802799228846244
gu: 0.3488372093023256 - f2: 0.24870841676919606
pt: 0.5555555555555556 - f2: 0.43300800312186266
hi: 0.476878612716763 - f2: 0.38613192168729654
fr: 0.44507575757575757 - f2: 0.4147858252609478
bn: 0.4439252336448598 - f2: 0.28472678330934886
fil: 0.5929203539823009 - f2: 0.4612155794482969
sw: 0.3125 - f2: 0.2741075286584825
as: 0.17391304347826086 - f2: 0.11904761904761904
Recall: total - 0.5206593220636823 line average - 0.6067739233780436 - f2 - 0.3865482930540207
----------------Details----------------
Recall for language en: 0.7038321016058207
Recall for language es: 0.3824609009169534
Recall for language gu: 0.37594771241830066
Recall for language pt: 0.6714285714285715
Recall for language hi: 0.5526733500417711
Recall for language fr: 0.5476154750354304
Recall for language bn: 0.5780941949616649
Recall for language fil: 0.8304577232555072
Recall for language sw: 0.3969528837175896

In [21]:
%%time
N_NEIGHBOR = 20
valid() # Recall: total - 0.6391622861430047 line average - 0.6950282194812909 - f2 - 0.32047399962118284


en: 0.6777211432995677 - f2: 0.38085379935036406
es: 0.5657472738935215 - f2: 0.17824560676119847
gu: 0.42524916943521596 - f2: 0.21000445632798578
pt: 0.6577777777777778 - f2: 0.3787142947112167
hi: 0.5982658959537572 - f2: 0.3362397030643935
fr: 0.571969696969697 - f2: 0.39182906401298795
bn: 0.6495327102803738 - f2: 0.2591635481319216
fil: 0.6814159292035398 - f2: 0.32920013100934165
sw: 0.5089285714285714 - f2: 0.3084005455328985
as: 0.391304347826087 - f2: 0.15610119047619048
Recall: total - 0.6397464165142329 line average - 0.6983979313213164 - f2 - 0.32176119820923443
----------------Details----------------
Recall for language en: 0.7888184412213224
Recall for language es: 0.48446455775869807
Recall for language gu: 0.4392623716153128
Recall for language pt: 0.7819029162052417
Recall for language hi: 0.6489087301587301
Recall for language fr: 0.6703501328929474
Recall for language bn: 0.7444786137557221
Recall for language fil: 0.8787033372905949
Recall for language sw: 0.569149

In [22]:
%%time
N_NEIGHBOR = 50 # 0.7559275033172171 line average - 0.7874367900873033 - f2 - 0.20525149868153938
valid()

en: 0.7930189044454481 - f2: 0.2409580600065519
es: 0.6706221937139192 - f2: 0.11463885439010225
gu: 0.5249169435215947 - f2: 0.13933127963978473
pt: 0.7511111111111111 - f2: 0.25180457217313895
hi: 0.7427745664739884 - f2: 0.2270816321019149
fr: 0.6988636363636364 - f2: 0.28224314086587227
bn: 0.8598130841121495 - f2: 0.17045192995076275
fil: 0.7964601769911505 - f2: 0.18613522878495656
sw: 0.7232142857142857 - f2: 0.25994204921657704
as: 0.6956521739130435 - f2: 0.13098975345916392
Recall: total - 0.7548025580395137 line average - 0.7866962783619373 - f2 - 0.20525541905547828
----------------Details----------------
Recall for language en: 0.8624029975351576
Recall for language es: 0.6012772964358071
Recall for language gu: 0.5270028011204482
Recall for language pt: 0.8447120708748614
Recall for language hi: 0.7577798663324978
Recall for language fr: 0.7857700484617834
Recall for language bn: 0.9028154576347346
Recall for language fil: 0.9396121883656509
Recall for language sw: 0.8010

In [23]:
%%time
N_NEIGHBOR = 100
valid() # 0.8259401867088655 line average - 0.8492008962245811 - f2 - 0.13077495168822426

en: 0.855539067036583 - f2: 0.1522766517385763
es: 0.7517639512508018 - f2: 0.07455100386142097
gu: 0.5913621262458472 - f2: 0.08951118714790071
pt: 0.8533333333333334 - f2: 0.1722736989288002
hi: 0.8352601156069365 - f2: 0.14956109219285033
fr: 0.7878787878787878 - f2: 0.19300351114150155
bn: 0.9299065420560748 - f2: 0.10320044239721249
fil: 0.7964601769911505 - f2: 0.10352226326069704
sw: 0.875 - f2: 0.19115635721928978
as: 0.9565217391304348 - f2: 0.09890109890109891
Recall: total - 0.8237861852632786 line average - 0.8478493721424663 - f2 - 0.1306620297259688
----------------Details----------------
Recall for language en: 0.9003655397521046
Recall for language es: 0.7206168320822046
Recall for language gu: 0.5988982259570494
Recall for language pt: 0.9080472499077151
Recall for language hi: 0.8248955722639933
Recall for language fr: 0.8420515941960468
Recall for language bn: 0.9498174516246805
Recall for language fil: 0.9396121883656509
Recall for language sw: 0.9061265531853767
Re

In [138]:
%%time
N_NEIGHBOR = 200
valid()

en: 0.9058003742176914 - f2: 0.09010325178925496
es: 0.8277742142398974 - f2: 0.04475314405750251
gu: 0.6411960132890365 - f2: 0.052337540522346473
pt: 0.9111111111111111 - f2: 0.1035415025895308
hi: 0.9017341040462428 - f2: 0.08964752868743232
fr: 0.8333333333333334 - f2: 0.11687854124602665
bn: 0.9719626168224299 - f2: 0.057759794506208405
fil: 0.9292035398230089 - f2: 0.06261366114090322
sw: 0.9821428571428571 - f2: 0.12363072381358618
as: 0.9565217391304348 - f2: 0.05205115592879189
Recall: total - 0.8815381644259791 line average - 0.8968322742628836 - f2 - 0.07752117522458835
----------------Details----------------
Recall for language en: 0.9345313642413996
Recall for language es: 0.8091058050500656
Recall for language gu: 0.6790756302521008
Recall for language pt: 0.9478036175710594
Recall for language hi: 0.8687238930659984
Recall for language fr: 0.8750671878519979
Recall for language bn: 0.9728915662650602
Recall for language fil: 0.9759002770083102
Recall for language sw: 0.9

shuffle?

groupby lang shuffle

In [139]:

!free -h

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
              total        used        free      shared  buff/cache   available
Mem:          755Gi        68Gi        17Gi       4.0Gi       670Gi       678Gi
Swap:         1.9Gi        37Mi       1.9Gi
