# Train a definitive model, log it and predict golden standard

## Paths import 


In [1]:
import os
import sys
os.chdir('../')

Make sure that your current working directory (cwd) is `ReproducingAugSS/AugmentedSocialScientist/`

In [2]:
#os.getcwd() 

In [3]:
from PATHS import OFF_ASS, OFF_GS, SAVED_MODELS_PATH

## Parameters definition

In [4]:
N_EPOCHS_OFF = 5
SAMPLER_OFF = "random"
LR_OFF = 5e-5
BS_OFF = 32

N_EPOCHS_ENDOEXO = 25
SAMPLER_ENDOEXO = "sequential"
LR_ENDOEXO = 1e-5
BS_ENDOEXO = 64

DROP_DUPLICATES = True
PERCENT_OF_DATA = 1


## General imports


In [5]:
import pandas as pd
import numpy as np
from sklearn.metrics import precision_recall_fscore_support
from TransferSociologist.data import Dataset
from TransferSociologist.models import BertSequence
from TransferSociologist.utils import regularize_seqlab
from operator import add
from functools import reduce
from copy import deepcopy
import os
import json
import logging
import sys
from torch.cuda import empty_cache


def try_eval(x):
    try:
        return eval(x)
    except:
        return x


def fill_zeros(labels, zeros, conv_dict):
    try:
        labels = eval(labels)
    except:
        pass
    for l in labels:
        start_span, stop_span, lab = l
        size = len(zeros[start_span:stop_span])
        number = conv_dict[lab]
        zeros[start_span:stop_span] = [number] * size
    return zeros


[nltk_data] Downloading package punkt to
[nltk_data]     /pbs/home/r/rshen/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


## Util function computing true positives in our custom metric (span)

In [6]:
def compute_TP(x, thres=0, factor=0.25):
    nb_TP = 0
    labels = x.labels
    pred_labels = list(filter(lambda y: y[1] - y[0] > thres, x.pred_labels))

    for l_pred in pred_labels:
        for l in labels:
            is_overlap_1 = l[0] in range(l_pred[0], l_pred[1] + 1)
            is_overlap_2 = l[1] in range(l_pred[0], l_pred[1] + 1)
            is_overlap = is_overlap_1 or is_overlap_2
            if (
                is_overlap == True
            ):  # There is overlap, compute intersection etc size
                left_border = min(l_pred[0], l[0])
                right_border = max(l_pred[1], l[1])
                one_hot_pred = x.labels_pred_str[left_border:right_border]
                one_hot = x.labels_str[left_border:right_border]
                mul = np.dot(one_hot_pred, one_hot)
                intersection_length = mul.sum()
                if intersection_length > factor * (l[1] - l[0]):
                    nb_TP = nb_TP + 1
                    break
    return nb_TP

## 1. Preparing the experiment : loading and formatting the data

In [7]:
def prepare_experiment(
    train_path, gs_path, drop_duplicates=False, percent_of_data=1
):
    dataset = Dataset()
    dataset.read(
        data_path=train_path, gold_standard_path=gs_path, data_type="csv"
    )
    dataset.df = dataset.df.rename({'is_control_1': 'is_control'}, axis=1)
    if drop_duplicates == True:
        if "is_control" in dataset.df.columns:
            gs = dataset.df[dataset.df.is_gold_standard == True]
            no_gs = dataset.df[dataset.df.is_gold_standard == False]
            no_gs = pd.concat(
                [
                    no_gs[no_gs.is_control == True]
                        .groupby(["text"])
                        .apply(lambda x: x.sample(1))
                        .reset_index(drop=True),
                    no_gs[no_gs.is_control != True]
                ]
            )
            dataset.df = pd.concat([no_gs, gs])
    # Now sample subset of data
    gs = dataset.df[dataset.df.is_gold_standard == True]
    no_gs = dataset.df[dataset.df.is_gold_standard == False]
    no_gs = no_gs.sample(frac=percent_of_data)
    dataset.df = pd.concat([no_gs, gs])

    dataset.task_encode(task_type="sequence_labelling", bert_model="CamemBert")
    # natural_samples = dataset.df

    dataset.encode_torch(
        task_type="sequence_labelling",
        bert_model="CamemBert",
        # test_size=0.3,
        random_seed=2018,
    )

    dataset_pred = Dataset()
    dataset_pred.read(data_path=gs_path, data_type="csv")
    dataset_pred.task_encode(
        task_type="sequence_labelling",
        bert_model="CamemBert",
        # pred_gs=True,
        pred_mode=True,
    )
    dataset_pred.encode_torch(
        task_type="sequence_labelling", bert_model="CamemBert", pred_mode=True
    )
    return dataset, dataset_pred


