In [1]:
import pandas as pd
import glob
from sklearn.metrics import classification_report, precision_recall_fscore_support

from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F

In [2]:
%%time
metrics = {
    'task': [], 'task_lang': [], 'prompt_lang': [], 'prompt': [], 'model': [], 'acc': [],
    'macro_f1': [], 'macro_prec': [], 'macro_rec': [],
    'micro_f1': [], 'micro_prec': [], 'micro_rec': [],
    'weighted_f1': [], 'weighted_prec': [], 'weighted_rec': [],
}
for path in glob.glob('../outputs/*.csv'):
    if '_nusantara_text_' in path:
        task, meta = path.split('/')[-1].split('_nusantara_text_')
        try:
            prompt, model = meta[:-4].split('_')
        except:
            prompt, model, rehearse = meta[:-4].split('_')
            model = f'{model}-{rehearse}'
    elif  '_nusantara_pairs_' in path:
        task, meta = path.split('/')[-1].split('_nusantara_pairs_')
        try:
            prompt, model = meta[:-4].split('_')
        except:
            prompt, model, rehearse = meta[:-4].split('_')
            model = f'{model}-{rehearse}'
    elif 'nusa_kalimat_' in path:
        try:
            task, lang, prompt, model = path.split('/')[-1][:-4].split('nusa_kalimat_')[1].split('_')
        except:
            task, lang, prompt, model, rehearse = path.split('/')[-1][:-4].split('nusa_kalimat_')[1].split('_')
            model = f'{model}-{rehearse}'
        task = f'nusa_kalimat_{task}_{lang}'
    elif 'xnli' in path:
        try:
            task, task_lang, prompt, model = path.split('/')[-1][:-4].split('_')
        except:
            task, task_lang, prompt, model, rehearse = path.split('/')[-1][:-4].split('_')
            model = f'{model}-{rehearse}'
        task = f'{task}_{task_lang}'
    
    df = pd.read_csv(path)
    cls_report = classification_report(df['Gold'], df['Pred'], output_dict=True, zero_division=0)
    micro_f1, micro_prec, micro_rec, _ = precision_recall_fscore_support(df['Gold'], df['Pred'], average='micro')

    metrics['task'].append(task)
    metrics['task_lang'].append(task.split('_')[-1])
    metrics['prompt_lang'].append(prompt[:2])
    metrics['prompt'].append(prompt)
    metrics['model'].append(model)
    metrics['acc'].append(cls_report['accuracy'])
    metrics['macro_f1'].append(cls_report['macro avg']['f1-score'])
    metrics['macro_prec'].append(cls_report['macro avg']['precision'])
    metrics['macro_rec'].append(cls_report['macro avg']['recall'])
    metrics['micro_f1'].append(micro_f1)
    metrics['micro_prec'].append(micro_prec)
    metrics['micro_rec'].append(micro_rec)
    metrics['weighted_f1'].append(cls_report['weighted avg']['f1-score'])
    metrics['weighted_prec'].append(cls_report['weighted avg']['precision'])
    metrics['weighted_rec'].append(cls_report['weighted avg']['recall'])

CPU times: user 1min 36s, sys: 10.1 s, total: 1min 47s
Wall time: 2min 37s


In [3]:
mdf = pd.DataFrame(metrics)
mdf.groupby(['model','prompt_lang']).size()

model                     prompt_lang
bilingual-bloomz-560m     EN             165
                          ID             120
bloomz-1b1                EN             165
                          ID             120
bloomz-1b7                EN             165
                          ID             120
bloomz-3b                 EN             165
                          ID             120
bloomz-560m               EN             165
                          ID             120
checkpoint-19480          EN             114
                          ID             114
checkpoint-29220          EN             114
                          ID             114
checkpoint-38960          EN             114
                          ID             114
checkpoint-9740           EN             114
                          ID             114
monolingual-bloomz-560m   EN             165
                          ID             120
pair-bloomz-1b1           EN             120
                 

In [4]:
mdf.loc[
    (~mdf['task'].str.startswith('xnli')) & ~mdf['task'].isin(['su_emot', 'indonli', 'indolem_sentiment', 'smsa', 'emot', 'imdb_jv'])
,:].groupby(['model','prompt_lang']).size()

model                     prompt_lang
bilingual-bloomz-560m     EN             102
                          ID             102
bloomz-1b1                EN             102
                          ID             102
bloomz-1b7                EN             102
                          ID             102
bloomz-3b                 EN             102
                          ID             102
bloomz-560m               EN             102
                          ID             102
checkpoint-19480          EN             102
                          ID             102
checkpoint-29220          EN             102
                          ID             102
checkpoint-38960          EN             102
                          ID             102
checkpoint-9740           EN             102
                          ID             102
