In [1]:
import torch
import numpy as np
from datasets import Dataset, DatasetDict
from tqdm.auto import tqdm
import os
import pickle
import conllu 
from transformers import (
    DistilBertTokenizerFast, 
    DistilBertForSequenceClassification, 
    DistilBertModel
)
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import accuracy_score, f1_score
from sklearn.linear_model import LogisticRegression
import pandas as pd

print(" Step 1: Data Preparation")

file_path = "en_ewt-ud-train.conllu"

try:
    print(f"Loading data from local file: {file_path}...")
    with open(file_path, "r", encoding="utf-8") as f:
        parsed_data = conllu.parse(f.read())
    
    words = []
    pos_labels = []
    
    # getting unique UPOS tags for label map
    upos_tags = sorted(list(set(token['upos'] for sentence in parsed_data for token in sentence if 'upos' in token)))
    upos_to_id = {tag: i for i, tag in enumerate(upos_tags)}
    
    for sentence in tqdm(parsed_data, desc="Processing sentences"):
        tokens = [token['form'] for token in sentence]
        upos_ids = [upos_to_id[token['upos']] for token in sentence if 'upos' in token and token['upos'] in upos_to_id]
        
        if len(tokens) == len(upos_ids) and len(tokens) > 0:
            words.append(tokens)
            pos_labels.append(upos_ids)

    # creating dataset object
    probing_dataset = Dataset.from_dict({'tokens': words, 'labels': pos_labels})

    print(f"Total sentences in dataset: {len(probing_dataset)}")
    
    # subsetting dataset
    SUBSET_SIZE = 2000
    if len(probing_dataset) > SUBSET_SIZE:
        print(f"Subsetting dataset to {SUBSET_SIZE} sentences for faster execution.")
        probing_dataset = probing_dataset.shuffle(seed=42).select(range(SUBSET_SIZE))

    print(f"Total sentences in probing subset: {len(probing_dataset)}")
    print(f"First example tokens: {probing_dataset[0]['tokens']}")
    print(f"First example POS labels: {probing_dataset[0]['labels']}")

except Exception as e:
    print(f"Error loading dataset: {e}")
    print("Please ensure the .conllu file path is correct and 'conllu' library is installed.")
    exit()

  from .autonotebook import tqdm as notebook_tqdm


 Step 1: Data Preparation
Loading data from local file: en_ewt-ud-train.conllu...


Processing sentences: 100%|██████████| 12544/12544 [00:00<00:00, 128410.75it/s]


Total sentences in dataset: 12544
Subsetting dataset to 2000 sentences for faster execution.
Total sentences in probing subset: 2000
First example tokens: ['A', 'key', 'question', 'is', 'how', 'they', 'acquired', 'the', 'anthrax', 'strain', 'first', 'isolated', 'by', 'the', 'Texas', 'Veterinary', 'Medical', 'Diagnostic', 'Lab', 'in', '1980', '.']
First example POS labels: [5, 0, 7, 3, 2, 10, 15, 5, 7, 7, 2, 15, 1, 5, 11, 0, 0, 0, 11, 1, 8, 12]


In [2]:
print("\n Step 2: Probing Setup ")

model_name = 'distilbert-base-uncased'
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# path to fine tuned models
MODEL_DIR_FINE_TUNED = "/home/sharmajidotdev/manish/models/readability"
PROBE_RESULTS_DIR = "/home/sharmajidotdev/manish/probing_results_pos"
os.makedirs(PROBE_RESULTS_DIR, exist_ok=True)


