In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
!pip install s2sphere
from s2sphere import CellId
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset





In [8]:
!unzip fine_tuned_model.zip -d fine_tuned_model

Archive:  fine_tuned_model.zip
  inflating: fine_tuned_model/tokenizer.json  
  inflating: fine_tuned_model/__MACOSX/._tokenizer.json  
  inflating: fine_tuned_model/model.safetensors  
  inflating: fine_tuned_model/__MACOSX/._model.safetensors  
  inflating: fine_tuned_model/tokenizer_config.json  
  inflating: fine_tuned_model/__MACOSX/._tokenizer_config.json  
  inflating: fine_tuned_model/spiece.model  
  inflating: fine_tuned_model/__MACOSX/._spiece.model  
  inflating: fine_tuned_model/special_tokens_map.json  
  inflating: fine_tuned_model/__MACOSX/._special_tokens_map.json  
  inflating: fine_tuned_model/generation_config.json  
  inflating: fine_tuned_model/__MACOSX/._generation_config.json  
  inflating: fine_tuned_model/config.json  
  inflating: fine_tuned_model/__MACOSX/._config.json  


In [None]:
model_path = "fine_tuned_model"  # folder with config.json and pytorch_model.bin
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
model.eval()  # inference mode

T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=384, bias=False)
              (k): Linear(in_features=512, out_features=384, bias=False)
              (v): Linear(in_features=512, out_features=384, bias=False)
              (o): Linear(in_features=384, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 6)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=512, out_features=1024, bias=False)
              (wi_1): Linear(in_features=512, out_features=1024, bias=False)
              (wo): 

In [None]:
df = pd.read_csv('df_for_model_with_hierarchical.csv')
#df = df.sample(n=1000000, random_state=42)

In [None]:
# split

df_train, df_val = train_test_split(df, test_size=0.2, random_state=42)
#df_val = df_val[:1000]# now just doing this for 10000 instances

In [None]:
# dataset loader 

class TextDataset(Dataset):
    def __init__(self, texts):
        self.texts = ["predict cell token: " + t for t in texts]

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx]

def collate_fn(batch):
    return tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)

# setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# prepare dataset and dataloader
inputs = df_val["text_combined"].tolist()
dataset = TextDataset(inputs)
val_dataloader = DataLoader(dataset, batch_size=32, collate_fn=collate_fn)

pred_labels = []

# generate predicitons 
print("Generating predictions...")
for batch in tqdm(val_dataloader):
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        output_ids = model.generate(**batch, max_length=16)

    decoded = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    pred_labels.extend(decoded)


In [None]:
df_val['pred_labels'] = pred_labels

### Accuracy 

In [None]:

labels = df_val["hierarchical_label"].tolist()
predictions = df_val["pred_labels"].tolist()

labels_str = list(map(str, labels))
preds_str = list(map(str, predictions))

# Compute accuracy
accuracy = accuracy_score(labels_str, preds_str)
print(f"\n✅ Finished! Validation Accuracy: {accuracy:.4f}")


✅ Finished! Validation Accuracy: 0.5492


### Hierarchical Metrics

In [None]:
# HIERARCHICAL EVALUATION FROM HIERARCHICAL LABELS
def hierarchical_precision(true_labels, pred_labels):
    total_overlap = 0
    total_predicted_levels = 0

    for t, p in zip(true_labels, pred_labels):
        true_prefixes = {t[:i+1] for i in range(len(t))}
        pred_prefixes = {p[:i+1] for i in range(len(p))}

        total_overlap += len(pred_prefixes & true_prefixes)
        total_predicted_levels += len(pred_prefixes)

    return total_overlap / total_predicted_levels if total_predicted_levels > 0 else 0.0


def hierarchical_recall(true_labels, pred_labels):
    total_overlap = 0
    total_true_levels = 0

    for t, p in zip(true_labels, pred_labels):
        true_prefixes = {t[:i+1] for i in range(len(t))}
        pred_prefixes = {p[:i+1] for i in range(len(p))}

        total_overlap += len(pred_prefixes & true_prefixes)
        total_true_levels += len(true_prefixes)

    return total_overlap / total_true_levels if total_true_levels > 0 else 0.0


