## Import

In [None]:
from utils.load_data import load_llama_factory_data
from typing import Optional
import json
from functools import cache
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
import seaborn as sns
import numpy as np
from utils.model_registry import MODELS

load_llama_factory_data = cache(load_llama_factory_data)

In [None]:
all_language = ['en', 'de', 'fr', 'it', 'ar', 'he', 'ru', 'pl', 'ja', 'zh']


def get_file_path(
    test_data_name,
    model_name,
    train_data_name,
    train_epoch,
    train_lr='2e-4',
    train_dropout='0',
    train_type='lora-all',
    template=None,
    is_base_model=False,
) -> str:
    if template is None:
        template = MODELS[model_name].default_template.replace('_', '-')
    if is_base_model:
        file_path = f'eval/outputs/{test_data_name}_{model_name}_{template}.json'
    else:
        file_path = f'eval/outputs/{test_data_name}_{model_name}-{train_data_name}{train_lr}-{train_dropout}-{train_epoch}-{train_type}_{template}.json'
    return file_path


def load_data(
    task = 'arithmetic',
    lang = 'EN',
    model_name = 'llama-2-7b-chat',
    train_lang = 'EN',
    train_epoch = '1',
    is_base_model = False,
):
    test_data_name = f'kfrd_{task}_{lang}_test'
    train_data_name = f'kfrd-{task}-{train_lang}-train'
    
    train_lr = '2e-4'
    train_dropout = '0'
    train_type = 'lora-all'
    template = MODELS[model_name].default_template.replace('_', '-')

    file_path = get_file_path(
        test_data_name=test_data_name,
        model_name=model_name,
        train_data_name=train_data_name,
        train_lr=train_lr,
        train_dropout=train_dropout,
        train_epoch=train_epoch,
        train_type=train_type,
        template=template,
        is_base_model=is_base_model,
    )
    
    v = [t[1] for t in json.load(open(file_path))]
    assert all(y in 'ABCD' for y in v)
    return v


def load_test_data(
    task = 'arithmetic',
    lang = 'EN',
):
    test_data_name = f'kfrd_{task}_{lang}_test'
    testset = load_llama_factory_data(test_data_name)
    return testset


def compute_accuracy(
    task = 'arithmetic',
    lang = 'EN',
    model_name = 'llama-2-7b-chat',
    train_lang = 'EN',
    train_epoch = '1',
    return_corrects=False,
    is_base_model=False,
):
    testset = load_test_data(
        task = task,
        lang = lang,
    )
    hypset = load_data(
        task = task,
        lang = lang,
        model_name = model_name,
        train_lang = train_lang,
        train_epoch = train_epoch,
        is_base_model = is_base_model,
    )
    
    corrects = [int(hyp == ref['response']) for hyp, ref in zip(hypset, testset)]
    accuracy = sum(corrects) / len(testset)
    if return_corrects:
        return accuracy, corrects
    else:
        return accuracy


def evaluate_sqa_accuracy(
    model_name: str, lang: str,
    facts=False, return_corrects=False, is_base_model=False,
) -> Optional[float]:
    lang = lang.upper()
    
    if facts == 'one':
        test_data_name = f'sqa_one_fact_dev_{lang}'
        train_data_name = 'sqa-one-fact-train'
    elif facts == 'two':
        test_data_name = f'sqa_two_fact_dev_{lang}'
        train_data_name = 'sqa-two-fact-train'
    elif facts:
        test_data_name = f'sqa_facts_dev_{lang}'
        train_data_name = 'sqa-facts-train'
    else:
        test_data_name = f'sqa_dev_{lang}'
        train_data_name = 'sqa-train'
    
    path = get_file_path(
        test_data_name=test_data_name,
        model_name=model_name,
        train_data_name=train_data_name,
        train_epoch='4',
        is_base_model=is_base_model,
    )
    testset = load_llama_factory_data(
        test_data_name,
    )
    hypset = json.load(open(path))
    corrects = [int(hyp[1] == ref['output']) for hyp, ref in zip(hypset, testset)]
    accuracy = sum(corrects) / len(testset)
    if return_corrects:
        return accuracy, corrects
    else:
        return accuracy

## Accuracy Before Finetuning

In [None]:
d = {}

facts_map = {
    False: 'NF',
    'one': 'WF-1',
    'two': 'WF-2',
    True: 'WF-all',
}

for facts in [False, 'one', 'two', True]:
    d[facts_map[facts]] = {
        lang: evaluate_sqa_accuracy('llama-2-7b-chat', lang, facts=facts, is_base_model=True)
        for lang in all_language
    }
    