# probing function
def get_word_level_embeddings_single_sentence(sentence_tokens, labels_for_sentence, tokenizer, model, device):
    base_bert_model = model.distilbert
    base_bert_model.eval()

    encoded_inputs = tokenizer(sentence_tokens, is_split_into_words=True, return_tensors='pt', padding=True, truncation=True, return_offsets_mapping=True, max_length=256)
    input_ids, attention_mask = encoded_inputs['input_ids'].to(device), encoded_inputs['attention_mask'].to(device)
    word_ids_list = encoded_inputs.word_ids()

    if not isinstance(word_ids_list, list) or not word_ids_list or not labels_for_sentence: return None, None
    
    with torch.no_grad():
        outputs = base_bert_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
    hidden_states = outputs.hidden_states
    sentence_hidden_states_all_layers = [hs[0] for hs in hidden_states]
    
    word_to_subword_indices = {}
    for token_idx, word_idx in enumerate(word_ids_list):
        if word_idx is None or token_idx >= attention_mask.shape[1] or attention_mask[0, token_idx].item() == 0 or word_idx < 0: continue
        if word_idx not in word_to_subword_indices: word_to_subword_indices[word_idx] = []
        word_to_subword_indices[word_idx].append(token_idx)
    if not word_to_subword_indices: return None, None
    sorted_word_indices = sorted(word_to_subword_indices.keys())
    if len(sorted_word_indices) != len(labels_for_sentence): return None, None
    
    current_sentence_word_embeddings_by_layer = [[] for _ in range(7)]
    aligned_labels_for_sentence_output = []
    
    for original_word_idx in sorted_word_indices:
        subword_token_indices = word_to_subword_indices[original_word_idx]
        if not subword_token_indices: continue
        for layer_idx in range(7):
            if any(idx >= sentence_hidden_states_all_layers[layer_idx].shape[0] for idx in subword_token_indices): return None, None
            subword_embs = sentence_hidden_states_all_layers[layer_idx][subword_token_indices, :]
            current_sentence_word_embeddings_by_layer[layer_idx].append(subword_embs.mean(dim=0).cpu().numpy())
        aligned_labels_for_sentence_output.append(labels_for_sentence[original_word_idx])
    
    processed_word_embeddings_tensors = []
    for layer_idx in range(7):
        if current_sentence_word_embeddings_by_layer[layer_idx]:
            processed_word_embeddings_tensors.append(torch.from_numpy(np.vstack(current_sentence_word_embeddings_by_layer[layer_idx])).float())
        else: return None, None
    return processed_word_embeddings_tensors, aligned_labels_for_sentence_output



 Step 2: Probing Setup 
Using device: cuda


In [4]:
print("\n Step 3: Probing Fine Tuned Models ")

probe_results_pos_ft = {f"layer_{i}": {'accuracy': [], 'f1': []} for i in range(7)}

N_SPLITS = 5
for fold_idx in range(N_SPLITS):
    print(f"\n Probing Fine-Tuned Model from Fold {fold_idx + 1} ")

    ft_model_fkgl = DistilBertForSequenceClassification.from_pretrained(
        os.path.join(MODEL_DIR_FINE_TUNED, "Fkgl", f"best_fold_{fold_idx}"),
        output_hidden_states=True
    )
    ft_model_fkgl.to(device)

    # probe for POS on FKGL-tuned model 
    print("\nExtracting embeddings from FKGL-tuned model ...")
    all_embs_pos_ft = [[] for _ in range(7)]
    all_labels_pos_ft = []
    for i in tqdm(range(len(probing_dataset['tokens'])), desc="POS Probing"):
        embs, labels = get_word_level_embeddings_single_sentence(probing_dataset['tokens'][i], probing_dataset['labels'][i], tokenizer, ft_model_fkgl, device)
        if embs:
            for l_idx, e_list in enumerate(embs):
                all_embs_pos_ft[l_idx].append(e_list)
            all_labels_pos_ft.extend(labels)
    
    concatenated_embs_pos_ft = [torch.cat(embs, dim=0) for embs in all_embs_pos_ft]
    y_pos_ft = np.array(all_labels_pos_ft)

    for l_idx in range(7):
        X_train, X_val, y_train, y_val = train_test_split(concatenated_embs_pos_ft[l_idx].cpu().numpy(), y_pos_ft, test_size=0.2, stratify=y_pos_ft)
        probe = LogisticRegression(max_iter=1000, multi_class='multinomial', n_jobs=-1).fit(X_train, y_train)
        y_pred = probe.predict(X_val)
        probe_results_pos_ft[f"layer_{l_idx}"]['accuracy'].append(accuracy_score(y_val, y_pred))
        probe_results_pos_ft[f"layer_{l_idx}"]['f1'].append(f1_score(y_val, y_pred, average='weighted'))
    
    # saving probing results for this fold
    current_fold_results = {
        'pos': probe_results_pos_ft
    }
    results_file_path = os.path.join(PROBE_RESULTS_DIR, f"fold_{fold_idx}_probing_results.pkl")
    with open(results_file_path, 'wb') as f:
        pickle.dump(current_fold_results, f)
    print(f"Saved probing results for Fold {fold_idx+1} to {results_file_path}")




 Step 3: Probing Fine Tuned Models 

 Probing Fine-Tuned Model from Fold 1 