def hierarchical_f1(hP, hR):
    return 2 * hP * hR / (hP + hR) if (hP + hR) > 0 else 0.0

# Ensure the columns are strings
true_labels = df_val["hierarchical_label"].astype(str)
pred_labels = df_val["pred_labels"].astype(str)

# Compute hierarchical metrics
hP = hierarchical_precision(true_labels, pred_labels)
hR = hierarchical_recall(true_labels, pred_labels)
hF = hierarchical_f1(hP, hR)

# Display results
print(f"Hierarchical Precision (hP): {hP:.4f}")
print(f"Hierarchical Recall (hR): {hR:.4f}")
print(f"Hierarchical F1 (hF): {hF:.4f}")



Hierarchical Precision (hP): 0.7824
Hierarchical Recall (hR): 0.7973
Hierarchical F1 (hF): 0.7898


### Hierarchical precision per level


In [None]:
def hierarchical_precision_per_level(true_labels, pred_labels, max_level=10):
    correct_by_level = [0] * (max_level + 1)
    total_by_level = [0] * (max_level + 1)

    for t, p in zip(true_labels, pred_labels):
        for lvl in range(min(len(t), len(p))):
            if t[lvl] == p[lvl]:
                correct_by_level[lvl] += 1
            total_by_level[lvl] += 1

    precision_by_level = {}
    for lvl in range(max_level + 1):
        if total_by_level[lvl] == 0:
            precision_by_level[f"Level {lvl} Precision"] = None
        else:
            precision_by_level[f"Level {lvl} Precision"] = correct_by_level[lvl] / total_by_level[lvl]

    return precision_by_level


# Compute precision at each level
level_precisions = hierarchical_precision_per_level(true_labels, pred_labels)

# Print nicely
for lvl, score in level_precisions.items():
    print(f"{lvl}: {score:.4f}" if score is not None else f"{lvl}: No predictions")


Level 0 Precision: 0.9725
Level 1 Precision: 0.9500
Level 2 Precision: 0.9221
Level 3 Precision: 0.8807
Level 4 Precision: 0.8403
Level 5 Precision: 0.8013
Level 6 Precision: 0.7646
Level 7 Precision: 0.7382
Level 8 Precision: 0.7407
Level 9 Precision: 0.7689
Level 10 Precision: 0.8568


### SHAP


In [None]:
# Able to see which token is more important when doing predictions!!
def lime_style_token_importance(text, model, tokenizer, device, max_length=16):
    full_input =  text
    inputs = tokenizer([full_input], return_tensors="pt", truncation=True, padding=True).to(device)

    with torch.no_grad():
        output_ids = model.generate(**inputs, max_length=max_length)
    original_label = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    print(f"Original Prediction: {original_label}")

    tokens = text.split()
    results = []

    for i in range(len(tokens)):
        perturbed = tokens[:i] + ["[MASK]"] + tokens[i+1:]
        perturbed_text = " ".join(perturbed)
        full_perturbed_input = "predict cell token: " + perturbed_text

        inputs = tokenizer([full_perturbed_input], return_tensors="pt", truncation=True, padding=True).to(device)
        with torch.no_grad():
            output_ids = model.generate(**inputs, max_length=max_length)
        new_label = tokenizer.decode(output_ids[0], skip_special_tokens=True)

        changed = new_label != original_label
        results.append((tokens[i], changed, new_label))

    print("\n🔍 Token influence (✔ = changed prediction):\n")
    for word, changed, new_output in results:
        mark = "✔" if changed else " "
        print(f"{word:15} → {new_output:15} {mark}")

    return results

text = "Specimen: Batis crypta Fjeldsa, Bowie & Kiure, 2006. Collected by Andersen, Thorkild on 1948-10-23 in Uluguru Mts., Unkn, Tanzania. Discipline: Zoology."

