In [1]:
from italian_csv_type_prediction.dataframe_generators import SimpleDatasetGenerator
from italian_csv_type_prediction.models import TypePredictor
from tqdm.auto import tqdm
import pandas as pd
import numpy as np
from multiprocessing import Pool, cpu_count
from sklearn.metrics import accuracy_score, balanced_accuracy_score

In [2]:
def dataset_generation(number:int):
    return SimpleDatasetGenerator(verbose=False).build(number)

def _dataset_generation(args):
    return dataset_generation(*args)

def parallel_dataset_generation(number:int, chunks:int, processes:int):
    processes = min(processes, number)
    with Pool(processes) as p:
        Xs, ys = list(zip(*tqdm(
            p.imap(_dataset_generation, (
                (number//chunks, )
                for _ in range(chunks)
            )),
            desc="Creating dataset",
            total=chunks,
            leave=False
        )))
    return np.vstack(Xs), np.concatenate(ys)

In [None]:
processes = cpu_count()//2
chunks = processes*6
x_train, y_train = parallel_dataset_generation(10000, chunks, processes)
x_test, y_test = parallel_dataset_generation(1000, chunks, processes)

HBox(children=(IntProgress(value=0, description='Creating dataset', max=36, style=ProgressStyle(description_wi…



In [12]:
model = TypePredictor()

model.fit(x_train, y_train)

In [13]:
y_pred = model._model.predict(x_test)
y_train_pred = model._model.predict(x_train)

In [14]:
accuracy_score(y_test, y_pred), balanced_accuracy_score(y_test, y_pred)

(0.9921307895409784, 0.9934967586922895)

In [15]:
accuracy_score(y_train, y_train_pred), balanced_accuracy_score(y_train, y_train_pred)

(0.9929827277066161, 0.9942316356903012)

In [16]:
X, y = SimpleDatasetGenerator().generate_simple_dataframe(max_rows=20)

In [17]:
y

Unnamed: 0,ItalianFiscalCode,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,ItalianZIPCode,ProvinceCode,Region,...,String,EMail,PhoneNumber,Date,BiologicalSex,Boolean,Name,Surname,SurnameName,NameSurname
0,ItalianFiscalCode,,CadastreCode,Document,Error,Plate,Address,ItalianZIPCode,ProvinceCode,Region,...,String,EMail,PhoneNumber,Date,BiologicalSex,Boolean,Name,Surname,SurnameName,NameSurname
1,ItalianFiscalCode,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,ItalianZIPCode,ProvinceCode,Region,...,String,,Error,Date,BiologicalSex,Boolean,Name,Surname,SurnameName,NameSurname
2,ItalianFiscalCode,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,,ProvinceCode,Region,...,String,EMail,PhoneNumber,Date,BiologicalSex,Boolean,Name,Error,Error,
3,ItalianFiscalCode,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,ItalianZIPCode,ProvinceCode,Region,...,String,EMail,PhoneNumber,Date,BiologicalSex,Error,Name,,SurnameName,NameSurname
4,ItalianFiscalCode,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,,,Region,...,String,EMail,,Date,BiologicalSex,Boolean,Name,Surname,SurnameName,NameSurname
5,ItalianFiscalCode,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,ItalianZIPCode,ProvinceCode,Region,...,String,EMail,PhoneNumber,Error,BiologicalSex,Boolean,,Surname,Error,NameSurname
6,ItalianFiscalCode,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,ItalianZIPCode,ProvinceCode,Region,...,String,EMail,PhoneNumber,Date,BiologicalSex,,Name,Surname,,NameSurname
7,ItalianFiscalCode,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,ItalianZIPCode,ProvinceCode,Region,...,String,EMail,PhoneNumber,Date,BiologicalSex,Boolean,Name,Surname,SurnameName,NameSurname
8,ItalianFiscalCode,Error,CadastreCode,Document,Tax,Plate,Address,ItalianZIPCode,,Region,...,String,EMail,PhoneNumber,Date,BiologicalSex,Boolean,Name,Surname,SurnameName,NameSurname
9,ItalianFiscalCode,ItalianVAT,CadastreCode,,Tax,Plate,Address,ItalianZIPCode,ProvinceCode,Region,...,String,,PhoneNumber,Error,BiologicalSex,Boolean,Name,Surname,,NameSurname


In [18]:
model.predict_dataframe(X)

Unnamed: 0,ItalianFiscalCode,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,ItalianZIPCode,ProvinceCode,Region,...,String,EMail,PhoneNumber,Date,BiologicalSex,Boolean,Name,Surname,SurnameName,NameSurname
0,ItalianFiscalCode,,CadastreCode,Document,Error,Plate,Address,ItalianZIPCode,ProvinceCode,Region,...,String,EMail,PhoneNumber,Date,BiologicalSex,Boolean,Name,Surname,SurnameName,NameSurname
1,ItalianFiscalCode,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,ItalianZIPCode,ProvinceCode,Region,...,String,,PhoneNumber,Date,BiologicalSex,Boolean,Name,Surname,SurnameName,NameSurname
2,ItalianFiscalCode,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,,ProvinceCode,Region,...,String,EMail,PhoneNumber,Date,BiologicalSex,Boolean,Name,Company,Company,
3,ItalianFiscalCode,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,ItalianZIPCode,ProvinceCode,Region,...,String,EMail,PhoneNumber,Date,BiologicalSex,Error,Name,,SurnameName,NameSurname
4,ItalianFiscalCode,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,,,Region,...,String,EMail,,Date,BiologicalSex,Boolean,Name,Surname,SurnameName,NameSurname
5,ItalianFiscalCode,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,ItalianZIPCode,ProvinceCode,Region,...,String,EMail,PhoneNumber,Error,BiologicalSex,Boolean,,Surname,Company,NameSurname
6,ItalianFiscalCode,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,ItalianZIPCode,ProvinceCode,Region,...,String,EMail,PhoneNumber,Date,BiologicalSex,,Name,Surname,,NameSurname
7,ItalianFiscalCode,ItalianVAT,CadastreCode,Document,Tax,Plate,Address,ItalianZIPCode,ProvinceCode,Region,...,String,EMail,PhoneNumber,Date,BiologicalSex,Boolean,Name,Surname,SurnameName,NameSurname
8,ItalianFiscalCode,Error,CadastreCode,Document,Tax,Plate,Address,ItalianZIPCode,,Region,...,String,EMail,PhoneNumber,Date,BiologicalSex,Boolean,Name,Surname,SurnameName,NameSurname
9,ItalianFiscalCode,ItalianVAT,CadastreCode,,Tax,Plate,Address,ItalianZIPCode,ProvinceCode,Region,...,String,,PhoneNumber,Error,BiologicalSex,Boolean,Name,Surname,,NameSurname


In [19]:
from collections import Counter

mask = y_test != y_pred

true_labels = model._embedder._encoder.inverse_transform(y_test[mask])
predicted_labels = model._embedder._encoder.inverse_transform(y_pred[mask])

Counter(zip(true_labels, predicted_labels))

Counter({('ItalianZIPCode', 'Error'): 32,
         ('Error', 'Company'): 2983,
         ('Error', 'Document'): 618,
         ('Error', 'ProvinceCode'): 26,
         ('NaN', 'Company'): 142,
         ('Company', 'Error'): 452,
         ('NaN', 'Error'): 1,
         ('Error', 'ItalianZIPCode'): 118,
         ('Error', 'CountryCode'): 30,
         ('NaN', 'CountryCode'): 38,
         ('Company', 'NaN'): 233,
         ('Error', 'Date'): 2,
         ('Error', 'PhoneNumber'): 38,
         ('Integer', 'NaN'): 4,
         ('Float', 'NaN'): 3,
         ('CountryCode', 'NaN'): 20,
         ('ItalianVAT', 'Error'): 2,
         ('CountryCode', 'Company'): 3,
         ('Error', 'String'): 42,
         ('PhoneNumber', 'String'): 5,
         ('NaN', 'Float'): 2,
         ('Float', 'Integer'): 62,
         ('SurnameName', 'Company'): 3,
         ('Error', 'Integer'): 1,
         ('ItalianZIPCode', 'Integer'): 13,
         ('Error', 'Municipality'): 5,
         ('Company', 'String'): 54,
         ('Nam