Extracting embeddings from FKGL-tuned model ...


POS Probing: 100%|██████████| 2000/2000 [00:28<00:00, 71.27it/s]


Saved probing results for Fold 1 to /home/sharmajidotdev/manish/probing_results_pos/fold_0_probing_results.pkl

 Probing Fine-Tuned Model from Fold 2 

Extracting embeddings from FKGL-tuned model ...


POS Probing: 100%|██████████| 2000/2000 [00:26<00:00, 74.44it/s]


Saved probing results for Fold 2 to /home/sharmajidotdev/manish/probing_results_pos/fold_1_probing_results.pkl

 Probing Fine-Tuned Model from Fold 3 

Extracting embeddings from FKGL-tuned model ...


POS Probing: 100%|██████████| 2000/2000 [00:26<00:00, 75.69it/s]


Saved probing results for Fold 3 to /home/sharmajidotdev/manish/probing_results_pos/fold_2_probing_results.pkl

 Probing Fine-Tuned Model from Fold 4 

Extracting embeddings from FKGL-tuned model ...


POS Probing: 100%|██████████| 2000/2000 [00:27<00:00, 73.47it/s]


Saved probing results for Fold 4 to /home/sharmajidotdev/manish/probing_results_pos/fold_3_probing_results.pkl

 Probing Fine-Tuned Model from Fold 5 

Extracting embeddings from FKGL-tuned model ...


POS Probing: 100%|██████████| 2000/2000 [00:15<00:00, 133.22it/s]


Saved probing results for Fold 5 to /home/sharmajidotdev/manish/probing_results_pos/fold_4_probing_results.pkl


In [6]:
import pickle
import os


file_path = os.path.join("probing_results_pos/fold_0_probing_results.pkl") 

try:
    with open(file_path, 'rb') as f:
        loaded_results = pickle.load(f)

    # printing contents of dictionary
    print("\n Successfully loaded results from pickle file.")
    print(loaded_results)

except FileNotFoundError:
    print(f"Error: The file '{file_path}' was not found.")
except Exception as e:
    print(f"An error occurred while loading the pickle file: {e}")


 Successfully loaded results from pickle file.
{'pos': {'layer_0': {'accuracy': [0.8802304426925409], 'f1': [0.8797588971955954]}, 'layer_1': {'accuracy': [0.9422377198302001], 'f1': [0.9419389612018275]}, 'layer_2': {'accuracy': [0.9537598544572469], 'f1': [0.9534581902499958]}, 'layer_3': {'accuracy': [0.9601273499090358], 'f1': [0.959783903429134]}, 'layer_4': {'accuracy': [0.9604305639781686], 'f1': [0.9603801453620106]}, 'layer_5': {'accuracy': [0.9584596725288054], 'f1': [0.9581044484651879]}, 'layer_6': {'accuracy': [0.9536082474226805], 'f1': [0.9533427155725861]}}}


In [7]:
import pickle
import os


file_path = os.path.join("probing_results_pos/fold_1_probing_results.pkl") 

try:
    with open(file_path, 'rb') as f:
        loaded_results = pickle.load(f)

    # printing contents of dictionary
    print("\n Successfully loaded results from pickle file.")
    print(loaded_results)