## 2. Running the experiment :

In [8]:
def run_experiment(dataset, dataset_pred, batch_size, lr, sampler, nepochs):
    clf = BertSequence()
    random_seed = np.random.randint(2021)

    perfs, best_perfs, epoch_best = clf.fit_evaluate(
        dataset.train,
        dataset.test,
        batch_size=batch_size,
        sampler=sampler,
        nepochs=nepochs,
        random_seed=random_seed,
        learning_rate=lr,
    )
    perf_dic = {
        "batch_size": batch_size,
        "lr": lr,
        "sampler": sampler,
        "nepochs": nepochs,
        "best epoch": int(epoch_best),
        "random_seed": int(random_seed),
        "train_size": len(dataset.train[0])
    # TODO : return model : V, return dataset pred : V, return perfs dic : V. 
    }
    inv_conv_dict = {
        item: key
        for i, (key, item) in enumerate(dataset.conversion_dict.items())
    }
    for i in range(len(perfs[0])):
        j = inv_conv_dict[i]
        perf_dic[f"prec_{j}"] = float(perfs[0][i])
        perf_dic[f"rec_{j}"] = float(perfs[1][i])
        perf_dic[f"F1_{j}"] = float(perfs[2][i])
        perf_dic[f"supp_{j}"] = float(perfs[3][i])
        # perf_dic[f'prec_{j}_best_run'] = float(best_perfs[0][i])
        # perf_dic[f'rec_{j}_best_run'] = float(best_perfs[1][i])
        # perf_dic[f'F1_{j}_best_run'] = float(best_perfs[2][i])

    truncated_labels, truncated_logits = clf.predict(dataset_pred.pred)
    dataset_pred.df["truncated_labels"] = truncated_labels
    dataset_pred.df["truncated_logits"] = truncated_logits
    dataset_pred = regularize_seqlab(dataset_pred, dataset.tokenizer)

    preds = dataset_pred.df
    preds["labels_str"] = preds.sents.apply(lambda x: [0] * len(x))
    preds["labels_pred_str"] = preds.sents.apply(lambda x: [0] * len(x))
    preds["labels_str_len"] = preds["labels_str"].apply(len)
    # preds["labels_pred_str_len"] = preds["labels_pred_str"].apply(len)
    preds["labels_str"] = preds.apply(
        lambda x: fill_zeros(x.labels, x.labels_str, dataset.conversion_dict),
        axis=1,
    )
    preds["labels_pred_str"] = preds.apply(
        lambda x: fill_zeros(
            x.pred_labels, x.labels_pred_str, dataset.conversion_dict
        ),
        axis=1,
    )
    preds["labels_str_len2"] = preds["labels_str"].apply(len)
    preds["labels_pred_str_len2"] = preds["labels_pred_str"].apply(len)
    assert (preds["labels_str_len2"]==preds["labels_str_len"]).mean()==1, f'problem in true labels fill {(preds["labels_str_len2"]==preds["labels_str_len"]).mean()}'
    assert (preds["labels_pred_str_len2"]==preds["labels_str_len"]).mean()==1, f'problem in true labels fill {(preds["labels_pred_str_len2"]==preds["labels_str_len"]).mean()}'

    true = reduce(add, preds["labels_str"].values)
    pred = reduce(add, preds["labels_pred_str"].values)

    perfs_char = precision_recall_fscore_support(true, pred)
    for i in range(len(perfs_char[0])):
        j = inv_conv_dict[i]
        perf_dic[f"prec_char_{j}"] = float(perfs_char[0][i])
        perf_dic[f"rec_char_{j}"] = float(perfs_char[1][i])
        perf_dic[f"F1_char_{j}"] = float(perfs_char[2][i])
        perf_dic[f"supp_char_{j}"] = float(perfs_char[3][i])

    dataset_pred.df.pred_labels = dataset_pred.df.pred_labels.apply(try_eval)
    dataset_pred.df.labels = dataset_pred.df.labels.apply(try_eval)

    dataset_pred.df["TP_25"] = dataset_pred.df.apply(
        lambda x: compute_TP(x), axis=1
    )
    dataset_pred.df["TP_50"] = dataset_pred.df.apply(
        lambda x: compute_TP(x, factor=0.5), axis=1
    )
    dataset_pred.df["TPFP"] = dataset_pred.df.pred_labels.apply(len)
    dataset_pred.df["TPFN"] = dataset_pred.df.labels.apply(len)

    dataset_pred.df["TP_thres4_25"] = dataset_pred.df.apply(
        lambda x: compute_TP(x, thres=4), axis=1
    )
    dataset_pred.df["TP_thres4_50"] = dataset_pred.df.apply(
        lambda x: compute_TP(x, thres=4, factor=0.5), axis=1
    )
    dataset_pred.df["TPFP_thres4"] = dataset_pred.df.pred_labels.apply(
        lambda x: len(list(filter(lambda y: y[1] - y[0] > 4, x)))
    )
    dataset_pred.df["TPFN_thres4"] = dataset_pred.df.labels.apply(len)

    perf_dic["prec_span_25"] = (
        dataset_pred.df["TP_25"].sum() / dataset_pred.df["TPFP"].sum()
    )
    perf_dic["rec_span_25"] = (
        dataset_pred.df["TP_25"].sum() / dataset_pred.df["TPFN"].sum()
    )
    perf_dic["F1_span_25"] = (
        2
        * perf_dic["prec_span_25"]
        * perf_dic["rec_span_25"]
        / (perf_dic["prec_span_25"] + perf_dic["rec_span_25"])
    )
    perf_dic["prec_span_T4_25"] = (
        dataset_pred.df["TP_thres4_25"].sum()
        / dataset_pred.df["TPFP_thres4"].sum()
    )
    perf_dic["rec_span_T4_25"] = (
        dataset_pred.df["TP_thres4_25"].sum()
        / dataset_pred.df["TPFN_thres4"].sum()
    )
    perf_dic["F1_span_T4"] = (
        2
        * perf_dic["prec_span_T4_25"]
        * perf_dic["rec_span_T4_25"]
        / (perf_dic["prec_span_T4_25"] + perf_dic["rec_span_T4_25"])
    )
    ## 50
    perf_dic["prec_span_50"] = (
        dataset_pred.df["TP_50"].sum() / dataset_pred.df["TPFP"].sum()
    )
    perf_dic["rec_span_50"] = (
        dataset_pred.df["TP_50"].sum() / dataset_pred.df["TPFN"].sum()
    )
    perf_dic["F1_span_50"] = (
        2
        * perf_dic["prec_span_50"]
        * perf_dic["rec_span_50"]
        / (perf_dic["prec_span_50"] + perf_dic["rec_span_50"])
    )
    perf_dic["prec_span_T4_50"] = (
        dataset_pred.df["TP_thres4_50"].sum()
        / dataset_pred.df["TPFP_thres4"].sum()
    )
    perf_dic["rec_span_T4_50"] = (
        dataset_pred.df["TP_thres4_50"].sum()
        / dataset_pred.df["TPFN_thres4"].sum()
    )
    perf_dic["F1_span_T4_50"] = (
        2
        * perf_dic["prec_span_T4_50"]
        * perf_dic["rec_span_T4_50"]
        / (perf_dic["prec_span_T4_50"] + perf_dic["rec_span_T4_50"])
    )

    return perf_dic, dataset_pred, clf