monolingual-bloomz-560m   EN             102
                          ID             102
pair-bloomz-1b1           EN             102
                 

In [5]:
mdf.loc[
    (~mdf['task'].str.startswith('xnli')) & ~mdf['task'].isin(['su_emot', 'indonli', 'indolem_sentiment', 'smsa', 'emot', 'imdb_jv'])
].to_csv('raw_result.csv', index=False)

In [6]:
mdf.loc[mdf['task'].str.startswith('nusa'),:].groupby(['prompt_lang', 'model'])[['acc','macro_f1','micro_f1','weighted_f1']].mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,acc,macro_f1,micro_f1,weighted_f1
prompt_lang,model,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
EN,bilingual-bloomz-560m,0.412686,0.296968,0.412686,0.339234
EN,bloomz-1b1,0.437423,0.3363,0.437423,0.364528
EN,bloomz-1b7,0.479848,0.330984,0.479848,0.380157
EN,bloomz-3b,0.494082,0.355584,0.494082,0.407582
EN,bloomz-560m,0.377802,0.281763,0.377802,0.298149
EN,checkpoint-19480,0.341641,0.226884,0.341641,0.228843
EN,checkpoint-29220,0.35745,0.256799,0.35745,0.26692
EN,checkpoint-38960,0.341225,0.238079,0.341225,0.237882
EN,checkpoint-9740,0.378706,0.267356,0.378706,0.282995
EN,monolingual-bloomz-560m,0.396675,0.286992,0.396675,0.329202


### Related Work
#### Instruction-TUning 

#### Bilingual Alignment

#### Continual Learning

### Method
#### Instruct-Augment
- Token-Level language denoising (TLD) -> monolingual sentence completion
- Token-Level language alignment (TLA) -> contextualized sentence completion
- Sentence-level language alignment (SLA) -> machine translation
- In addition, we explore pair language alignment (PLA) method which combination of both contextualized sentence completion and machine translation, and monolingual-bilingual learning (MBL), which is a combination of all three methods.

#### Continual Learning -> Simple Rehearsal
- R-100, R-1000, R-10000

### Experiment
- 2000 parallel data for learning new languages
- We use instruction-tuned model, BLOOMZ, with only 560M parameters due to limited time and resource.
- We cover 6 languages that is not pre-trained on BLOOMZ, i.e., A, B, C, D, E, F.

### Results & Discussion
#### No Continual Learning
- Among the 3 basic methods (TLD, TLA, and SLA), TLA performs the best over all downstream tasks
- In terms of generalization to other languages, <Can-it-helps?>
- We observe forgetting problems on the other pre-learnt languages, i.e. English and Indonesian, we suspect this happens due to catastrophic forgetting

#### With Continual Learning
- 100 and 10,000 fails to improve the performance, but 1,000 improve the performance consistently on almost all tasks
   - Requires balancing between new tasks and old tasks.
- 

### Future Works
- Scale upp to larger size models
- 

# Analyze model per task group

In [7]:
mdf.loc[mdf['task'].str.startswith('xnli_'),:].groupby(['model'])[['acc','macro_f1','micro_f1','weighted_f1']].mean()

Unnamed: 0_level_0,acc,macro_f1,micro_f1,weighted_f1
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
bilingual-bloomz-560m,0.334527,0.171678,0.334527,0.171678
bloomz-1b1,0.336177,0.178051,0.336177,0.178051
bloomz-1b7,0.354185,0.218703,0.354185,0.218703
bloomz-3b,0.356363,0.221926,0.356363,0.221926
bloomz-560m,0.340306,0.197887,0.340306,0.197887
monolingual-bloomz-560m,0.335338,0.186514,0.335338,0.186514
translation-bloomz-560m,0.334411,0.172372,0.334411,0.172372


In [8]:
mdf.loc[
    (mdf['task'].str.startswith('nusax'))
,:].groupby(['model'])[['acc','macro_f1','micro_f1','weighted_f1']].mean()

Unnamed: 0_level_0,acc,macro_f1,micro_f1,weighted_f1
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
bilingual-bloomz-560m,0.416007,0.3106,0.416007,0.333499
bloomz-1b1,0.513472,0.379785,0.513472,0.432903
bloomz-1b7,0.502708,0.359108,0.502708,0.405421
bloomz-3b,0.566771,0.421318,0.566771,0.480302
bloomz-560m,0.443611,0.314911,0.443611,0.35624
checkpoint-19480,0.403958,0.253971,0.403958,0.278848
checkpoint-29220,0.395556,0.256838,0.395556,0.289581
checkpoint-38960,0.391979,0.262773,0.391979,0.291583
checkpoint-9740,0.414931,0.282982,0.414931,0.312247
monolingual-bloomz-560m,0.368681,0.248469,0.368681,0.276945