except FileNotFoundError:
    print(f"Error: The file '{file_path}' was not found.")
except Exception as e:
    print(f"An error occurred while loading the pickle file: {e}")


 Successfully loaded results from pickle file.
{'pos': {'layer_0': {'accuracy': [0.8802304426925409, 0.8815949060036385], 'f1': [0.8797588971955954, 0.880546425174543]}, 'layer_1': {'accuracy': [0.9422377198302001, 0.9393571861734384], 'f1': [0.9419389612018275, 0.9384915825646749]}, 'layer_2': {'accuracy': [0.9537598544572469, 0.9525469981807155], 'f1': [0.9534581902499958, 0.9521631452497487]}, 'layer_3': {'accuracy': [0.9601273499090358, 0.9602789569436022], 'f1': [0.959783903429134, 0.9599074663969538]}, 'layer_4': {'accuracy': [0.9604305639781686, 0.9610369921164342], 'f1': [0.9603801453620106, 0.960693320135188]}, 'layer_5': {'accuracy': [0.9584596725288054, 0.9628562765312311], 'f1': [0.9581044484651879, 0.9627879243352037]}, 'layer_6': {'accuracy': [0.9536082474226805, 0.9599757428744694], 'f1': [0.9533427155725861, 0.9596529286234082]}}}


In [8]:
import pickle
import os


file_path = os.path.join("probing_results_pos/fold_2_probing_results.pkl") 

try:
    with open(file_path, 'rb') as f:
        loaded_results = pickle.load(f)

    # printing contents of dictionary
    print("\n Successfully loaded results from pickle file.")
    print(loaded_results)

except FileNotFoundError:
    print(f"Error: The file '{file_path}' was not found.")
