In [14]:
from italian_csv_type_prediction.dataframe_generators import SimpleDatasetGenerator
from italian_csv_type_prediction.models import TypePredictor
from tqdm.auto import tqdm
import pandas as pd
import numpy as np
from multiprocessing import Pool, cpu_count
from sklearn.metrics import accuracy_score, balanced_accuracy_score
from sklearn.ensemble import RandomForestClassifier

In [2]:
def dataset_generation(number:int):
    return SimpleDatasetGenerator().build(number, verbose=False)

def _dataset_generation(args):
    return dataset_generation(*args)

def parallel_dataset_generation(number:int):
    processes = min(cpu_count()*5, number)
    with Pool(cpu_count()) as p:
        Xs, ys = list(zip(*tqdm(
            p.imap(_dataset_generation, (
                (number//processes, )
                for _ in range(processes)
            )),
            desc="Creating dataset",
            total=processes,
            leave=False
        )))
    return np.vstack(Xs), np.concatenate(ys)

In [3]:
x_train, y_train = parallel_dataset_generation(1000)
x_test, y_test = parallel_dataset_generation(1000)

HBox(children=(FloatProgress(value=0.0, description='Creating dataset', max=80.0, style=ProgressStyle(descript…

HBox(children=(FloatProgress(value=0.0, description='Creating dataset', max=80.0, style=ProgressStyle(descript…

In [4]:
model = TypePredictor()

model.fit(x_train, y_train)

In [15]:
forest = RandomForestClassifier()

forest.fit(x_train, y_train)

RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,
                       criterion='gini', max_depth=None, max_features='auto',
                       max_leaf_nodes=None, max_samples=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=100,
                       n_jobs=None, oob_score=False, random_state=None,
                       verbose=0, warm_start=False)

In [5]:
y_pred = model._model.predict(x_test)
y_train_pred = model._model.predict(x_train)

In [16]:
y_pred = forest.predict(x_test)
y_train_pred = forest.predict(x_train)

In [17]:
accuracy_score(y_test, y_pred), balanced_accuracy_score(y_test, y_pred)

(0.9939285880950121, 0.9956782806113973)

In [18]:
accuracy_score(y_train, y_train_pred), balanced_accuracy_score(y_train, y_train_pred)

(0.9973542309124945, 0.9986382891993849)

In [8]:
X, y = SimpleDatasetGenerator().generate_simple_dataframe()

In [9]:
y

Unnamed: 0,ItalianFiscalCode,CadastreCode,Document,Plate,Address,ItalianZIPCode,ProvinceCode,Region,Municipality,Year,...,Name,Surname,String,EMail,PhoneNumber,Currency,Date,BiologicalSex,Boolean,NumericId
0,ItalianFiscalCode,CadastreCode,Document,Error,Address,ItalianZIPCode,ProvinceCode,Region,Error,Error,...,Name,,String,EMail,PhoneNumber,Currency,,BiologicalSex,Boolean,NumericId
1,ItalianFiscalCode,Error,Error,Plate,,ItalianZIPCode,,Region,Municipality,,...,Name,Error,String,EMail,PhoneNumber,Currency,Date,BiologicalSex,Boolean,
2,ItalianFiscalCode,,,,Address,ItalianZIPCode,ProvinceCode,Region,Municipality,,...,Name,Surname,,EMail,PhoneNumber,Currency,Date,,Boolean,NumericId
3,ItalianVAT,CadastreCode,Document,Plate,Address,ItalianZIPCode,Error,,Municipality,Year,...,,,Error,Error,PhoneNumber,Currency,,BiologicalSex,Boolean,NumericId
4,ItalianFiscalCode,,,Plate,Address,ItalianZIPCode,ProvinceCode,Region,Municipality,Year,...,,,,,PhoneNumber,Currency,Date,,,
5,ItalianVAT,CadastreCode,Document,Error,Address,ItalianZIPCode,ProvinceCode,Region,Error,,...,Name,Surname,String,EMail,PhoneNumber,Error,Date,BiologicalSex,Boolean,NumericId
6,ItalianFiscalCode,CadastreCode,,Plate,Address,Error,,Region,Municipality,,...,Name,Surname,Error,EMail,PhoneNumber,Currency,,BiologicalSex,Boolean,NumericId
7,ItalianVAT,CadastreCode,Document,Plate,Address,,ProvinceCode,,Municipality,Year,...,Name,,String,EMail,PhoneNumber,Currency,Date,Error,Boolean,NumericId
8,ItalianVAT,CadastreCode,,Error,Address,ItalianZIPCode,ProvinceCode,,Error,,...,,Surname,Error,,PhoneNumber,Currency,Date,Error,Boolean,NumericId
9,ItalianVAT,,Document,Plate,Address,ItalianZIPCode,,,Municipality,Year,...,,,,EMail,PhoneNumber,,Date,BiologicalSex,Boolean,NumericId


In [10]:
(model.predict_dataframe(X) == y).any()

ItalianFiscalCode    True
CadastreCode         True
Document             True
Plate                True
Address              True
ItalianZIPCode       True
ProvinceCode         True
Region               True
Municipality         True
Year                 True
Integer              True
Float                True
Country              True
CountryCode          True
Name                 True
Surname              True
String               True
EMail                True
PhoneNumber          True
Currency             True
Date                 True
BiologicalSex        True
Boolean              True
NumericId            True
dtype: bool

In [19]:
from collections import Counter

mask = y_test != y_pred

true_labels = model._embedder._encoder.inverse_transform(y_test[mask])
predicted_labels = model._embedder._encoder.inverse_transform(y_pred[mask])

Counter(zip(true_labels, predicted_labels))

Counter({('Error', 'ProvinceCode'): 65,
         ('Error', 'Name'): 537,
         ('Error', 'Surname'): 445,
         ('Error', 'String'): 834,
         ('Error', 'Address'): 90,
         ('Error', 'NumericId'): 143,
         ('Surname', 'Error'): 113,
         ('ItalianZIPCode', 'Error'): 9,
         ('String', 'Error'): 499,
         ('Error', 'CountryCode'): 86,
         ('CountryCode', 'NaN'): 102,
         ('String', 'Name'): 1,
         ('EMail', 'Error'): 11,
         ('Error', 'Document'): 145,
         ('Error', 'ItalianZIPCode'): 38,
         ('Name', 'Error'): 133,
         ('Error', 'Integer'): 155,
         ('ProvinceCode', 'Error'): 8,
         ('Boolean', 'Error'): 18,
         ('ItalianVAT', 'Error'): 20,
         ('ItalianFiscalCode', 'Error'): 13,
         ('Municipality', 'Error'): 4,
         ('Plate', 'Error'): 14,
         ('Error', 'ItalianVAT'): 24,
         ('Currency', 'Error'): 20,
         ('Integer', 'Error'): 19,
         ('CountryCode', 'Error'): 6,
     