## Data exploration - lm-kbc challenge

In [None]:
import json
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
from pathlib import Path
from typing import List, Dict, Union



def read_lm_kbc_jsonl(file_path: Union[str, Path]) -> List[Dict]:
    """
    Reads a LM-KBC jsonl file and returns a list of dictionaries.
    Args:
        file_path: path to the jsonl file
    Returns:
        list of dictionaries, each possibly has the following keys:
        - "SubjectEntity": str
        - "Relation": str
        - "ObjectEntities":
            None or List[List[str]] (can be omitted for the test input)
    """
    rows = []
    with open(file_path, "r") as f:
        for line in f:
            row = json.loads(line)
            rows.append(row)
    return rows


def read_lm_kbc_jsonl_to_df(file_path: Union[str, Path]) -> pd.DataFrame:
    """
    Reads a LM-KBC jsonl file and returns a dataframe.
    """
    rows = read_lm_kbc_jsonl(file_path)
    df = pd.DataFrame(rows)
    return df

### Load dev set

In [None]:
dev = read_lm_kbc_jsonl_to_df('dev_checkup.jsonl')

In [None]:
# Count unique s,r 
dev.groupby('Relation').count()

In [None]:
dev[dev['Relation']=='PersonProfession'][:2]

### Load train set

In [None]:
train_set = read_lm_kbc_jsonl_to_df('data/train.jsonl')

In [None]:
train_set['counts'] = train_set['ObjectEntities'].apply(lambda x: len(x))

In [None]:
grouped = train_set.groupby('Relation')

In [None]:
grouped.count()

In [None]:
train_set.groupby('Relation').hist(by=train_set['Relation'])

In [None]:
plt.tight_layout()
train_set.counts.hist(by=train_set['Relation'],
                      grid=True,
                      layout=(6,2),
                      figsize=(10,15),
                      xlabelsize=10,
                      ylabelsize=10,
                      xrot=15)

plt.savefig('relation_statistics_train_set.pdf',
           dpi=150)

In [None]:
grouped = train_set.groupby('Relation')

In [None]:
(grouped.counts.value_counts()
   .unstack().plot.bar(width=1, stacked=False))

In [None]:
train_set.groupby('Relation').hist(sharex=True,
                                   sharey=True)

#### 
    

In [None]:
# def count_anwers(list_of_lists):

In [None]:
import seaborn as sns

In [None]:
train_set.groupby('Relation').hist()

In [None]:
train_set.groupby('Relation').count()

In [None]:
train_set_selaki = train_set.groupby('Relation').reset_index(drop=True)

In [None]:
tryout

### Get data statistics

- 

In [None]:
dev_set = read_lm_kbc_jsonl_to_df('data/dev.jsonl')

In [None]:
dev_set[dev_set['Relation'] == 'PersonProfession']['ObjectEntities']

In [None]:
dev_set['counts'] = dev_set['ObjectEntities'].apply(lambda x: len(x))

In [None]:
plt.tight_layout()
dev_set.counts.hist(by=dev_set['Relation'],
                      grid=True,
                      layout=(6,2),
                      figsize=(10,15),
                      xlabelsize=10,
                      ylabelsize=10,
                      xrot=15)

plt.savefig('relation_statistics_dev_set.pdf',
           dpi=150)

In [None]:
person_employer = train_set[train_set['Relation'] == 'PersonCauseOfDeath']

In [None]:
pd.set_option("display.max_rows", 100)

In [None]:
person_employer[0:100]

In [None]:
person_employer_dev = dev[dev['Relation'] == 'PersonEmployer']

In [None]:
pd.set_option('display.max_colwidth', 255)

person_employer_dev

In [None]:
temp = dev_set.groupby('Relation').agg({'counts': ['mean', 'std']})

In [None]:
temp = temp.sort_values(by=[('counts', 'mean'),('counts', 'mean')])

In [None]:
temp.to_latex('rel_means_std.tex')

In [None]:
temp_stats = temp.agg({'counts': ['mean', 'std']})

In [None]:
temp_stats

In [None]:
temp

In [None]:
temp.to_csv('avg_num_answers_per_rel_type.csv')

ADA