except Exception as e:
    print(f"An error occurred while loading the pickle file: {e}")


 Successfully loaded results from pickle file.
{'pos': {'layer_0': {'accuracy': [0.8802304426925409, 0.8815949060036385, 0.8793208004851425], 'f1': [0.8797588971955954, 0.880546425174543, 0.8788097335507091]}, 'layer_1': {'accuracy': [0.9422377198302001, 0.9393571861734384, 0.9413280776228017], 'f1': [0.9419389612018275, 0.9384915825646749, 0.9409352127779149]}, 'layer_2': {'accuracy': [0.9537598544572469, 0.9525469981807155, 0.9517889630078835], 'f1': [0.9534581902499958, 0.9521631452497487, 0.9516092023485575]}, 'layer_3': {'accuracy': [0.9601273499090358, 0.9602789569436022, 0.9572468162522741], 'f1': [0.959783903429134, 0.9599074663969538, 0.9569656408945049]}, 'layer_4': {'accuracy': [0.9604305639781686, 0.9610369921164342, 0.9601273499090358], 'f1': [0.9603801453620106, 0.960693320135188, 0.9598882325164818]}, 'layer_5': {'accuracy': [0.9584596725288054, 0.9628562765312311, 0.9573984232868405], 'f1': [0.9581044484651879, 0.9627879243352037, 0.957101907944027]}, 'layer_6': {'accu

In [9]:
import pickle
import os


file_path = os.path.join("probing_results_pos/fold_3_probing_results.pkl") 

try:
    with open(file_path, 'rb') as f:
        loaded_results = pickle.load(f)

    # printing contents of dictionary
    print("\n Successfully loaded results from pickle file.")
    print(loaded_results)

except FileNotFoundError:
    print(f"Error: The file '{file_path}' was not found.")
except Exception as e:
    print(f"An error occurred while loading the pickle file: {e}")


 Successfully loaded results from pickle file.
{'pos': {'layer_0': {'accuracy': [0.8802304426925409, 0.8815949060036385, 0.8793208004851425, 0.8811400848999393], 'f1': [0.8797588971955954, 0.880546425174543, 0.8788097335507091, 0.880366713703215]}, 'layer_1': {'accuracy': [0.9422377198302001, 0.9393571861734384, 0.9413280776228017, 0.9402668283808369], 'f1': [0.9419389612018275, 0.9384915825646749, 0.9409352127779149, 0.9399116702148234]}, 'layer_2': {'accuracy': [0.9537598544572469, 0.9525469981807155, 0.9517889630078835, 0.9489084293511219], 'f1': [0.9534581902499958, 0.9521631452497487, 0.9516092023485575, 0.9483849269658877]}, 'layer_3': {'accuracy': [0.9601273499090358, 0.9602789569436022, 0.9572468162522741, 0.9572468162522741], 'f1': [0.959783903429134, 0.9599074663969538, 0.9569656408945049, 0.9570051947462127]}, 'layer_4': {'accuracy': [0.9604305639781686, 0.9610369921164342, 0.9601273499090358, 0.959824135839903], 'f1': [0.9603801453620106, 0.960693320135188, 0.9598882325164

In [10]:
import pickle
import os


file_path = os.path.join("probing_results_pos/fold_4_probing_results.pkl") 

try:
    with open(file_path, 'rb') as f:
        loaded_results = pickle.load(f)

    # printing contents of dictionary
    print("\n Successfully loaded results from pickle file.")
    print(loaded_results)

except FileNotFoundError:
    print(f"Error: The file '{file_path}' was not found.")
except Exception as e:
    print(f"An error occurred while loading the pickle file: {e}")


 Successfully loaded results from pickle file.
{'pos': {'layer_0': {'accuracy': [0.8802304426925409, 0.8815949060036385, 0.8793208004851425, 0.8811400848999393, 0.8843238326258338], 'f1': [0.8797588971955954, 0.880546425174543, 0.8788097335507091, 0.880366713703215, 0.8834425117376086]}, 'layer_1': {'accuracy': [0.9422377198302001, 0.9393571861734384, 0.9413280776228017, 0.9402668283808369, 0.9457246816252274], 'f1': [0.9419389612018275, 0.9384915825646749, 0.9409352127779149, 0.9399116702148234, 0.9455145641673119]}, 'layer_2': {'accuracy': [0.9537598544572469, 0.9525469981807155, 0.9517889630078835, 0.9489084293511219, 0.9564887810794421], 'f1': [0.9534581902499958, 0.9521631452497487, 0.9516092023485575, 0.9483849269658877, 0.9561760695640499]}, 'layer_3': {'accuracy': [0.9601273499090358, 0.9602789569436022, 0.9572468162522741, 0.9572468162522741, 0.9607337780473014], 'f1': [0.959783903429134, 0.9599074663969538, 0.9569656408945049, 0.9570051947462127, 0.960375656334104]}, 'layer_

In [5]:
print("\n Step 4: Final Summary ")
avg_probe_acc_pos_ft = [np.mean(probe_results_pos_ft[f"layer_{i}"]['accuracy']) for i in range(7)]
avg_probe_f1_pos_ft = [np.mean(probe_results_pos_ft[f"layer_{i}"]['f1']) for i in range(7)]

print("Final Probing Results on Fine-Tuned Models (Average across 5 folds)")
print("------------------------------------------------------------------")
print("POS Tagging Probe Accuracy:", avg_probe_acc_pos_ft)
print("POS Tagging Probe F1-Score:", avg_probe_f1_pos_ft)


 Step 4: Final Summary 
Final Probing Results on Fine-Tuned Models (Average across 5 folds)
------------------------------------------------------------------
POS Tagging Probe Accuracy: [np.float64(0.8813220133414191), np.float64(0.9417828987265009), np.float64(0.9526986052152819), np.float64(0.9591267434808974), np.float64(0.9600363856882959), np.float64(0.9592783505154638), np.float64(0.9580958156458459)]
POS Tagging Probe F1-Score: [np.float64(0.8805848562723341), np.float64(0.9413583981853106), np.float64(0.9523583068756478), np.float64(0.9588075723601819), np.float64(0.9598266796501911), np.float64(0.959023447162339), np.float64(0.9578212560911048)]