## Putting things together & running

In [9]:
def process(params, paths, percent_of_data=1):
    train_path, gs_path = paths
    dataset, dataset_pred = prepare_experiment(
        train_path, gs_path, params["drop_duplicates"], percent_of_data
    )
    p, dataset_pred, clf = run_experiment(
        dataset,
        dataset_pred,
        params["batch_size"],
        params["lr"],
        params["sampler"],
        params["nepochs"],
    )
    return p, dataset_pred, clf

In [10]:
empty_cache()
tpath = OFF_ASS
gs_path = OFF_GS
params = {
"batch_size": BS_OFF,
"nepochs": N_EPOCHS_OFF,
"lr": LR_OFF,
"sampler": SAMPLER_OFF,
"drop_duplicates": DROP_DUPLICATES,
}
percent = 1

paths = tpath, gs_path
exp_name = os.path.basename(tpath).replace(
    "_train", "").replace('.csv', '')
p, dataset_pred, clf = process(params, paths, percent)
p["exp_name"] = exp_name
p["percent_of_data"] = percent


Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag : ▁x plique ▁un ▁de ▁ses ▁amis . not found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  c

Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  correctly found
Tag:  corr

Using gold standard


Some weights of the model checkpoint at camembert-base were not used when initializing CamembertForTokenClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing CamembertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CamembertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of CamembertForTokenClassification were not initialized from the model checkpoint at camembert-base and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream tas


