In [None]:
!pip uninstall -y enum34
!pip install --upgrade pip
!pip install farm

In [1]:
!set NotebookApp.iopub_data_rate_limit=2000000.0

In [None]:
import logging
from pathlib import Path

from farm.data_handler.data_silo import DataSilo
from farm.data_handler.processor import TextClassificationProcessor
from farm.modeling.optimization import initialize_optimizer
from farm.infer import Inferencer
from farm.modeling.adaptive_model import AdaptiveModel
from farm.modeling.language_model import LanguageModel
from farm.modeling.prediction_head import MultiLabelTextClassificationHead
from farm.modeling.tokenization import Tokenizer
from farm.train import Trainer
from farm.utils import set_all_seeds, MLFlowLogger, initialize_device_settings
from pprint import PrettyPrinter
import os.path
import numpy as np
import tensorflow as tf
import farm
import torch

In [3]:
# Load the saved model
save_dir = Path("./saved_models/distilbert/6epochs_final")

In [None]:
model = Inferencer.load(save_dir)

In [None]:
result_600k_test = model.inference_from_file(file="./DataSetCreation/df_prediction_cols_600k.tsv")

In [11]:
import pandas as pd
import numpy as np

In [12]:
label_list = pd.read_csv('labels_new_distilbert.csv')
label_list = label_list['0'].values.tolist()

In [13]:
label_list

['Germany',
 'Asia',
 'North_Europe',
 'Eastern_Europe',
 'Western_Europe',
 'South_Europe',
 'North_America',
 'South_America',
 'Africa',
 'Oceania',
 'married',
 'partnership',
 'unmarried',
 'divorced',
 'below1500',
 'to4000',
 'over4000',
 '40.0',
 '30.0',
 '50.0',
 '60.0',
 '-999.0',
 '20.0',
 '70.0',
 'employed',
 'public_office',
 'unemployed',
 'student',
 'self-employed',
 'other_employment']

In [14]:
from keras.utils import np_utils

In [15]:
labels_index = { "Germany" : 0, 
                "Asia" : 1,
                'North_Europe' : 2,
                'Eastern_Europe' : 3,
                'Western_Europe' : 4,
                'South_Europe': 5,
                'North_America': 6,
                'South_America': 7,
                'Africa': 8,
                'Oceania': 9,
                'married': 10,
                'partnership': 11,
                'unmarried': 12,
                'divorced': 13,
                'below1500': 14,
                'to4000': 15,
                'over4000': 16,
                '40.0': 17,
                '30.0': 18,
                '50.0': 19,
                '60.0': 20,
                '-999.0': 21,
                '20.0': 22,
                '70.0': 23,
                'employed': 24,
                'public_office': 25,
                'unemployed': 26,
                'student': 27,
                'self-employed' : 28,
                'other_employment': 29
               }
groups_labels = {
    
    'nationality': (0,9),
    'marital_status': (10, 13),
    'income_group': (14, 16),
    'age_decade': (17, 23),
    'employment_group': (24, 29)
}

In [16]:
test_df_600k_lp = pd.read_csv('./DataSetCreation/df_newtext_na_600k_all_cols.csv', sep='\t')

In [17]:
contexts = [p['context'] for l in result_600k_test for p in l['predictions']  ]
predictions = [p['probability'] for l in result_600k_test for p in l['predictions']  ]

## Read original dataset including hand-given labels

In [18]:
import pandas as pd
import numpy as np
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)

test_df_600k_lp = pd.read_csv('./DataSetCreation/df_newtext_na_600k_all_cols.csv', sep='\t')
test_df_600k_lp['prob_income_group'] = np.nan
test_df_600k_lp['prob_employment_group'] = np.nan
test_df_600k_lp['prob_marital_status'] = np.nan
test_df_600k_lp['prob_nationality'] = np.nan
test_df_600k_lp['prob_age_decade'] = np.nan

In [19]:
test_df_600k_lp['age_decade'] = test_df_600k_lp.age_decade.astype(str)

In [20]:
test_df_600k_lp['age_decade'] = test_df_600k_lp['age_decade'].astype('category')

In [21]:
test_df_600k_lp.income_group.replace({'missing_income':'missing_income_group'}, inplace=True)
test_df_600k_lp.employment_group.replace({'missing_employment':'missing_employment_group'}, inplace=True)
test_df_600k_lp.age_decade.replace({'-999.0':'missing_age_decade'}, inplace=True)

In [26]:
test_df_600k_lp.rename(columns={'maritial_status':'marital_status'}, inplace=True)

In [27]:
_labels = list(labels_index.keys())
treshold_prob = 0.5


for cix, c in enumerate(contexts):
    for gix, g in enumerate(groups_labels):
        max_from_group = groups_labels[g][0] + np.argmax(predictions[cix][groups_labels[g][0]:groups_labels[g][1]])
        if (test_df_600k_lp.at[cix,g]=='missing_'+g):
            test_df_600k_lp.at[cix,g]=_labels[max_from_group]
            test_df_600k_lp.at[cix,'prob_'+g]=max(predictions[cix][groups_labels[g][0]:groups_labels[g][1]])
            


In [28]:
test_df_600k_lp.to_csv('./DataSetCreation/test_overwritten_labels_600k.tsv',sep='\t', header = True)