In [3]:
import tokenizers
import transformers
from transformers import BertTokenizer, BertForMaskedLM
import sklearn
from sklearn.metrics import classification_report
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import MinMaxScaler
from transformers.pipelines import pipeline
from sklearn.linear_model import Ridge
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm


In [4]:
train_df = pd.read_csv("C:/Users/bergo/OneDrive - University of Pisa/Tesi Magistrale/probing_data/probing_train.csv")
val_df = pd.read_csv("C:/Users/bergo/OneDrive - University of Pisa/Tesi Magistrale/probing_data/probing_test.csv")

In [5]:
train_dict = train_df[:500].to_dict(orient='records')
val_dict = val_df[:250].to_dict(orient='records')

In [6]:
train_dict[0]

{'id': 'isdt_tut-689',
 'sent': "Il diritto dell'usufruttuario non si estende al tesoro che si scopra durante l'usufrutto, salve le ragioni che gli possono competere come ritrovatore (932).",
 'category': 8,
 'n_tokens': 31,
 'char_per_tok': 4.81481481481482,
 'upos_dist_DET': 16.1290322580645,
 'upos_dist_ADV': 3.2258064516129,
 'upos_dist_PUNCT': 12.9032258064516,
 'upos_dist_NUM': 3.2258064516129,
 'upos_dist_PRON': 16.1290322580645,
 'upos_dist_ADP': 12.9032258064516,
 'upos_dist_PROPN': 0.0,
 'upos_dist_ADJ': 3.2258064516129,
 'upos_dist_VERB': 9.67741935483871,
 'upos_dist_NOUN': 19.3548387096774,
 'upos_dist_CCONJ': 0.0,
 'upos_dist_AUX': 3.2258064516129,
 'avg_links_len': 2.42307692307692,
 'max_links_len': 11,
 'avg_max_depth': 5,
 'dep_dist_obj': 0.0,
 'dep_dist_nsubj': 12.9032258064516,
 'subj_pre': 75.0,
 'subj_post': 25.0,
 'n_prepositional_chains': 1,
 'avg_prepositional_chain_len': 1.0,
 'avg_subordinate_chain_len': 1.0,
 'subordinate_proposition_dist': 66.6666666666667,

In [8]:
#funzione per ottenere gli embedding delle frasi.
def feature_extraction(samples, model_name):
    first_layer = 1
    last_layer = 8
    tokenizer = BertTokenizer.from_pretrained("dbmdz/bert-base-italian-cased") #come tokenizzatore utilizziamo il solito giusto (bert-base-italian-uncased)? 
    model = BertForMaskedLM.from_pretrained(model_name)
    for sample in tqdm(samples, desc="Estrazione features", unit="sample"):
        encoded_sen = tokenizer(sample["sent"], padding=True, truncation=True, max_length=128, return_tensors='pt') 
        with torch.no_grad():    
            model_output = model(**encoded_sen, output_hidden_states=True)
            hidden_states = model_output.hidden_states
            for layer in range(first_layer, last_layer+1):
                layer_output = torch.squeeze(hidden_states[layer])
                cls_embedding = layer_output[0, :].cpu().detach().numpy()
                sample[f'layer_{layer}'] = {'cls_embedding': cls_embedding}
    return samples

#funzione per ottenere features e lables
def get_features_lables(samples, feature, layer):
    X = []
    y = []
    for sample in samples:
        X = sample[layer]["cls_embedding"]
        y = sample[feature]
    return X, y

#funzione per addestrare e valutare il modello
def train_eval(train_set, val_set, feature, layer):
    scaler = MinMaxScaler()
    X_train, y_train = get_features_lables(train_set, feature, layer)
    X_val, y_val = get_features_lables(val_set, feature, layer)
    scaled_X_train = scaler.fit_transform(X_train)
    scaled_X_val = scaler.transform(X_val)
    clf = sklearn.linear_model.Ridge(alpha=1.0)
    clf.fit(scaled_X_train, y_train)
    y_pred = clf.predict(scaled_X_val)
    return classification_report(y_val, y_pred, output_dict=True) 

In [14]:
final_model_path =  f"C:/Users/bergo/OneDrive - University of Pisa/Tesi Magistrale/models/ANTI_CURRICULUM/final_pretrained_model"
samples = feature_extraction(val_dict, final_model_path)

Estrazione features: 100%|██████████| 250/250 [00:09<00:00, 26.22sample/s]


In [9]:
# DEVO ITERARE ANCHE SUI CHECKPOINT CHE USIAMO
checkpoints = [2, 32, 512, 8192, 0] 
ling_features = ["n_tokens", "upos_dist_ADP", "char_per_tok"]

In [10]:
def probing_checkpoints(checkpoints, training_id, train_dict, val_dict, ling_features):
    first_layer = 1
    last_layer = 8
    results = dict()
    for n_step in checkpoints:
        checkpoint_name = f'checkpoint-step{n_step}'
        if n_step == 0:
            model_name = f"C:/Users/bergo/OneDrive - University of Pisa/Tesi Magistrale/models/{training_id}/final_pretrained_model"
            print("Inizio probing per il modello finale")
        else:
            model_name = f"C:/Users/bergo/OneDrive - University of Pisa/Tesi Magistrale/models/{training_id}/checkpoints/{checkpoint_name}"
            print(f"Inizio probing per il checkpoint {n_step}")
        
        print("Estrazione delle feature di training...")
        train_samples = feature_extraction(train_dict, model_name)  #si effettua l'estrazione delle feature per il checkpoint
        print("Estrazione delle features di validation...")
        val_samples = feature_extraction(val_dict, model_name)   
        for ling_feature in ling_features:              
            print(f'Addestramento del modello sulla feature linguistica: {ling_feature}') 
            layer_results = dict()
            feature_result = dict()
            for layer in range(first_layer, last_layer+1):
                print(f"Training for layer {layer}/{last_layer}")
                layer_result = train_eval(train_samples, val_samples, ling_feature, f'layer_{layer}')   #addestriamo il ridge per ogni layer ottenendo le metriche
                layer_results[f'layer_{layer}'] = layer_result
            feature_result[ling_feature] = {"results": layer_results}
            results[f"checkpoint{n_step}"] = feature_result
    return results
        

In [11]:
results = probing_checkpoints(checkpoints, "ANTI_CURRICULUM", train_dict, val_dict, ling_features)

Inizio probing per il checkpoint 2
Estrazione delle feature di training...


Estrazione features: 100%|██████████| 500/500 [00:21<00:00, 22.81sample/s]


Estrazione delle features di validation...


Estrazione features: 100%|██████████| 250/250 [00:11<00:00, 21.55sample/s]


Addestramento del modello sulla feature linguistica: n_tokens
Training for layer 1/8


ValueError: Expected 2D array, got 1D array instead:
array=[-2.1315079   2.5852509   0.8696965  -0.05993515  1.0990232   0.25199714
  1.0477796   0.6888469  -0.3031923   0.57987154  1.1484693   0.01575219
 -0.19587785  0.8250259  -0.28469267  1.1389377   0.2756512   1.8663802
 -0.53986233  0.6415098   0.34137812 -0.3464678   0.12808228 -1.4836715
  2.1205585   0.01961643  1.8258386   0.17781517 -0.6131355   0.7713341
  0.7724039  -0.4233626  -0.46201268  0.31877062  1.3882495   0.0500122
 -0.5601297  -0.31841275 -0.9053752   0.16928726 -1.3684188  -1.2948371
 -1.0609267   0.9092711   0.69348264 -0.02238305  0.27656123 -1.9663436
  1.1278688  -0.19321296  0.6869232  -1.4053665   0.6567181  -0.1321012
  0.2503742  -0.41009977 -0.6352198   1.3595886  -1.0228201  -0.1098899
  1.5206287  -0.24586315  0.5683745   1.3556516  -1.4159044   0.6217237
  0.6803576   0.17879528 -1.2983041  -0.32037452  0.61688304 -0.03982257
  0.2194626  -0.86114943 -1.3924428   0.54315287  1.1556351  -0.5546677
 -1.4090166   1.0658685  -0.3178617   1.0904107   0.772633   -0.27664983
  0.48406556 -0.17786855 -0.5552606  -0.02859019 -0.13913646 -0.370938
 -0.5773676  -1.2583542   0.51607805  0.5681857  -2.824658    0.8460992
  0.7603097   0.28983587 -1.9392209   0.2712168   1.1557989  -0.55029196
  1.0123689   0.06755079 -0.85215044  1.6900904   0.60460097 -1.5209166
 -0.36043766 -1.3951175   0.41214955  0.19910051  0.13087957  0.7612219
  0.36053273 -1.7659683  -0.6871994   0.74903375  1.0689613  -0.88474464
  0.8538169   0.8749668   0.43222013  0.73958707 -0.1142486  -0.74093926
 -2.494575    0.01482429  0.46978465 -0.6193824  -0.27584842  1.0859346
  0.71510965 -1.2793081  -0.02888424 -0.5597114  -0.51816374 -0.6552042
  0.02585092 -1.0246291  -0.7863578  -0.714702    1.1415294  -2.447301
 -0.65717787 -3.0756285   0.13754526 -0.33781958  1.6758555  -0.06363913
  1.1446282  -0.20447041 -1.4962667  -0.01498348  0.4629819  -1.2987927
 -0.6306813  -1.2097439  -0.93509805 -0.36197978 -1.5682865   0.4663133
 -0.14535345  0.28429675 -0.20041187  0.3192833   0.8618528   0.8399709
  0.5730316  -2.0949395  -0.7641507  -0.35881627 -0.19166067 -1.5288199
 -1.0515965   0.32365137  1.2290727  -1.5471276   1.7798424  -0.57943594
  1.239905   -0.78818196  0.2934345  -0.77484107 -1.9284226  -0.6072789
 -1.0302314  -0.92619336  0.7389993  -0.99731266  0.6872716  -0.23841058
 -0.9451939   0.05250425 -0.26899523  1.015855    2.235073    0.52772677
  1.2305936  -0.675693    0.13309199  0.34937277 -0.39567295  2.3384748
  0.43896678 -0.48593757  0.79565966  0.23780777 -0.5198643  -1.4581479
 -0.4991036  -0.55402845  1.6768126   0.67543864  0.2901484  -0.28103185
  2.0802653   1.5713841   0.8643136   1.3073233  -1.0909017  -0.23238957
  1.0269688  -1.1977359   0.95418036  0.8490027   1.0432409   1.1800563
 -1.2815168   1.213067   -1.1532068  -0.1866039  -1.9774449   0.87895215
  0.24810779 -0.20059954 -0.33959317  0.95733905 -0.35022283 -0.07024485
  0.7544405  -0.06140806 -0.04188168 -0.6633407   1.1840569   0.05816631
 -1.6159526   0.66268766 -0.23083846 -0.14517394  1.1469204   0.34799567
  1.4266642   0.31334278 -0.28295827 -1.1108409  -0.8835279   2.5834026
  0.87787735  0.60583746 -0.54754007  0.6377837  -0.5601931  -0.1955181
 -0.43250823 -0.2745196  -1.5457213   0.89529586  0.2582714   1.8737602
 -1.0582879   0.3474347   0.596482   -1.5612183  -0.40969077 -0.5706491
  1.7546692   1.0853792   0.40227976 -0.9536477   0.03482755  0.11100516
  0.06262214 -0.46291402  1.8026655  -0.74820435 -0.12988245 -2.2585475
 -0.24199292  1.9089705  -1.8475819   0.14677253 -0.1074628  -0.30475357
 -1.6002564   0.2691185   2.0681632  -1.1319175   1.7944973   0.3593319
 -0.9820287   0.38270625 -0.17628457  0.8467525   0.96995956 -0.8724843
 -0.22167179 -0.15352143 -1.1596437  -1.1336795   1.6857513  -0.06962411
 -1.433948    2.1456652  -0.4888481  -0.35367203  1.3283356   0.6612864
  0.9682756  -1.7218527   1.1928122  -1.5798819   0.63979894  1.1655827
 -0.18953459 -0.29073763 -0.2786006   0.6659778  -0.29394937 -0.18743546
  0.3852296  -0.01903921  0.6541177   0.32475945 -0.44165185  0.38651502
 -0.43014315 -0.55975664 -1.9981103  -1.2685267   0.16594425 -0.00500451
  1.028428   -0.9489211   0.18921047  1.4825013  -0.35193118 -0.9148419
  1.4117302  -0.18791197 -0.10850836 -1.4724474   0.11098368 -1.3436909
  1.7686806   0.8103731   0.5634731   1.0007045   1.0941757   0.3885739
 -1.2542454   1.7050391   1.7659668   0.97419995 -1.4393148  -0.15077412
 -0.20392825  1.2165229  -0.20432417 -0.9894576  -0.30093786 -1.7140422
  1.438269    1.1720381  -0.7884632  -0.4539971  -0.92923534  1.0242844
  1.4706647   0.9541518  -0.40100345  0.7072374  -0.3209165  -0.6545255
  1.3048483   1.7719132  -0.92017967 -0.6811888  -0.45541736  0.7656937
  0.4459724   0.35462174  1.7438623  -0.8897294  -0.26854998  1.691194
  0.3646269  -1.717905    1.04682     0.12194919  0.38123417 -0.9247314
 -1.4976519  -0.02474353 -0.57785183 -2.0275788  -0.48910365 -0.08354808
  0.93419427  0.81332195  0.35338688 -2.1042473   0.07245155  0.28052035
  1.2526983  -2.5040176   1.0658971  -1.0559217   1.1001824   1.0868396
 -0.08016798  0.45124474 -1.6627975   0.5373041  -0.41170666 -1.0421872
  1.3074819   0.25275132  0.66100025  0.84878033  0.6081676  -0.43544185
 -0.56882     0.21461087 -1.7406875   1.1791891   0.8716639   0.43179655
  0.2305155   1.6167414  -2.0586524   1.0186273  -1.8040993  -0.02628365
  1.0299834  -0.963757   -2.1501434  -0.3194373   0.46855178  0.12777597
 -0.15847397 -0.27317542  0.2992189  -1.5315675  -0.4296473   2.411139
 -0.4602697  -0.1070887  -0.9749655  -0.3112521  -0.8309874  -0.39027452
 -0.8181054   0.6713733  -0.95700055  1.7765068  -1.1362334   0.71948075
 -0.535767    0.27906173 -2.8360462   0.2817213   0.80584335 -0.6841869
 -0.60343087 -0.2996819  -0.10911417 -0.2782065   1.3123758  -0.44260928
 -1.0154886   1.1067688   0.14916731 -0.47244233 -1.4030291  -1.1196274
  1.397554    0.08674887  0.5141819  -0.43476057  0.3046904  -1.3075932
 -0.02327117 -0.850188    0.08484893 -0.3344713  -1.541361    2.0147092
  0.16958544 -1.2637361  -0.6364124   0.55694467 -0.27479953  0.12477673
  0.48269817  0.7332328   0.25257754  0.8060364  -0.0684315  -0.58036786
  1.3048286  -0.9681178 ].
Reshape your data either using array.reshape(-1, 1) if your data has a single feature or array.reshape(1, -1) if it contains a single sample.

In [None]:
import json

In [None]:
with open("C:/Users/bergo/OneDrive - University of Pisa/Tesi Magistrale/models/ANTI_CURRICULUM/probing.txt", "w") as f:
    json.dump(results, f)