In [1]:
from italian_csv_type_prediction.dataframe_generators import SimpleDatasetGenerator
from italian_csv_type_prediction.models import TypePredictor
from tqdm.auto import tqdm
import pandas as pd
import numpy as np
from multiprocessing import Pool, cpu_count
from sklearn.metrics import accuracy_score, balanced_accuracy_score

In [2]:
def dataset_generation(number:int):
    return SimpleDatasetGenerator(verbose=False).build(number)

def _dataset_generation(args):
    return dataset_generation(*args)

def parallel_dataset_generation(number:int):
    processes = min(cpu_count(), number)
    with Pool(cpu_count()) as p:
        Xs, ys = list(zip(*tqdm(
            p.imap(_dataset_generation, (
                (number//processes, )
                for _ in range(processes)
            )),
            desc="Creating dataset",
            total=processes,
            leave=False
        )))
    return np.vstack(Xs), np.concatenate(ys)

In [3]:
x_train, y_train = parallel_dataset_generation(1000)
x_test, y_test = parallel_dataset_generation(1000)

HBox(children=(FloatProgress(value=0.0, description='Creating dataset', max=16.0, style=ProgressStyle(descript…

HBox(children=(FloatProgress(value=0.0, description='Creating dataset', max=16.0, style=ProgressStyle(descript…

In [4]:
model = TypePredictor()

model.fit(x_train, y_train)

In [5]:
y_pred = model._model.predict(x_test)
y_train_pred = model._model.predict(x_train)

In [6]:
accuracy_score(y_test, y_pred), balanced_accuracy_score(y_test, y_pred)

(0.9985902498749762, 0.9985586708842794)

In [7]:
accuracy_score(y_train, y_train_pred), balanced_accuracy_score(y_train, y_train_pred)

(0.9996572193810486, 0.9996817463551818)

In [8]:
X, y = SimpleDatasetGenerator().generate_simple_dataframe(max_rows=10)

In [9]:
y

Unnamed: 0,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,ItalianZIPCode,ProvinceCode,Region,Municipality,...,Name,NameSurname,SurnameName,Surname,String,EMail,PhoneNumber,Date,BiologicalSex,Boolean
0,ItalianVAT,CadastreCode,Document,Tax,,Address,ItalianZIPCode,ProvinceCode,,Municipality,...,Name,NameSurname,SurnameName,Surname,,EMail,PhoneNumber,Date,,
1,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,ItalianZIPCode,ProvinceCode,Region,,...,Name,,SurnameName,,String,EMail,PhoneNumber,,BiologicalSex,Boolean
2,ItalianVAT,CadastreCode,Document,,Plate,,ItalianZIPCode,,Region,,...,Name,Error,SurnameName,,String,EMail,PhoneNumber,Date,BiologicalSex,Boolean
3,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,,ProvinceCode,Region,Municipality,...,Name,,SurnameName,Surname,String,EMail,PhoneNumber,Date,,Boolean
4,,CadastreCode,Error,Tax,,Address,ItalianZIPCode,ProvinceCode,Region,Municipality,...,Name,NameSurname,,Surname,String,EMail,,Date,BiologicalSex,Boolean
5,ItalianVAT,CadastreCode,Document,Error,Plate,Address,ItalianZIPCode,ProvinceCode,Region,Municipality,...,,NameSurname,SurnameName,,String,EMail,Error,,,Boolean
6,Error,CadastreCode,Document,Tax,Plate,Address,ItalianZIPCode,ProvinceCode,,,...,Name,NameSurname,SurnameName,,String,,PhoneNumber,Date,BiologicalSex,Boolean
7,,CadastreCode,Document,,,Address,ItalianZIPCode,ProvinceCode,Error,Municipality,...,Name,NameSurname,SurnameName,Surname,,EMail,PhoneNumber,Date,BiologicalSex,
8,ItalianVAT,CadastreCode,Document,Error,,Address,ItalianZIPCode,ProvinceCode,Region,Municipality,...,,NameSurname,SurnameName,Surname,String,,,,BiologicalSex,Boolean
9,,CadastreCode,,Tax,Plate,Address,ItalianZIPCode,,Region,Municipality,...,Name,NameSurname,SurnameName,Surname,String,EMail,PhoneNumber,,BiologicalSex,Boolean


In [10]:
model.predict_dataframe(X)

Unnamed: 0,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,ItalianZIPCode,ProvinceCode,Region,Municipality,...,Name,NameSurname,SurnameName,Surname,String,EMail,PhoneNumber,Date,BiologicalSex,Boolean
0,ItalianVAT,CadastreCode,Document,Tax,,Address,ItalianZIPCode,ProvinceCode,,Municipality,...,Name,NameSurname,SurnameName,Surname,,EMail,PhoneNumber,Date,,
1,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,ItalianZIPCode,ProvinceCode,Region,,...,Name,,SurnameName,,String,EMail,PhoneNumber,,BiologicalSex,Boolean
2,ItalianVAT,CadastreCode,Document,,Plate,,ItalianZIPCode,,Region,,...,Name,NameSurname,SurnameName,,String,EMail,PhoneNumber,Date,BiologicalSex,Boolean
3,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,,ProvinceCode,Region,Municipality,...,Name,,SurnameName,Surname,String,EMail,PhoneNumber,Date,,Boolean
4,,CadastreCode,Error,Tax,,Address,ItalianZIPCode,ProvinceCode,Region,Municipality,...,Name,NameSurname,,Surname,String,EMail,,Date,BiologicalSex,Boolean
5,ItalianVAT,CadastreCode,Document,Error,Plate,Address,ItalianZIPCode,ProvinceCode,Region,Municipality,...,,NameSurname,SurnameName,,String,EMail,Error,,,Boolean
6,Error,CadastreCode,Document,Tax,Plate,Address,ItalianZIPCode,ProvinceCode,,,...,Name,NameSurname,SurnameName,,String,,PhoneNumber,Date,BiologicalSex,Boolean
7,,CadastreCode,Document,,,Address,ItalianZIPCode,ProvinceCode,Error,Municipality,...,Name,NameSurname,SurnameName,Surname,,EMail,PhoneNumber,Date,BiologicalSex,
8,ItalianVAT,CadastreCode,Document,Error,,Address,ItalianZIPCode,ProvinceCode,Region,Municipality,...,,NameSurname,SurnameName,Surname,String,,,,BiologicalSex,Boolean
9,,CadastreCode,,Tax,Plate,Address,ItalianZIPCode,,Region,Municipality,...,Name,NameSurname,SurnameName,Surname,String,EMail,PhoneNumber,,BiologicalSex,Boolean


In [11]:
from collections import Counter

mask = y_test != y_pred

true_labels = model._embedder._encoder.inverse_transform(y_test[mask])
predicted_labels = model._embedder._encoder.inverse_transform(y_pred[mask])

Counter(zip(true_labels, predicted_labels))

Counter({('ItalianZIPCode', 'Error'): 20,
         ('Error', 'Boolean'): 1,
         ('Error', 'SurnameName'): 73,
         ('CountryCode', 'NaN'): 62,
         ('NameSurname', 'Error'): 90,
         ('Error', 'Document'): 148,
         ('SurnameName', 'Error'): 118,
         ('Error', 'Surname'): 2,
         ('Error', 'CountryCode'): 22,
         ('Error', 'Tax'): 4,
         ('NaN', 'CountryCode'): 107,
         ('String', 'Error'): 36,
         ('Error', 'NameSurname'): 50,
         ('Error', 'ItalianVAT'): 12,
         ('Name', 'Error'): 20,
         ('Error', 'ProvinceCode'): 26,
         ('Error', 'String'): 7,
         ('Error', 'ItalianZIPCode'): 6,
         ('Error', 'ItalianFiscalCode'): 21,
         ('SurnameName', 'String'): 3,
         ('CountryCode', 'ProvinceCode'): 20,
         ('Error', 'NaN'): 3,
         ('Document', 'Error'): 25,
         ('Date', 'Error'): 6,
         ('Error', 'Name'): 8,
         ('Error', 'Country'): 3,
         ('Surname', 'Error'): 14,
       