In [9]:
mdf.loc[mdf['task'].str.startswith('nusa_kalimat_emot'),:].groupby(['model'])[['acc','macro_f1','micro_f1','weighted_f1']].mean()

Unnamed: 0_level_0,acc,macro_f1,micro_f1,weighted_f1
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
bilingual-bloomz-560m,0.195424,0.085927,0.195424,0.083773
bloomz-1b1,0.236643,0.113682,0.236643,0.115983
bloomz-1b7,0.265237,0.144265,0.265237,0.158464
bloomz-3b,0.272423,0.155869,0.272423,0.168327
bloomz-560m,0.239758,0.148607,0.239758,0.155788
checkpoint-19480,0.225729,0.094887,0.225729,0.104563
checkpoint-29220,0.212524,0.09885,0.212524,0.104133
checkpoint-38960,0.208306,0.108513,0.208306,0.111586
checkpoint-9740,0.233941,0.088864,0.233941,0.099745
monolingual-bloomz-560m,0.186549,0.094371,0.186549,0.090209


In [10]:
mdf.loc[mdf['task'].str.startswith('nusa_kalimat_senti'),:].groupby(['model'])[['acc','macro_f1','micro_f1','weighted_f1']].mean()

Unnamed: 0_level_0,acc,macro_f1,micro_f1,weighted_f1
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
bilingual-bloomz-560m,0.482096,0.418283,0.482096,0.4435
bloomz-1b1,0.574394,0.503596,0.574394,0.544567
bloomz-1b7,0.58954,0.504131,0.58954,0.540276
bloomz-3b,0.646664,0.559592,0.646664,0.626146
bloomz-560m,0.429758,0.420991,0.429758,0.411413
checkpoint-19480,0.34452,0.301603,0.34452,0.239669
checkpoint-29220,0.361694,0.334292,0.361694,0.286725
checkpoint-38960,0.370402,0.350739,0.370402,0.3125
checkpoint-9740,0.397725,0.357907,0.397725,0.323065
monolingual-bloomz-560m,0.47071,0.405242,0.47071,0.42141


In [12]:
mdf.loc[
    (~mdf['task'].str.startswith('xnli')) & (~mdf['task'].str.startswith('nusa')),
:].groupby(['model','prompt_lang'])[['acc','macro_f1','micro_f1','weighted_f1']].mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,acc,macro_f1,micro_f1,weighted_f1
model,prompt_lang,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
bilingual-bloomz-560m,EN,0.398891,0.278782,0.398891,0.306586
bilingual-bloomz-560m,ID,0.364847,0.264173,0.364847,0.273851
bloomz-1b1,EN,0.462512,0.327055,0.462512,0.366678
bloomz-1b1,ID,0.490545,0.373169,0.490545,0.409193
bloomz-1b7,EN,0.455369,0.320702,0.455369,0.361813
bloomz-1b7,ID,0.510823,0.424169,0.510823,0.45468
bloomz-3b,EN,0.464366,0.337186,0.464366,0.376435
bloomz-3b,ID,0.5189,0.431164,0.5189,0.463574
bloomz-560m,EN,0.455505,0.335674,0.455505,0.365215
bloomz-560m,ID,0.452641,0.373856,0.452641,0.402889


### Analyze language per task group

In [14]:
mdf.loc[(mdf['task'].str.startswith('xnli')),:].groupby(
    ['model','prompt_lang','task']
).mean().reset_index().pivot(
    ['model','prompt_lang'],'task', ['acc']
    # ['model','lang'],'task', ['acc', 'macro_f1', 'micro_f1','weighted_f1']
)

Unnamed: 0_level_0,Unnamed: 1_level_0,acc,acc,acc,acc,acc,acc,acc,acc,acc,acc,acc,acc,acc,acc,acc
Unnamed: 0_level_1,task,xnli_ar,xnli_bg,xnli_de,xnli_el,xnli_en,xnli_es,xnli_fr,xnli_hi,xnli_ru,xnli_sw,xnli_th,xnli_tr,xnli_ur,xnli_vi,xnli_zh
model,prompt_lang,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2
bilingual-bloomz-560m,EN,0.333267,0.333333,0.333333,0.333267,0.333333,0.336261,0.333666,0.3332,0.333466,0.341118,0.333333,0.333866,0.3334,0.338523,0.334531
bloomz-1b1,EN,0.336194,0.334464,0.334664,0.335729,0.33992,0.337525,0.337791,0.337591,0.334331,0.336061,0.333599,0.3334,0.335529,0.33992,0.335928
bloomz-1b7,EN,0.360546,0.342182,0.353027,0.343713,0.374983,0.370459,0.368197,0.348237,0.34837,0.352229,0.339055,0.339987,0.341251,0.369661,0.360878
bloomz-3b,EN,0.367066,0.343846,0.350499,0.344844,0.37811,0.378044,0.374917,0.35509,0.352761,0.345842,0.339321,0.340785,0.346906,0.36487,0.362542
bloomz-560m,EN,0.337392,0.334265,0.338257,0.336128,0.347505,0.35489,0.351763,0.337924,0.336394,0.33513,0.339787,0.333533,0.334997,0.345576,0.341051
monolingual-bloomz-560m,EN,0.335263,0.338989,0.335595,0.336128,0.333533,0.333533,0.333666,0.336527,0.333932,0.336128,0.333333,0.337991,0.335529,0.335329,0.334597
translation-bloomz-560m,EN,0.333866,0.333466,0.333599,0.333666,0.342582,0.3334,0.334597,0.335662,0.334597,0.333333,0.333866,0.333333,0.333067,0.333466,0.333666