base_accuracy_sqa_llama2_df = pd.DataFrame(d)
# display(base_accuracy_sqa_llama2_df.T.mul(100).round(2))

In [None]:
d = {}
display_name = {
    'llama-2-7b-chat': 'LLaMA 2',
    'qwen1.5-7b-chat': 'Qwen 1.5',
    'bloomz-7b1-mt': 'BLOOMZ',
    'mistral-7b-instruct-v0.1': 'Mistral',
}

for model_name in display_name:
    d['StrategyQA NF', display_name[model_name]] = {
        lang: evaluate_sqa_accuracy(model_name, lang, facts=False, is_base_model=True)
        for lang in all_language
    }

for model_name in display_name:
    d['StrategyQA WF', display_name[model_name]] = {
        lang: evaluate_sqa_accuracy(model_name, lang, facts=True, is_base_model=True)
        for lang in all_language
    }
    
for model_name in display_name:
    d['Arithmetic', display_name[model_name]] = {
        lang: compute_accuracy(
            'arithmetic', lang.upper(),
            model_name=model_name, is_base_model=True,
        )
        for lang in all_language
    }
    
for model_name in display_name:
    d['Symbolic', display_name[model_name]] = {
        lang: compute_accuracy(
            'symbolic', lang.upper(),
            model_name=model_name, is_base_model=True,
        )
        for lang in all_language
    }
    
for model_name in display_name:
    d['Logical', display_name[model_name]] = {
        lang: compute_accuracy(
            'logical', lang.upper(),
            model_name=model_name, is_base_model=True,
        )
        for lang in all_language
    }
    
base_accuracy_df = pd.DataFrame(d)
# display(base_accuracy_df.T.mul(100).round(2))

In [None]:
def plot_radar(ax: plt.Axes, df: pd.DataFrame, linestyle: str, alpha: float = 1.0):
    df = df * 100
    labels = list(df.index)
    labels[0] = labels[0].upper()
    num_vars = len(labels)

    delta = np.pi / 2
    angles = np.linspace(0 + delta, 2 * np.pi + delta, num_vars, endpoint=False).tolist()

    angles += angles[:1]
    angles = np.array(angles)
    angles[angles >= 2 * np.pi] -= 2 * np.pi
    
    for model_name, color in zip(df.columns, colors):
        values = df[model_name].tolist()
        values += values[:1]
        ax.plot(
            angles, values, label=model_name,
            c=to_rgba(color, alpha=alpha), linestyle=linestyle,
            linewidth=2, marker='.',
        )
    
    ax.set_yticks([0, 25, 50, 75, 100])
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(labels)
    

language = all_language
display_name = {
    'llama-2-7b-chat': 'LLaMA 2',
    'qwen1.5-7b-chat': 'Qwen 1.5',
    'bloomz-7b1-mt': 'BLOOMZ',
    'mistral-7b-instruct-v0.1': 'Mistral',
}
colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red']


fig, ax = plt.subplots(1, 1, figsize=(4, 4), subplot_kw=dict(polar=True))
ax: plt.Axes
plot_radar(ax, base_accuracy_df['StrategyQA WF'], '-' , alpha=1.0)
plot_radar(ax, base_accuracy_df['StrategyQA NF'], '--', alpha=1.0)
fig.legend(list(display_name.values()), ncol=2, loc='upper center', bbox_to_anchor=(0.5, 0.0))
# fig.tight_layout()
fig.show()
# fig.savefig('pic/sqa-acc-init.pdf', bbox_inches='tight')

## StrategyQA

In [None]:
def plot_radar(
    ax: plt.Axes, df: pd.DataFrame,
    alpha: float, colors: list[str],
    linestyle: str = '-',
):
    df = df * 100
    
    labels = list(df.index)
    labels[0] = labels[0].upper()
    num_vars = len(labels)

    delta = np.pi / 2
    angles = np.linspace(0 + delta, 2 * np.pi + delta, num_vars, endpoint=False).tolist()

    angles += angles[:1]
    angles = np.array(angles)
    angles[angles >= 2 * np.pi] -= 2 * np.pi
    
    assert len(colors) >= len(df.columns)
    for model_name, color in zip(df.columns, colors):
        values = df[model_name].tolist()
        values += values[:1]
        ax.plot(
            angles, values, label=model_name,
            c=to_rgba(color, alpha=alpha), linestyle=linestyle,
            linewidth=2, marker='.',
        )
    
    ax.set_yticks([0, 25, 50, 75, 100])
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(labels)

### All Models