Training...
  Batch    40  of    197.    Elapsed: 0:00:27.
  Batch    80  of    197.    Elapsed: 0:00:53.
  Batch   120  of    197.    Elapsed: 0:01:20.
  Batch   160  of    197.    Elapsed: 0:01:47.

  Average training loss: 0.13
  Training epoch took: 0:02:11

Running Validation...


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

         0.0       0.99      1.00      0.99    216987
         1.0       0.00      0.00      0.00      2190

    accuracy                           0.99    219177
   macro avg       0.50      0.50      0.50    219177
weighted avg       0.98      0.99      0.99    219177


Training...
  Batch    40  of    197.    Elapsed: 0:00:27.
  Batch    80  of    197.    Elapsed: 0:00:53.
  Batch   120  of    197.    Elapsed: 0:01:20.
  Batch   160  of    197.    Elapsed: 0:01:47.

  Average training loss: 0.08
  Training epoch took: 0:02:11

Running Validation...
              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00    216987
         1.0       0.84      0.76      0.80      2190

    accuracy                           1.00    219177
   macro avg       0.92      0.88      0.90    219177
weighted avg       1.00      1.00      1.00    219177


Training...
  Batch    40  of    197.    Elapsed: 0:00:27.


In [11]:
clf.save(os.path.join(SAVED_MODELS_PATH, 'off_ASS'))

In [12]:
p

{'batch_size': 32,
 'lr': 5e-05,
 'sampler': 'random',
 'nepochs': 5,
 'best epoch': 3,
 'random_seed': 210,
 'train_size': 6274,
 'prec_O': 0.9982988022461343,
 'rec_O': 0.9979261430408274,
 'F1_O': 0.9981124378591035,
 'supp_O': 216987.0,
 'prec_off': 0.8018494055482166,
 'rec_off': 0.8315068493150685,
 'F1_off': 0.8164088769334229,
 'supp_off': 2190.0,
 'prec_char_O': 0.9948128380626766,
 'rec_char_O': 0.9927885632799502,
 'F1_char_O': 0.9937996698588998,
 'supp_char_O': 307151.0,
 'prec_char_off': 0.7795800577171857,
 'rec_char_off': 0.8312818336162988,
 'F1_char_off': 0.8046012427463666,
 'supp_char_off': 9424.0,
 'prec_span_25': 0.7981927710843374,
 'rec_span_25': 0.9298245614035088,
 'F1_span_25': 0.8589951377633712,
 'prec_span_T4_25': 0.8259493670886076,
 'rec_span_T4_25': 0.9157894736842105,
 'F1_span_T4': 0.8685524126455906,
 'prec_span_50': 0.7831325301204819,
 'rec_span_50': 0.9122807017543859,
 'F1_span_50': 0.8427876823338736,
 'prec_span_T4_50': 0.810126582278481,
 'rec