In [15]:
mdf.loc[(mdf['task'].str.startswith('nusax_senti')),:].groupby(
    ['model','prompt_lang','task']
).mean().reset_index().pivot(
    ['model','prompt_lang'],'task', ['macro_f1']
    # ['model','lang'],'task', ['acc', 'macro_f1', 'micro_f1','weighted_f1']
)

Unnamed: 0_level_0,Unnamed: 1_level_0,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1
Unnamed: 0_level_1,task,nusax_senti_ace,nusax_senti_ban,nusax_senti_bbc,nusax_senti_bjn,nusax_senti_bug,nusax_senti_eng,nusax_senti_ind,nusax_senti_jav,nusax_senti_mad,nusax_senti_min,nusax_senti_nij,nusax_senti_sun
model,prompt_lang,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
bilingual-bloomz-560m,EN,0.315518,0.275554,0.27888,0.303059,0.313309,0.309889,0.340177,0.313045,0.310223,0.315155,0.286747,0.321823
bilingual-bloomz-560m,ID,0.336973,0.31006,0.284049,0.26264,0.339395,0.307028,0.358439,0.287833,0.338471,0.377364,0.304228,0.264546
bloomz-1b1,EN,0.380609,0.378106,0.234341,0.392077,0.273218,0.505247,0.513932,0.408898,0.367496,0.374212,0.357249,0.343619
bloomz-1b1,ID,0.352163,0.353032,0.286701,0.39459,0.305605,0.511367,0.519056,0.420736,0.377147,0.380827,0.358189,0.326425
bloomz-1b7,EN,0.304023,0.305983,0.190856,0.351441,0.188839,0.535469,0.525077,0.374699,0.242941,0.326095,0.253138,0.258653
bloomz-1b7,ID,0.371426,0.365934,0.305545,0.422342,0.294169,0.567076,0.575015,0.374032,0.3525,0.422248,0.36784,0.34325
bloomz-3b,EN,0.428186,0.42456,0.284666,0.485016,0.204782,0.549672,0.554653,0.476144,0.341572,0.448357,0.365394,0.345356
bloomz-3b,ID,0.445181,0.420825,0.34044,0.473319,0.317071,0.538665,0.547843,0.459485,0.385363,0.46111,0.412813,0.401158
bloomz-560m,EN,0.274943,0.284812,0.259356,0.326441,0.244536,0.510651,0.488298,0.315051,0.288302,0.34499,0.334962,0.286018
bloomz-560m,ID,0.267869,0.275729,0.281985,0.279106,0.269705,0.382463,0.407593,0.296512,0.262206,0.288725,0.302909,0.284699


In [17]:
mdf.loc[(mdf['task'].str.startswith('nusa_kalimat_emot')),:].groupby(
    ['model','prompt_lang','task']
).mean().reset_index().pivot(
    ['model','prompt_lang'],'task', ['macro_f1']
    # ['model','lang'],'task', ['acc', 'macro_f1', 'micro_f1','weighted_f1']
)