In [None]:
def create_data_sqa(facts=False):
    data = {
        display_name[model_name]: [
            evaluate_sqa_accuracy(
                model_name=model_name,
                lang=lang,
                facts=facts,
                return_corrects=True
            )[int(print_xltr)]
            for lang in language
        ]
        for model_name in display_name
    }
    data['Language'] = language
    df = pd.DataFrame(data)
    df = df.set_index("Language")
    if print_xltr:
        df = df.apply(
            lambda series: pd.Series(
                [
                    sum(x & y for x, y in zip(series['en'], series[lang])) / sum(series['en'])
                    for lang in series.index
                ],
                index=series.index,
            ),
            axis=0,
        )
        df = df - 0.5
        df = df.clip(lower=0.0) * 2
    return df
    
    
language = all_language
display_name = {
    'llama-2-7b-chat': 'LLaMA 2',
    'qwen1.5-7b-chat': 'Qwen 1.5',
    'bloomz-7b1-mt': 'BLOOMZ',
    'mistral-7b-instruct-v0.1': 'Mistral',
}
colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red']

def draw():
    fig, ax = plt.subplots(1, 1, figsize=(4, 4), subplot_kw=dict(polar=True))
    ax: plt.Axes
    plot_radar(ax, create_data_sqa(facts=True) , 1.0, colors, '-')
    plot_radar(ax, create_data_sqa(facts=False), 1.0, colors, '--')
    fig.legend(list(display_name.values()), loc='upper center', ncol=2, bbox_to_anchor=(0.5, 0))
    # fig.tight_layout()
    fig.show()
    # output_name = 'sqa-xltr.pdf' if print_xltr else 'sqa-acc.pdf'
    # fig.savefig(os.path.join('pic', output_name), bbox_inches='tight')
    

print_xltr = True ; draw() # XLTR
print_xltr = False; draw() # Accuracy

### LLaMA 2 

In [None]:
display_name = {
    'llama-2-7b-chat': 'LLaMA 2',
}

def draw():
    fig, ax = plt.subplots(1, 1, figsize=(4, 4), subplot_kw=dict(polar=True))
    ax: plt.Axes
    plot_radar(ax, create_data_sqa(facts=False), 1.0, [colors[0]], '-')
    plot_radar(ax, create_data_sqa(facts='one'), 1.0, [colors[1]], '-')
    plot_radar(ax, create_data_sqa(facts='two'), 1.0, [colors[2]], '-')
    plot_radar(ax, create_data_sqa(facts=True) , 1.0, [colors[3]], '-')
    if not print_xltr:
        plot_radar(ax, base_accuracy_sqa_llama2_df , 0.3, colors, '-')
    fig.legend(['NF', 'WF-1', 'WF-2', 'WF-all'], loc='upper center', ncol=2, bbox_to_anchor=(0.5, 0))
    # fig.tight_layout()
    fig.show()
    # output_name = 'sqa-llama2-xltr.pdf' if print_xltr else 'sqa-llama2-acc.pdf'
    # fig.savefig(os.path.join('pic', output_name), bbox_inches='tight')
    
print_xltr = True ; draw() # XLTR
print_xltr = False; draw() # Accuracy

## KFRD Main

In [None]:
def create_data(
    task = 'arithmetic',
    train_epoch = '1',
):
    def get_accuracy(model_name):
        return [
            compute_accuracy(
                lang=lang.upper(),
                model_name=model_name,
                task=task,
                train_epoch=train_epoch,
                return_corrects=True,
            )[int(print_xltr)]
            for lang in language
        ]

    data = {
        display_name[model_name]: get_accuracy(model_name)
        for model_name in display_name
    }
    data['Language'] = language

    df = pd.DataFrame(data)

    df = df.set_index("Language")
    if print_xltr:
        df = df.apply(
            lambda series: pd.Series(
                [
                    sum(x & y for x, y in zip(series['en'], series[lang])) / sum(series['en'])
                    for lang in series.index
                ],
                index=series.index,
            ),
            axis=0,
        )
        df = df - 0.25
        df = df.clip(lower=0.0) / (1 - 0.25)
    return df
    

def plot_radar(ax, df, title):
    df *= 100
    
    labels = list(df.index)
    labels[0] = labels[0].upper()
    num_vars = len(labels)

    delta = np.pi / 2
    angles = np.linspace(0 + delta, 2 * np.pi + delta, num_vars, endpoint=False).tolist()

    angles += angles[:1]
    angles = np.array(angles)
    angles[angles >= 2 * np.pi] -= 2 * np.pi
    
    for model, c in zip(df.columns, colors):
        values = df[model].tolist()
        values += values[:1]
        ax.plot(
            angles, values, label=model,
            c=c, linewidth=2, marker='.',
        )
        
    if not print_xltr:
        for model, c in zip(base_accuracy_df[title.split()[0]].columns, colors):
            values = (base_accuracy_df[title.split()[0]][model] * 100).tolist()
            values += values[:1]
            ax.plot(
                angles, values, label=model,
                c=to_rgba(c, alpha=0.4), linewidth=2, marker='.',
            )
    
    ax.set_yticks([0, 25, 50, 75, 100])
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(labels)
    
    ax.set_title(title)