| Relation                             | p     |r   |  f1|
|--------------------------------------|:------|:------|:------|
|ChemicalCompoundElement    |0.256  |0.225  |0.231|
|CompanyParentOrganization  |0.120  |0.120  |0.120|
|CountryBordersWithCountry  |0.066  |0.040  |0.046|
|CountryOfficialLanguage    |0.142  |0.145  |0.133|
|PersonCauseOfDeath         |0.160  |0.160  |0.160|
|PersonEmployer             |0.000  |0.000  |0.000|
|PersonInstrument           |0.297  |0.352  |0.270|
|PersonLanguage             |0.331  |0.702  |0.394|
|PersonPlaceOfDeath         |0.040  |0.040  |0.040|
|PersonProfession           |0.281  |0.134  |0.156|
|RiverBasinsCountry         |0.365  |0.349  |0.313|
|StateSharesBorderState     |0.102  |0.060  |0.066|
|*** Average ***            |0.180  |0.194  |0.161

BABBAGE

| Relation                             | p     |r   |  f1|
|--------------------------------------|:------|:------|:------|
|ChemicalCompoundElement    |0.357  |0.240  |0.275|
|CompanyParentOrganization  |0.080  |0.080  |0.080|
|CountryBordersWithCountry  |0.206  |0.171  |0.170|
|CountryOfficialLanguage    |0.686  |0.629  |0.605|
|PersonCauseOfDeath         |0.040  |0.040  |0.040|
|PersonEmployer             |0.012  |0.017  |0.014|
|PersonInstrument           |0.507  |0.463  |0.457|
|PersonLanguage             |0.689  |0.657  |0.636|
|PersonPlaceOfDeath         |0.000  |0.000  |0.000|
|PersonProfession           |0.513  |0.219  |0.286|
|RiverBasinsCountry         |0.700  |0.558  |0.578|
|StateSharesBorderState     |0.117  |0.078  |0.088|
|*** Average ***            |0.325  |0.263  |0.269|

CURIE

| Relation                             | p     |r   |  f1|
|--------------------------------------|:------|:------|:------|
|ChemicalCompoundElement    |0.532  |0.521  |0.513|
|CompanyParentOrganization  |0.140  |0.140  |0.140|
|CountryBordersWithCountry  |0.517  |0.487  |0.462|
|CountryOfficialLanguage    |0.658  |0.768  |0.664|
|PersonCauseOfDeath         |0.040  |0.040  |0.040|
|PersonEmployer             |0.043  |0.067  |0.050|
|PersonInstrument           |0.318  |0.421  |0.326|
|PersonLanguage             |0.752  |0.833  |0.739|
|PersonPlaceOfDeath         |0.000  |0.000  |0.000|
|PersonProfession           |0.683  |0.312  |0.383|
|RiverBasinsCountry         |0.598  |0.711  |0.604|
|StateSharesBorderState     |0.255  |0.195  |0.198|
|*** Average ***            |0.378  |0.375  |0.343|

Da-Vinci


| Relation                             | p     |r   |  f1|
|--------------------------------------|:------|:------|:------|
|ChemicalCompoundElement    |0.905  |0.894  |0.894|
|CompanyParentOrganization  |0.685  |0.700  |0.688|
|CountryBordersWithCountry  |0.830  |0.794  |0.792|
|CountryOfficialLanguage    |0.824  |0.840  |0.788|
|PersonCauseOfDeath         |0.600  |0.590  |0.593|
|PersonEmployer             |0.276  |0.335  |0.270|
|PersonInstrument           |0.589  |0.570  |0.551|
|PersonLanguage             |0.755  |0.936  |0.797|
|PersonPlaceOfDeath         |0.820  |0.820  |0.820|
|PersonProfession           |0.735  |0.526  |0.582|
|RiverBasinsCountry         |0.824  |0.851  |0.820|
|StateSharesBorderState     |0.638  |0.472  |0.532|
|*** Average ***            |0.707  |0.694  |0.677|

### Scaling plot

In [None]:
data = {'Model': ['Ada (1024)', 'Babbage (2048)', 'Curie (4096)', 'Da-Vinci (12288)'],
        'F1 Score': [0.161,0.269, 0.343, 0.677]}

In [None]:
data = pd.DataFrame(data)

In [None]:
%matplotlib inline

In [None]:
model_types = data.Model.to_list()
performances = data['F1 Score'].to_list()

In [None]:
sns.set_style("darkgrid")
sns.set(rc = {'figure.figsize':(15,8)})

ax = sns.lineplot(data=data, x='Model', y='F1 Score',
            lw=2, marker='o')
ax.set_title("Performance per Model Size", fontsize=15)

ax.set_xlabel(ax.get_xlabel(), fontdict={'weight': 'bold', 'size':16})
ax.set_ylabel(ax.get_ylabel(), fontdict={'weight': 'bold','size':16})
ax.tick_params(axis='x', labelrotation=45, labelsize=20)
plt.tight_layout()
plt.savefig('performance_per_model_size.pdf')