Unnamed: 0_level_0,Unnamed: 1_level_0,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1
Unnamed: 0_level_1,task,nusa_kalimat_emot_abs,nusa_kalimat_emot_bew,nusa_kalimat_emot_bhp,nusa_kalimat_emot_btk,nusa_kalimat_emot_jav,nusa_kalimat_emot_mad,nusa_kalimat_emot_mak,nusa_kalimat_emot_min,nusa_kalimat_emot_mui,nusa_kalimat_emot_rej,nusa_kalimat_emot_sun
model,prompt_lang,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2
bilingual-bloomz-560m,EN,0.084318,0.073395,0.080324,0.080378,0.077748,0.074239,0.07802,0.081006,0.081779,0.082591,0.079572
bilingual-bloomz-560m,ID,0.089178,0.097604,0.086638,0.086288,0.116028,0.092074,0.081392,0.083563,0.084014,0.077163,0.123082
bloomz-1b1,EN,0.069118,0.069421,0.073321,0.069421,0.069421,0.069421,0.069421,0.069421,0.072131,0.079715,0.069421
bloomz-1b1,ID,0.182568,0.151744,0.154125,0.144808,0.13452,0.186581,0.130951,0.134195,0.20111,0.129159,0.171014
bloomz-1b7,EN,0.069118,0.069421,0.073321,0.070018,0.069421,0.069421,0.069421,0.069421,0.072131,0.079715,0.070264
bloomz-1b7,ID,0.243732,0.231712,0.20025,0.206199,0.204042,0.200819,0.1704,0.22742,0.3044,0.191295,0.211891
bloomz-3b,EN,0.072755,0.070554,0.073366,0.069421,0.070018,0.069421,0.069421,0.069957,0.07347,0.079715,0.069421
bloomz-3b,ID,0.275993,0.267299,0.216056,0.240311,0.233591,0.23436,0.185902,0.242028,0.302838,0.193764,0.249458
bloomz-560m,EN,0.076038,0.072969,0.07883,0.072351,0.07359,0.078075,0.071022,0.07174,0.074063,0.089272,0.073781
bloomz-560m,ID,0.233863,0.230298,0.21659,0.207444,0.211685,0.210093,0.181321,0.230795,0.29426,0.195955,0.225312


In [19]:
mdf.loc[(mdf['task'].str.startswith('nusa_kalimat_senti')),:].groupby(
    ['model','prompt_lang','task']
).mean().reset_index().pivot(
    ['model','prompt_lang'],'task', ['macro_f1']
    # ['model','lang'],'task', ['acc', 'macro_f1', 'micro_f1','weighted_f1']
)

Unnamed: 0_level_0,Unnamed: 1_level_0,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1,macro_f1
Unnamed: 0_level_1,task,nusa_kalimat_senti_abs,nusa_kalimat_senti_bew,nusa_kalimat_senti_bhp,nusa_kalimat_senti_btk,nusa_kalimat_senti_jav,nusa_kalimat_senti_mad,nusa_kalimat_senti_mak,nusa_kalimat_senti_min,nusa_kalimat_senti_mui,nusa_kalimat_senti_rej,nusa_kalimat_senti_sun
model,prompt_lang,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2
bilingual-bloomz-560m,EN,0.522304,0.51369,0.476884,0.503581,0.501866,0.500476,0.471746,0.512422,0.52964,0.498079,0.509487
bilingual-bloomz-560m,ID,0.352246,0.318484,0.309224,0.328738,0.288265,0.351379,0.30171,0.384188,0.358944,0.38093,0.287937
bloomz-1b1,EN,0.566364,0.626431,0.453768,0.490919,0.582676,0.490652,0.491664,0.651486,0.670936,0.512261,0.587814
bloomz-1b1,ID,0.464774,0.473891,0.422487,0.441964,0.448232,0.417179,0.397657,0.491894,0.525744,0.415787,0.454533
bloomz-1b7,EN,0.617435,0.629431,0.604459,0.576428,0.623538,0.570905,0.534322,0.64397,0.652007,0.54656,0.615527
bloomz-1b7,ID,0.432954,0.440561,0.357676,0.368534,0.396965,0.375659,0.360178,0.4524,0.510935,0.366691,0.413748
bloomz-3b,EN,0.57735,0.6109,0.577276,0.557401,0.594904,0.552941,0.525474,0.634489,0.615064,0.553253,0.594937
bloomz-3b,ID,0.570595,0.581326,0.473651,0.496212,0.541089,0.504376,0.479032,0.610004,0.617576,0.487612,0.555565
bloomz-560m,EN,0.442492,0.471191,0.354347,0.385986,0.437558,0.384482,0.372906,0.498828,0.591102,0.387029,0.463946
bloomz-560m,ID,0.409556,0.417616,0.404772,0.387102,0.393838,0.382162,0.371099,0.423875,0.49492,0.382163,0.404822


### Analyze per task

In [16]:
mdf.groupby(
    ['model','lang','task']
).mean().reset_index().pivot(
    ['model','lang'],'task', ['acc', 'macro_f1', 'micro_f1','weighted_f1']
)