language = all_language
display_name = {
    'llama-2-7b-chat': 'LLaMA 2',
    'qwen1.5-7b-chat': 'Qwen 1.5',
    'bloomz-7b1-mt': 'BLOOMZ',
    'mistral-7b-instruct-v0.1': 'Mistral',
}
colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red']


def draw():
    fig, axs = plt.subplots(1, 3, figsize=(12, 4), squeeze=False, subplot_kw=dict(polar=True))
    plot_radar(axs[0][0], create_data('arithmetic', train_epoch='4'), "Arithmetic Reasoning")
    plot_radar(axs[0][1], create_data('symbolic'), "Symbolic Reasoning"  )
    plot_radar(axs[0][2], create_data('logical'), "Logical Reasoning"   )
    # fig.tight_layout()
    fig.legend(list(display_name.values()), ncol=4, loc='lower center', bbox_to_anchor=(0.5, 0.0))
    fig.show()
    # output_name = 'main-xltr.pdf' if print_xltr else 'main-acc.pdf'
    # fig.savefig(os.path.join('pic', output_name), bbox_inches='tight')

print_xltr = True ; draw() # XLTR
print_xltr = False; draw() # Accuracy

## KFRD Different Train Language

In [None]:
def create_data(task : str) -> pd.DataFrame:
    language = all_language
    
    def get_accuracy(train_lang):
        return [
            compute_accuracy(
                lang=lang.upper(),
                model_name='llama-2-7b-chat',
                train_lang=train_lang,
                task=task, 
                train_epoch='4' if task == 'arithmetic' else '1',
                return_corrects=True,
            )[int(print_xltr)]
            for lang in language
        ]

    data = {
        train_lang: get_accuracy(train_lang)
        for train_lang in train_langs
    }
    data['Language'] = language

    df = pd.DataFrame(data)
    df.set_index("Language", inplace=True)
    if print_xltr:
        df = df.apply(
            lambda series: pd.Series(
                [
                    sum(x & y for x, y in zip(series['en'], series[lang])) / sum(series['en'])
                    for lang in series.index
                ],
                index=series.index,
            ),
            axis=0,
        )
    return df


def plot_radar(ax: plt.Axes, df: pd.DataFrame, title):
    df = df * 100
    labels = list(df.index)
    num_vars = len(labels)

    delta = np.pi / 2
    angles = np.linspace(0 + delta, 2 * np.pi + delta, num_vars, endpoint=False).tolist()

    angles += angles[:1]
    angles = np.array(angles)
    angles[angles >= 2 * np.pi] -= 2 * np.pi
    for model in df.columns:
        values = df[model].tolist()
        values += values[:1]
        ax.plot(
            angles, values, label=model,
            linewidth=2, marker='.',
        )
    
    ax.set_yticks([0, 25, 50, 75, 100])
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(labels)    
    ax.set_title(title)

def draw():
    sns.set_palette('tab10')
    fig, axs = plt.subplots(1, 3, figsize=(12, 4), subplot_kw=dict(polar=True))
    plot_radar(axs[0], create_data('arithmetic'), "Arithmetic")
    plot_radar(axs[1], create_data('symbolic'  ), "Symbolic")
    plot_radar(axs[2], create_data('logical'   ), "Logical")
    # fig.tight_layout()
    fig.legend(train_langs, ncol=5, loc='lower center', bbox_to_anchor=(0.5, 0))
    fig.show()
    # output_name = 'langs-xltr.pdf' if print_xltr else 'langs-acc.pdf'
    # fig.savefig(os.path.join('pic', output_name), bbox_inches='tight')

train_langs = ['EN', 'DE', 'ZH', 'AR', 'HE']
print_xltr = True ; draw() # XLTR
print_xltr = False; draw() # Accuracy

## KFRD SFT & CPT Models

### Accuracy

In [None]:
def get_averaged_accuracy_for_model_and_lang_pair(
    model_name: str,
    train_lang: str,
    lang: str,
):
    v = [
        compute_accuracy(
            task='arithmetic', train_epoch='4', lang=lang, train_lang=train_lang, model_name=model_name,
        ),
        compute_accuracy(
            task='symbolic', lang=lang, train_lang=train_lang, model_name=model_name,
        ),
        compute_accuracy(
            task='logical', lang=lang, train_lang=train_lang, model_name=model_name,
        ),
    ]
    return sum(v) / 3
    
    
