In [None]:
TRAIN_PATH = '../../data/documents_functions_data/document_function_train.csv'
TEST_PATH = '../../data/documents_functions_data/document_function_test.csv'

In [None]:
import pandas as pd
import plotly.express as px
import spacy
import plotly.express as px
from sklearn.metrics import accuracy_score
from sklearn.metrics import plot_confusion_matrix, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.linear_model import LogisticRegression
from flair.models import TARSClassifier
from flair.data import Sentence
from flair.datasets import SentenceDataset
from flair.trainers import ModelTrainer
from flair.data import Corpus

In [None]:
df = pd.read_csv(TRAIN_PATH)
df_test = pd.read_csv(TEST_PATH)

X_train = df.title
y_train = df.documentFunction
X_test = df_test.title
y_test = df_test.documentFunction


In [None]:
px.histogram(y_train)

In [None]:
df.title

## Baseline: argmax

In [None]:
accuracy_score(df.documentFunction, ['strategy']*len(df))

## Baseline: logistic regression on BoW

In [None]:
def prep_lemmas(df):
    lemmas = df.title.apply(nlp).apply(lambda d: [(t.lemma_).lower() for t in d if not t.is_stop])
    lemmas_joined = lemmas.apply(lambda t: ' '.join(t))
    return lemmas_joined

nlp = spacy.load('en_core_web_sm')
X = prep_lemmas(df)
classifier = LogisticRegression(penalty="l2", multi_class="multinomial")

pipeline = Pipeline(
    [
        ("preprocessing", CountVectorizer()),
        ("classifier", LogisticRegression(penalty="l2", multi_class="multinomial")),
    ]
)

pipeline = pipeline.fit(X, df.documentFunction)

#### Train acc

In [None]:
y_pred = pipeline.predict(X)
accuracy_score(y_train, y_pred)

#### Test acc

In [None]:
X = prep_lemmas(df_test)

y_pred = pipeline.predict(X)
accuracy_score(y_test, y_pred)

## TARS: zero shot (TODO)

In [None]:
tars = TARSClassifier.load('tars-base')
classes = list(df_test.documentFunction.drop_duplicates())
def predict_zero_shot(titles):
    prepared = titles.apply(Sentence)
    prepared.apply()
s = Sentence(df.title[1])
tars.predict_zero_shot(s, classes)
#TODO

## TARS: few shot

In [None]:
def create_set(sentences: list, labels: list, label_name: str):
    dataset = SentenceDataset(
        [
            Sentence(sent).add_label(label_name, label)
            for sent, label in zip(sentences, labels)
        ]
    )
    
    return dataset

train_set = create_set(X_train, y_train, "Document function")
test_set = create_set(X_test, y_test, "Document function")
corpus = Corpus(train=train_set, test=test_set)
label_type = 'Document function'
label_dict = corpus.make_label_dictionary(label_type=label_type)

tars.add_and_switch_to_new_task("Document function classification",label_dictionary=label_dict,label_type=label_type)
trainer = ModelTrainer(tars, corpus)
trainer.train(base_path='resources/taggers/policy_doc_class', # path to store the model artifacts
               learning_rate=0.02, # use very small learning rate
               mini_batch_size=16, # small mini-batch size since corpus is tiny
              max_epochs=10, # terminate after 10 epochs
              train_with_dev=True,
              )

trained = TARSClassifier.load('resources/taggers/policy_doc_class/final-model.pt')

In [None]:
test_sents = X_test.apply(Sentence)
trained.predict(list(test_sents))
predictions = test_sents.apply(lambda x: x.labels[0].value)

In [None]:
labels = y_test.unique()
ConfusionMatrixDisplay(
    confusion_matrix(y_test, predictions, labels=labels), display_labels=labels
).plot(xticks_rotation="vertical", cmap=plt.cm.Blues, colorbar=False)
plt.tight_layout()
plt.savefig('confusion_matrix.pdf')#, dpi=200)

## Bigger labels for TARS (full sentences as labels) (TODO)

## Using first sentences from documents as training data (TODO)

## RoBerta Baseline (TODO)