Unnamed: 0_level_0,Unnamed: 1_level_0,acc,acc,acc,acc,acc,acc,acc,acc,acc,acc,...,weighted_f1,weighted_f1,weighted_f1,weighted_f1,weighted_f1,weighted_f1,weighted_f1,weighted_f1,weighted_f1,weighted_f1
Unnamed: 0_level_1,task,emot,imdb_jv,indolem_sentiment,indonli,nusa_kalimat_emot_abs,nusa_kalimat_emot_bew,nusa_kalimat_emot_bhp,nusa_kalimat_emot_btk,nusa_kalimat_emot_jav,nusa_kalimat_emot_mad,...,xnli_es,xnli_fr,xnli_hi,xnli_ru,xnli_sw,xnli_th,xnli_tr,xnli_ur,xnli_vi,xnli_zh
model,lang,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2
bilingual-bloomz-560m,EN,0.223485,0.463267,0.643917,0.328381,0.20963,0.206667,0.214815,0.211333,0.206333,0.207667,...,0.178586,0.168004,0.166633,0.167303,0.197756,0.166675,0.168187,0.167943,0.185416,0.173591
bilingual-bloomz-560m,ID,0.186364,0.50004,0.495219,0.333591,0.17037,0.189,0.165185,0.18,0.205,0.185333,...,,,,,,,,,,
bloomz-1b1,EN,0.229545,0.40212,0.768546,0.379317,0.208889,0.21,0.224444,0.21,0.21,0.21,...,0.181538,0.182368,0.183016,0.168904,0.190199,0.168059,0.167555,0.175736,0.199187,0.176093
bloomz-1b1,ID,0.299242,0.461,0.754698,0.345103,0.284444,0.258,0.251852,0.244,0.240333,0.279,...,,,,,,,,,,
bloomz-1b7,EN,0.231818,0.33196,0.760303,0.40993,0.208889,0.21,0.224444,0.210333,0.21,0.21,...,0.246327,0.242461,0.209639,0.20763,0.217091,0.205035,0.195838,0.199565,0.243632,0.227814
bloomz-1b7,ID,0.487879,0.396333,0.636334,0.404528,0.328889,0.330667,0.291111,0.301333,0.305667,0.296333,...,,,,,,,,,,
bloomz-3b,EN,0.247727,0.31712,0.772173,0.412309,0.211111,0.210667,0.224444,0.21,0.210333,0.21,...,0.250262,0.24415,0.220768,0.219688,0.203474,0.207429,0.19912,0.207793,0.237553,0.229303
bloomz-3b,ID,0.411364,0.379587,0.716782,0.464017,0.348148,0.347,0.322222,0.321333,0.327667,0.316333,...,,,,,,,,,,
bloomz-560m,EN,0.227273,0.435507,0.702275,0.368705,0.211852,0.210333,0.224444,0.21,0.210333,0.210667,...,0.236853,0.236757,0.195337,0.179174,0.181047,0.21444,0.173873,0.185542,0.216245,0.192473
bloomz-560m,ID,0.37197,0.4744,0.567755,0.341694,0.280741,0.268667,0.24963,0.246667,0.254667,0.251333,...,,,,,,,,,,


In [17]:
mdf.groupby(
    ['model','lang','task']
).mean().reset_index().pivot(
    ['model','lang'],'task', ['acc', 'macro_f1', 'micro_f1','weighted_f1']
).to_csv('result_zero_shot.csv', index=True)

In [21]:
dfs = []
for path in glob.glob('../metrics/*.csv'):
    lang, model = path[:-4].split('_')[2:4]
    df = pd.read_csv(path)
    df['lang'] = lang[:2]
    df['model'] = model
    dfs.append(df)
pd.concat(dfs).groupby(['model', 'lang']).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,accuracy,f1_score
model,lang,Unnamed: 2_level_1,Unnamed: 3_level_1
bilingual-bloomz-560m,EN,0.410617,0.29424
bilingual-bloomz-560m,ID,0.326184,0.250878
bloomz-1b1,EN,0.412547,0.292133
bloomz-1b1,ID,0.455926,0.337494
bloomz-1b7,EN,0.442906,0.29924
bloomz-1b7,ID,0.440505,0.354371
bloomz-3b,EN,0.453281,0.317125
bloomz-3b,ID,0.503425,0.408722
bloomz-560m,EN,0.376052,0.264769
bloomz-560m,ID,0.381165,0.318804


# 32-bit