def draw_pic3(ax: plt.Axes, model_names, alter_lang):
    x = np.arange(len(model_names))
    width = 0.3
    ax.bar(
        x - 0.5 * width,
        [
            get_averaged_accuracy_for_model_and_lang_pair(model_name, 'EN', 'EN')
            for model_name in model_names
        ],
        width, label='English',
    )
    ax.bar(
        x + 0.5 * width,
        [
            get_averaged_accuracy_for_model_and_lang_pair(model_name, 'EN', alter_lang)
            for model_name in model_names
        ],
        width, label='Hebrew',
    )
    ax.set_xticks(x)
    ax.set_xticklabels([labels[0]] + labels[2:])
    ax.grid(axis='y')
    ax.set_title(alter_lang)


def draw_pic4(ax: plt.Axes, model_names, alter_lang):
    x = np.arange(len(model_names))
    width = 0.3
    ax.bar(
        x - 0.5 * width,
        [
            get_averaged_accuracy_for_model_and_lang_pair(model_name, 'EN', 'EN')
            for model_name in model_names
        ],
        width, label='English',
    )
    ax.bar(
        x + 0.5 * width,
        [
            get_averaged_accuracy_for_model_and_lang_pair(model_name, 'EN', alter_lang)
            for model_name in model_names
        ],
        width, label='Arabic',
    )
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.grid(axis='y')
    ax.set_title(alter_lang)
    
    
labels = [
    'Vanilla',
    'SFT',
    'CPT',
    'CPT+SFT',
]
    
colors = ['#74a9cf', '#fc8d59']
sns.set_palette(sns.color_palette(colors))
    
fig, axs = plt.subplots(1, 2, figsize=(6, 4), sharey='all')
draw_pic4(
    axs[0],
    [
        'llama-2-7b-chat',
        'llama-2-7b-chat-arabic-lora',
        'sambalingo-arabic-base',
        'sambalingo-arabic-chat',
    ],
    'AR',
)
draw_pic3(
    axs[1],
    ['mistral-7b-instruct-v0.1', 'dictalm-2', 'dictalm-2-instruct'], 'HE',
)

axs[1].legend(['English', 'Target Language'], loc='lower right', ncol=1)
# fig.tight_layout()
fig.show()
# output_name = 'cptft-acc.pdf'
# fig.savefig(os.path.join('pic', output_name), bbox_inches='tight')

### XLTR

In [None]:
def get_transition_ratio_for_model(
    model_name: str,
    alter_lang: str,
):
    v = [
        [
            compute_accuracy(
                task='arithmetic', train_epoch='4',
                lang=lang, model_name=model_name, return_corrects=True,
            )[1],
            compute_accuracy(
                task='symbolic',
                lang=lang, model_name=model_name, return_corrects=True,
            )[1],
            compute_accuracy(
                task='logical',
                lang=lang, model_name=model_name, return_corrects=True,
            )[1],
        ]
        for lang in ('EN', alter_lang)
    ]

    
    v = [
        max(0, sum(x & y for x, y in zip(en_corrects, alter_corrects)) / sum(en_corrects) * 2 - 1)
        for en_corrects, alter_corrects in zip(*v)
    ]
    return sum(v) / len(v)


fig, axs = plt.subplots(1, 2, figsize=(6, 4), sharey='all')
axs: list[plt.Axes]
axs[0].bar(
    list(range(4)),
    [
        get_transition_ratio_for_model(model_name, 'AR')
        for model_name in [
            'llama-2-7b-chat',
            'llama-2-7b-chat-arabic-lora',
            'sambalingo-arabic-base',
            'sambalingo-arabic-chat',
        ]
    ],
    width=0.5,
)
axs[0].set_xticks(list(range(4)))
axs[0].set_xticklabels(labels)
axs[0].set_title('AR')
axs[0].grid(axis='y')

axs[1].bar(
    list(range(3)),
    [
        get_transition_ratio_for_model(model_name, 'HE')
        for model_name in [
            'mistral-7b-instruct-v0.1',
            'dictalm-2', 'dictalm-2-instruct',
        ]
    ],
    width=0.5,
)
axs[1].set_xticks(list(range(3)))
axs[1].set_xticklabels(labels[0:1] + labels[2:])
axs[1].set_title('HE')
axs[1].grid(axis='y')

# fig.tight_layout()
fig.show()
# output_name = 'cptft-xltr.pdf'
# fig.savefig(os.path.join('pic', output_name), bbox_inches='tight')