print("True LAbel: 2011313213")
lime_style_token_importance(text, model, tokenizer, device)


True LAbel: 2011313213
🎯 Original Prediction: 2011313213

🔍 Token influence (✔ = changed prediction):

Specimen:       → 2011313213       
Batis           → 2011313213       
crypta          → 2011313213       
Fjeldsa,        → 2011313213       
Bowie           → 2011313213       
&               → 2011313213       
Kiure,          → 2011313213       
2006.           → 2011313213       
Collected       → 2011313213       
by              → 2011313213       
Andersen,       → 2011313213       
Thorkild        → 2011313213       
on              → 2011313213       
1948-10-23      → 2011313213       
in              → 2011313213       
Uluguru         → 2003200111      ✔
Mts.,           → 2011313213       
Unkn,           → 2011313213       
Tanzania.       → 3203112023      ✔
Discipline:     → 2011313213       
Zoology.        → 2011313213       


[('Specimen:', False, '2011313213'),
 ('Batis', False, '2011313213'),
 ('crypta', False, '2011313213'),
 ('Fjeldsa,', False, '2011313213'),
 ('Bowie', False, '2011313213'),
 ('&', False, '2011313213'),
 ('Kiure,', False, '2011313213'),
 ('2006.', False, '2011313213'),
 ('Collected', False, '2011313213'),
 ('by', False, '2011313213'),
 ('Andersen,', False, '2011313213'),
 ('Thorkild', False, '2011313213'),
 ('on', False, '2011313213'),
 ('1948-10-23', False, '2011313213'),
 ('in', False, '2011313213'),
 ('Uluguru', True, '2003200111'),
 ('Mts.,', False, '2011313213'),
 ('Unkn,', False, '2011313213'),
 ('Tanzania.', True, '3203112023'),
 ('Discipline:', False, '2011313213'),
 ('Zoology.', False, '2011313213')]

In [None]:
from IPython.display import display, HTML

def visualize_token_importance(text, importance_results):
    html = ""
    for token, changed, _ in importance_results:
        if changed:
            # Red for important (changed prediction when masked)
            html += f'<span style="background-color: rgba(255,0,0,0.4); padding:2px; margin:1px;">{token}</span> '
        else:
            # Light gray for unimportant
            html += f'<span style="background-color: rgba(200,200,200,0.2); padding:2px; margin:1px;">{token}</span> '

    display(HTML(html))


In [None]:
text = df_val.iloc[0]['text_combined']
print(f"True Label: {df_val.iloc[0]['hierarchical_label']}")
results = lime_style_token_importance(text, model, tokenizer, device)
visualize_token_importance(text, results)

True Label: 1022213221
🎯 Original Prediction: 1022213211

🔍 Token influence (✔ = changed prediction):

Specimen:       → 1022213211       
Sherardia       → 1022213211       
arvensis        → 1022213211       
L..             → 1022213211       
Collected       → 1022213211       
by              → 1022213211       
Willing,R.      → 1022213211       
on              → 1022213011      ✔
2006-04-26      → 1022213211       
in              → 1022213211       
Kilkís,         → 1022211203      ✔
W               → 1022213211       
Myriofyto,      → 1022213011      ✔
Unkn,           → 1022213211       
Greece.         → 1022213211       
Discipline:     → 1022213211       
Botany.         → 1022213213      ✔


In [None]:
from IPython.display import display, HTML

from IPython.display import display, HTML
import difflib

def token_difference_score(original, new):
    return 1 - difflib.SequenceMatcher(None, original, new).ratio()  # Between 0 and 1

def visualize_token_importance_gradient(text, importance_results, original_label):
    html = ""
    for token, _, new_label in importance_results:
        diff_score = token_difference_score(original_label, new_label)
        intensity = min(max(diff_score, 0.0), 1.0)  # Clamp to [0, 1]
        red = int(255 * intensity)
        html += f'<span style="background-color: rgba({red},0,0,{0.3 + 0.4 * intensity:.2f}); padding:2px; margin:1px;">{token}</span> '
    display(HTML(html))