In [19]:
metrics = {
    'task': [], 'lang': [], 'prompt': [], 'model': [], 'acc': [],
    'macro_f1': [], 'macro_prec': [], 'macro_rec': [],
    'micro_f1': [], 'micro_prec': [], 'micro_rec': []
}
for path in glob.glob('outputs_32bit/*.csv'):
    if '_nusantara_text_' in path:
        task, meta = path.split('/')[-1].split('_nusantara_text_')
    else:
        task, meta = path.split('/')[-1].split('_nusantara_pairs_')
    prompt, model = meta[:-4].split('_')
    
    df = pd.read_csv(path)
    cls_report = classification_report(df['Gold'], df['Pred'], output_dict=True)
    micro_f1, micro_prec, micro_rec, _ = precision_recall_fscore_support(df['Gold'], df['Pred'], average='micro')

    metrics['task'].append(task)
    metrics['lang'].append(prompt[:2])
    metrics['prompt'].append(prompt)
    metrics['model'].append(model)
    metrics['acc'].append(cls_report['accuracy'])
    metrics['macro_f1'].append(cls_report['macro avg']['f1-score'])
    metrics['macro_prec'].append(cls_report['macro avg']['precision'])
    metrics['macro_rec'].append(cls_report['macro avg']['recall'])
    metrics['micro_f1'].append(micro_f1)
    metrics['micro_prec'].append(micro_prec)
    metrics['micro_rec'].append(micro_rec)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

In [20]:
pd.DataFrame(metrics).groupby(['model','lang']).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,acc,macro_f1,macro_prec,macro_rec,micro_f1,micro_prec,micro_rec
model,lang,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
bloomz-1b1,EN,0.522649,0.379011,0.404832,0.457861,0.522649,0.522649,0.522649
bloomz-1b1,ID,0.528781,0.39276,0.464153,0.466697,0.528781,0.528781,0.528781
bloomz-1b7,EN,0.514377,0.36896,0.419022,0.451026,0.514377,0.514377,0.514377
bloomz-1b7,ID,0.526789,0.411978,0.508149,0.480274,0.526789,0.526789,0.526789
bloomz-560m,EN,0.475819,0.332466,0.379742,0.431091,0.475819,0.475819,0.475819
bloomz-560m,ID,0.443664,0.297815,0.377854,0.402763,0.443664,0.443664,0.443664
mt0-base,EN,0.380048,0.226484,0.218003,0.358409,0.380048,0.380048,0.380048
mt0-base,ID,0.373831,0.231813,0.253617,0.353473,0.373831,0.373831,0.373831
mt0-large,EN,0.426327,0.293602,0.403439,0.397829,0.426327,0.426327,0.426327
mt0-large,ID,0.37983,0.233884,0.230554,0.359814,0.37983,0.37983,0.37983


In [5]:
def get_logprobs(prompt):
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to('cuda')
    input_ids, output_ids = inputs["input_ids"], inputs["input_ids"][:, 1:]
    
    outputs = model(**inputs, labels=input_ids)
    logits = outputs.logits
    
    logprobs = torch.gather(F.log_softmax(logits, dim=2), 2, output_ids.unsqueeze(2))
    
    return logprobs.mean()

In [6]:
model = AutoModelForSeq2SeqLM.from_pretrained('bigscience/mt0-base')
tokenizer = AutoTokenizer.from_pretrained('bigscience/mt0-base')

In [None]:
inputs = tokenizer('The emotion of "I am really glad that you are here. Thank you!" is', return_tensors='pt')
input_ids, output_ids = inputs["input_ids"], inputs["input_ids"][:, 1:]
outputs = model(**inputs, labels=input_ids)
x1 = F.log_softmax(outputs.logits, dim=2)
y1 = torch.gather(F.log_softmax(outputs.logits, dim=2), 2, output_ids.unsqueeze(2)).mean()

In [None]:
model = AutoModelForCausalLM.from_pretrained('bigscience/bloomz-560m')
tokenizer = AutoTokenizer.from_pretrained('bigscience/mt0-base')

In [43]:
pd.concat(dfs).groupby(
    ['model','lang','index']
).mean().reset_index().pivot(
    ['model','lang'], 'index', ['accuracy', 'f1_score']
).to_csv('result_zero_shot.csv', index=True)

# Final Project

In [7]:
fp_df = mdf.loc[
    (~mdf['task'].str.startswith('xnli')) & ~mdf['task'].isin(['su_emot', 'indonli', 'indolem_sentiment', 'smsa', 'emot', 'imdb_jv'])
]

In [12]:
fp_df.sort_values(['task', 'model']).head(30)

