# Comparison - Multicass vs binary classifier for middlepage classification

In [1]:
import pandas as pd
from sklearn.metrics import ConfusionMatrixDisplay, accuracy_score, f1_score
from preprocessing.train_val_test_split import get_data_files_df
from document_classification.document_classifier import DocumentClassifier
import matplotlib.pyplot as plt
from prediction.predict import predict_documents
import os

## Load models

In [None]:
multiclass_others_clf = DocumentClassifier.load_from_path(
    "EfficientNetB0",
    # "/data/dssg/occrp/data/output/document_classifier/AlexNetDropout_2022_08_16-10_43_41",
    "/data/dssg/occrp/data/output/document_classifier/EfficientNetB0_2022_08_16-10_54_00",
)

In [None]:
binary_clf = DocumentClassifier.load_from_path(
    "EfficientNetB0", "/data/dssg/occrp/data/output/document_classifier/AlexNet_2022_08_12-19_21_46"
)

## Load test sets

In [None]:
labels_filter = [
    "bank-statements",
    "company-registry",
    "contracts",
    "court-documents",
    "gazettes",
    "invoices",
    # "middle-page",
    "passport-scan",
    "receipts",
    "shipping-receipts",
    # "transcripts",
]


page_2_test_df = get_data_files_df("/data/dssg/occrp/data/processed_clean", labels_filter)
page_2_test_df["class"] = "middlepage"
page_2_test_df["page_number"] = page_2_test_df["filename"].str.extract(r"(\d+).jpg")
page_2_test_df = page_2_test_df[page_2_test_df["page_number"] == "2"].reset_index(drop=True)
page_2_test_df = page_2_test_df.rename(columns={"path": "directory"})
page_2_test_df

In [None]:
middlepages_test_df = get_data_files_df(
    "/data/dssg/occrp/data/processed_firstpages_vs_middle_pages/processed_clean/", ["middlepages"]
)
middlepages_test_df["class"] = "middlepage"
middlepages_test_df["page_number"] = middlepages_test_df["filename"].str.extract(r"(\d+).jpg")
# page_2_test_df = page_2_test_df[page_2_test_df["page_number"] == 2].reset_index(drop=True)
middlepages_test_df = middlepages_test_df.rename(columns={"path": "directory"})
middlepages_test_df

## Compare

In [None]:
def compare_classifiers(df):
    middlepages_binary_prediction_df = binary_clf.predict_from_df(df)

    middlepages_multiclass_prediction_df = multiclass_others_clf.predict_from_df(df)
    middlepages_multiclass_prediction_df["true-label"] = "middlepage"
    middlepages_multiclass_prediction_df["predicted"] = middlepages_multiclass_prediction_df["predicted"].map(
        lambda u: "middlepage" if u == "other" else "firstpage"
    )

    binary_acc = accuracy_score(
        middlepages_binary_prediction_df["true-label"], middlepages_binary_prediction_df["predicted"]
    )
    multiclass_acc = accuracy_score(
        middlepages_multiclass_prediction_df["true-label"], middlepages_multiclass_prediction_df["predicted"]
    )

    ConfusionMatrixDisplay.from_predictions(
        middlepages_binary_prediction_df["true-label"], middlepages_binary_prediction_df["predicted"]
    )
    plt.title(f"Binary (acc = {100*binary_acc:.2f}%)")
    plt.savefig("binary_middlepages_confusion_matrix.png")
    plt.show()

    ConfusionMatrixDisplay.from_predictions(
        middlepages_multiclass_prediction_df["true-label"], middlepages_multiclass_prediction_df["predicted"]
    )
    plt.title(f"Multiclass (acc = {100*multiclass_acc:.2f}%)")
    plt.savefig("multiclass_middlepages_confusion_matrix.png")
    plt.show()

In [None]:
compare_classifiers(middlepages_test_df)

In [None]:
compare_classifiers(page_2_test_df)

## Intersection of test sets

In [33]:
def test_full_clf(binary, multiclass):
    binary_test_df = pd.read_csv(
        f"/data/dssg/occrp/data/output/document_classifier/{binary}/model_inputs/test.txt",
        sep=" ",
        names=["file_path", "doc type str"],
        dtype=str,
    )
    multiclass_test_df = pd.read_csv(
        f"/data/dssg/occrp/data/output/document_classifier/{multiclass}/model_inputs/test.txt",
        sep=" ",
        names=["file_path", "doc type str"],
        dtype=str,
    )

    multiclass_labels_df = pd.read_csv(
        f"/data/dssg/occrp/data/output/document_classifier/{multiclass}/model_inputs/labels.csv"
    )

    binary_test_df["dir"] = binary_test_df["file_path"].apply(lambda u: os.path.basename(os.path.dirname(u)))
    multiclass_test_df["dir"] = multiclass_test_df["file_path"].apply(lambda u: os.path.basename(os.path.dirname(u)))
    test_intersection_df = multiclass_test_df[multiclass_test_df["dir"].isin(binary_test_df["dir"])].reset_index()
    print("test set size", len(test_intersection_df))
    test_intersection_df["file_path"] = "/data/dssg/occrp" + test_intersection_df["file_path"]

    test_intersection_df["prediction"] = predict_documents(test_intersection_df["file_path"].to_list(), "EfficientNetB4")
    test_intersection_df["prediction"] = test_intersection_df["prediction"].apply(lambda u: u[0])
    test_intersection_df["doc type str"] = test_intersection_df["doc type str"].astype("int")
    test_intersection_df = test_intersection_df.merge(
        multiclass_labels_df, left_on="doc type str", right_on="index", how="left"
    )

    return accuracy_score(test_intersection_df["label"], test_intersection_df["prediction"])

In [34]:
test_full_clf("EfficientNetB4_2022_08_17-13_58_56", "EfficientNetB4_2022_08_16-11_34_48")

test set size 55
55




Found 55 validated image filenames.
Found 54 validated image filenames.






0.7636363636363637

In [None]:
test_full_clf("EfficientNetB4BW_2022_08_17-15_55_41", "EfficientNetB4_2022_08_16-11_34_48")