# slightly modified analysis function that returns the original label
def lime_style_token_importance(text, model, tokenizer, device, max_length=16):
    inputs = tokenizer([text], return_tensors="pt", truncation=True, padding=True).to(device)
    with torch.no_grad():
        output_ids = model.generate(**inputs, max_length=max_length)
    original_label = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    tokens = text.split()
    results = []

    for i in range(len(tokens)):
        perturbed = tokens[:i] + ["[MASK]"] + tokens[i+1:]
        perturbed_text = " ".join(perturbed)
        full_perturbed_input = "predict cell token: " + perturbed_text

        inputs = tokenizer([full_perturbed_input], return_tensors="pt", truncation=True, padding=True).to(device)
        with torch.no_grad():
            output_ids = model.generate(**inputs, max_length=max_length)
        new_label = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        changed = new_label != original_label
        results.append((tokens[i], changed, new_label))

    return original_label, results


In [None]:
for idx, row in df_val.iterrows():
    text = row['text_combined']
    true_label = str(row['hierarchical_label'])  # make sure it's string for comparison

    predicted_label, importance_results = lime_style_token_importance(text, model, tokenizer, device)

    if predicted_label == true_label:
        print(f"\n✅ Row {idx} — Match! Predicted: {predicted_label} | True: {true_label}")
        visualize_token_importance(text, importance_results)
    else:
        None



✅ Row 620810 — Match! Predicted: 2011313213 | True: 2011313213



✅ Row 580633 — Match! Predicted: 32002030001 | True: 32002030001



✅ Row 6379 — Match! Predicted: 21330033011 | True: 21330033011



✅ Row 366511 — Match! Predicted: 21330033033 | True: 21330033033



✅ Row 693950 — Match! Predicted: 21331302303 | True: 21331302303



✅ Row 707398 — Match! Predicted: 32033311021 | True: 32033311021



✅ Row 350203 — Match! Predicted: 33002133203 | True: 33002133203



✅ Row 206407 — Match! Predicted: 1022002031 | True: 1022002031



✅ Row 580309 — Match! Predicted: 430100001 | True: 430100001



✅ Row 678515 — Match! Predicted: 13121200133 | True: 13121200133



✅ Row 650817 — Match! Predicted: 21330013321 | True: 21330013321



✅ Row 625207 — Match! Predicted: 21233023001 | True: 21233023001



✅ Row 684929 — Match! Predicted: 21330013031 | True: 21330013031



✅ Row 817676 — Match! Predicted: 21221030103 | True: 21221030103



✅ Row 826148 — Match! Predicted: 21300132131 | True: 21300132131



✅ Row 692727 — Match! Predicted: 13123122133 | True: 13123122133



✅ Row 358268 — Match! Predicted: 21232231033 | True: 21232231033



✅ Row 820423 — Match! Predicted: 13212203011 | True: 13212203011



✅ Row 122168 — Match! Predicted: 311323003 | True: 311323003



✅ Row 845839 — Match! Predicted: 3211013103 | True: 3211013103



✅ Row 964995 — Match! Predicted: 10032022121 | True: 10032022121



✅ Row 510604 — Match! Predicted: 10002120021 | True: 10002120021



✅ Row 432991 — Match! Predicted: 21330003103 | True: 21330003103



✅ Row 226634 — Match! Predicted: 21232320001 | True: 21232320001



✅ Row 863844 — Match! Predicted: 21323310033 | True: 21323310033



✅ Row 216671 — Match! Predicted: 10221233123 | True: 10221233123



✅ Row 195516 — Match! Predicted: 2131310221 | True: 2131310221



✅ Row 852641 — Match! Predicted: 2011012223 | True: 2011012223



✅ Row 836275 — Match! Predicted: 12303331333 | True: 12303331333


KeyboardInterrupt: 