Unnamed: 0,task,task_lang,prompt_lang,prompt,model,acc,macro_f1,macro_prec,macro_rec,micro_f1,micro_prec,micro_rec,weighted_f1,weighted_prec,weighted_rec
526,nusa_kalimat_emot_abs,abs,ID,ID3,bilingual-bloomz-560m,0.157778,0.083115,0.07977,0.205731,0.157778,0.157778,0.157778,0.076699,0.084122,0.157778
671,nusa_kalimat_emot_abs,abs,EN,EN,bilingual-bloomz-560m,0.208889,0.103098,0.152474,0.208018,0.208889,0.208889,0.208889,0.097275,0.158016,0.208889
1193,nusa_kalimat_emot_abs,abs,ID,ID,bilingual-bloomz-560m,0.146667,0.062306,0.100973,0.203905,0.146667,0.146667,0.146667,0.050782,0.111085,0.146667
1356,nusa_kalimat_emot_abs,abs,EN,EN2,bilingual-bloomz-560m,0.211111,0.080739,0.108559,0.204222,0.211111,0.211111,0.211111,0.080334,0.09042,0.211111
2107,nusa_kalimat_emot_abs,abs,EN,EN3,bilingual-bloomz-560m,0.208889,0.069118,0.041778,0.2,0.208889,0.208889,0.208889,0.07219,0.043635,0.208889
2260,nusa_kalimat_emot_abs,abs,ID,ID2,bilingual-bloomz-560m,0.206667,0.122115,0.095712,0.236241,0.206667,0.206667,0.206667,0.123576,0.102352,0.206667
0,nusa_kalimat_emot_abs,abs,EN,EN2,bloomz-1b1,0.208889,0.069118,0.041778,0.2,0.208889,0.208889,0.208889,0.07219,0.043635,0.208889
1611,nusa_kalimat_emot_abs,abs,ID,ID2,bloomz-1b1,0.233333,0.148112,0.113794,0.250385,0.233333,0.233333,0.233333,0.130836,0.098166,0.233333
1755,nusa_kalimat_emot_abs,abs,ID,ID3,bloomz-1b1,0.32,0.218037,0.413753,0.290312,0.32,0.32,0.32,0.222516,0.454712,0.32
2034,nusa_kalimat_emot_abs,abs,EN,EN,bloomz-1b1,0.208889,0.069118,0.041778,0.2,0.208889,0.208889,0.208889,0.07219,0.043635,0.208889


In [13]:
fp_df.groupby(['task', 'model']).size()

task                   model                   
nusa_kalimat_emot_abs  bilingual-bloomz-560m       6
                       bloomz-1b1                  6
                       bloomz-1b7                  6
                       bloomz-3b                   6
                       bloomz-560m                 6
                                                  ..
nusax_senti_sun        pair-bloomz-560m-R-100      6
                       pair-bloomz-560m-R-1000     6
                       pair-bloomz-560m-R-10000    6
                       random-bloomz-560m          6
                       translation-bloomz-560m     6
Length: 544, dtype: int64

### Per Task Result

### Catastrophic Forgetting

In [45]:
cf_models = ['pair-bloomz-560m','checkpoint-9740','checkpoint-19480', 'checkpoint-29220', 'checkpoint-38960']

In [46]:
cf_df = fp_df.loc[fp_df['model'].isin(cf_models)]

In [47]:
cf_df

Unnamed: 0,task,lang,prompt,model,acc,macro_f1,macro_prec,macro_rec,micro_f1,micro_prec,micro_rec,weighted_f1,weighted_prec,weighted_rec
1,nusax_senti_ace,ID,ID3,checkpoint-29220,0.337500,0.225947,0.221954,0.297321,0.337500,0.337500,0.337500,0.256942,0.252985,0.337500
5,nusa_kalimat_emot_mui,ID,ID3,checkpoint-19480,0.177778,0.102929,0.077812,0.200319,0.177778,0.177778,0.177778,0.102981,0.081851,0.177778
11,nusa_kalimat_senti_jav,ID,ID2,checkpoint-9740,0.342500,0.318982,0.509669,0.504498,0.342500,0.342500,0.342500,0.266883,0.597522,0.342500
19,nusa_kalimat_emot_mad,ID,ID,checkpoint-38960,0.240000,0.218514,0.242729,0.250257,0.240000,0.240000,0.240000,0.217395,0.250623,0.240000
21,nusa_kalimat_emot_btk,EN,EN2,checkpoint-19480,0.209000,0.069148,0.041842,0.199048,0.209000,0.209000,0.209000,0.072605,0.043934,0.209000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4045,nusa_kalimat_emot_bew,EN,EN2,checkpoint-29220,0.210000,0.069421,0.042000,0.200000,0.210000,0.210000,0.210000,0.072893,0.044100,0.210000
4049,nusa_kalimat_senti_btk,EN,EN,checkpoint-29220,0.419167,0.418872,0.516633,0.515850,0.419167,0.419167,0.419167,0.413490,0.604494,0.419167
4057,nusa_kalimat_senti_bhp,EN,EN,checkpoint-9740,0.534000,0.483029,0.492370,0.490873,0.534000,0.534000,0.534000,0.554453,0.589801,0.534000
4059,nusa_kalimat_senti_rej,EN,EN2,checkpoint-38960,0.446000,0.445501,0.510793,0.511339,0.446000,0.446000,0.446000,0.451955,0.587439,0.446000
