In [1]:
from italian_csv_type_prediction.dataframe_generators import SimpleDatasetGenerator
from italian_csv_type_prediction.models import TypePredictor
from tqdm.auto import tqdm
import pandas as pd
import numpy as np
from multiprocessing import Pool, cpu_count
from sklearn.metrics import accuracy_score, balanced_accuracy_score

In [2]:
def dataset_generation(number:int):
    return SimpleDatasetGenerator().build(number, verbose=False)

def _dataset_generation(args):
    return dataset_generation(*args)

def parallel_dataset_generation(number:int):
    processes = min(cpu_count()*5, number)
    with Pool(cpu_count()) as p:
        Xs, ys = list(zip(*tqdm(
            p.imap(_dataset_generation, (
                (number//processes, )
                for _ in range(processes)
            )),
            desc="Creating dataset",
            total=processes,
            leave=False
        )))
    return np.vstack(Xs), np.concatenate(ys)

In [3]:
x_train, y_train = parallel_dataset_generation(1000)
x_test, y_test = parallel_dataset_generation(1000)

HBox(children=(FloatProgress(value=0.0, description='Creating dataset', max=80.0, style=ProgressStyle(descript…

HBox(children=(FloatProgress(value=0.0, description='Creating dataset', max=80.0, style=ProgressStyle(descript…

In [4]:
model = TypePredictor()

model.fit(x_train, y_train)

In [5]:
y_pred = model._model.predict(x_test)
y_train_pred = model._model.predict(x_train)

In [6]:
accuracy_score(y_test, y_pred), balanced_accuracy_score(y_test, y_pred)

(0.9868033722001138, 0.988544848702392)

In [7]:
accuracy_score(y_train, y_train_pred), balanced_accuracy_score(y_train, y_train_pred)

(0.988926406744825, 0.9909282548348708)

In [8]:
X, y = SimpleDatasetGenerator().generate_simple_dataframe()

In [9]:
y

Unnamed: 0,ItalianFiscalCode,ItalianVAT,CadastreCode,Document,Plate,Address,ItalianZIPCode,ProvinceCode,Region,Municipality,...,Name,Surname,String,EMail,PhoneNumber,Currency,Date,BiologicalSex,Boolean,NumericId
0,,ItalianVAT,CadastreCode,Document,,Error,,ProvinceCode,Region,,...,,,String,,,Error,Error,,Boolean,NumericId
1,ItalianFiscalCode,ItalianVAT,CadastreCode,,Error,,ItalianZIPCode,ProvinceCode,Error,Municipality,...,Name,Surname,Error,,PhoneNumber,Currency,Date,BiologicalSex,Boolean,NumericId
2,,ItalianVAT,Error,,Plate,Address,,Error,Region,,...,Name,Surname,String,EMail,PhoneNumber,Currency,,,,NumericId
3,ItalianFiscalCode,,,,Error,Address,ItalianZIPCode,ProvinceCode,Region,Municipality,...,Name,Surname,String,EMail,PhoneNumber,Currency,,Error,Boolean,NumericId
4,ItalianFiscalCode,ItalianVAT,CadastreCode,Document,Plate,Address,ItalianZIPCode,,Error,,...,Name,Surname,String,EMail,PhoneNumber,Currency,Date,BiologicalSex,Boolean,NumericId
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
86,ItalianFiscalCode,,Error,,Plate,Address,ItalianZIPCode,Error,,Error,...,Name,Surname,String,EMail,PhoneNumber,Currency,Date,BiologicalSex,,NumericId
87,ItalianFiscalCode,ItalianVAT,CadastreCode,,Plate,Error,ItalianZIPCode,ProvinceCode,Region,Municipality,...,,Surname,,EMail,PhoneNumber,Currency,,,Boolean,NumericId
88,ItalianFiscalCode,ItalianVAT,,,Plate,Address,ItalianZIPCode,ProvinceCode,Region,Municipality,...,,Surname,String,,PhoneNumber,,Date,BiologicalSex,Error,NumericId
89,ItalianFiscalCode,ItalianVAT,CadastreCode,Error,Plate,Address,,ProvinceCode,Region,Municipality,...,,Surname,String,,PhoneNumber,Currency,Error,Error,Boolean,NumericId


In [10]:
model.predict_dataframe(X)

Unnamed: 0,ItalianFiscalCode,ItalianVAT,CadastreCode,Document,Plate,Address,ItalianZIPCode,ProvinceCode,Region,Municipality,...,Name,Surname,String,EMail,PhoneNumber,Currency,Date,BiologicalSex,Boolean,NumericId
0,,ItalianVAT,CadastreCode,Document,,Error,,ProvinceCode,Region,,...,,,String,,,Error,Error,,Boolean,NumericId
1,ItalianFiscalCode,ItalianVAT,CadastreCode,,Error,,ItalianZIPCode,ProvinceCode,Error,Municipality,...,Name,Surname,Error,,PhoneNumber,Currency,Date,BiologicalSex,Boolean,NumericId
2,,ItalianVAT,Error,,Plate,Address,,Error,Region,,...,Name,Surname,String,EMail,PhoneNumber,Currency,,,,NumericId
3,ItalianFiscalCode,,,,Error,Address,ItalianZIPCode,ProvinceCode,Region,Municipality,...,Name,Surname,String,EMail,PhoneNumber,Currency,,Error,Boolean,NumericId
4,ItalianFiscalCode,ItalianVAT,CadastreCode,Document,Plate,Address,ItalianZIPCode,,Error,,...,Name,Surname,String,EMail,PhoneNumber,Currency,Date,BiologicalSex,Boolean,NumericId
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
86,ItalianFiscalCode,,Error,,Plate,Address,ItalianZIPCode,Error,,Error,...,Name,Surname,String,EMail,PhoneNumber,Currency,Date,BiologicalSex,,NumericId
87,ItalianFiscalCode,ItalianVAT,CadastreCode,,Plate,Error,ItalianZIPCode,ProvinceCode,Region,Municipality,...,,Surname,,EMail,PhoneNumber,Currency,,,Boolean,NumericId
88,ItalianFiscalCode,ItalianVAT,,,Plate,Address,ItalianZIPCode,ProvinceCode,Region,Municipality,...,,Surname,String,,PhoneNumber,,Date,BiologicalSex,Error,NumericId
89,ItalianFiscalCode,ItalianVAT,CadastreCode,Error,Plate,Address,,ProvinceCode,Region,Municipality,...,,Surname,String,,PhoneNumber,Currency,Error,Error,Boolean,NumericId


In [11]:
from collections import Counter

mask = y_test != y_pred

true_labels = model._embedder._encoder.inverse_transform(y_test[mask])
predicted_labels = model._embedder._encoder.inverse_transform(y_pred[mask])

Counter(zip(true_labels, predicted_labels))

Counter({('Error', 'Address'): 275,
         ('Error', 'CountryCode'): 176,
         ('Error', 'Name'): 1617,
         ('Error', 'Surname'): 1576,
         ('Error', 'String'): 2236,
         ('Error', 'ItalianVAT'): 66,
         ('Municipality', 'Error'): 295,
         ('Year', 'Error'): 190,
         ('Name', 'Error'): 1249,
         ('EMail', 'Error'): 674,
         ('BiologicalSex', 'Error'): 248,
         ('Error', 'BiologicalSex'): 74,
         ('Error', 'CadastreCode'): 283,
         ('Error', 'Document'): 444,
         ('NaN', 'CountryCode'): 222,
         ('Error', 'NumericId'): 163,
         ('ItalianVAT', 'Error'): 29,
         ('Address', 'String'): 106,
         ('Error', 'ProvinceCode'): 128,
         ('Integer', 'ItalianZIPCode'): 38,
         ('String', 'Address'): 149,
         ('String', 'Error'): 311,
         ('EMail', 'Name'): 44,
         ('CountryCode', 'NaN'): 132,
         ('Error', 'Integer'): 294,
         ('ProvinceCode', 'Error'): 307,
         ('Error', 'I