In [1]:
import math
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import f1_score
import re
import random
import math
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import BertTokenizerFast, AutoModel, AutoTokenizer
from sklearn.metrics import f1_score, classification_report
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [18]:
torch.cuda.empty_cache()

In [20]:
import re
import random
from typing import List, Optional
import pandas as pd
from transformers import BertTokenizer

# 1) Loader function
def read_raw_sentences(
    path: str,
    n_max_sentences: Optional[int] = None,
    shuffle: bool = False,
    random_seed: int = 0,
    report_every: int = 10000
) -> List[str]:
    """
    Reads up to `n_max_sentences` lines from `path` (one sentence per line).
    Returns a list of the raw lines (with trailing newlines stripped).
    """
    sentences: List[str] = []
    with open(path, "r", encoding="utf-8") as f:
        if n_max_sentences is not None:
            for i in range(n_max_sentences):
                line = f.readline()
                if not line:
                    break
                sentences.append(line.rstrip("\n"))
                if (i + 1) % report_every == 0:
                    print(f"… loaded {i+1} sentences")
        else:
            for i, line in enumerate(f, start=1):
                sentences.append(line.rstrip("\n"))
                if i % report_every == 0:
                    print(f"… loaded {i} sentences")
    print(f"Done loading: {len(sentences)} sentences")
    if shuffle:
        random.seed(random_seed)
        random.shuffle(sentences)
        print("Shuffled sentences")
    return sentences

# 2) Your label extractor & pattern
pattern = re.compile(r"\w+|[^\w\s]", flags=re.UNICODE)

def extract_labels(sentence: str):
    tokens = pattern.findall(sentence)
    words, init_labels, final_labels, cap_labels = [], [], [], []
    for i, token in enumerate(tokens):
        if re.match(r"\w+", token, flags=re.UNICODE):
            # initial punctuation
            init = '¿' if i>0 and tokens[i-1]=='¿' else ''
            # final punctuation
            final = tokens[i+1] if i < len(tokens)-1 and tokens[i+1] in {'.',',','?'} else ''
            # capitalization
            if token.isupper():
                cap = 3
            elif token[0].isupper() and token[1:].islower():
                cap = 1
            elif token.islower():
                cap = 0
            else:
                cap = 2
            words.append(token)
            init_labels.append(init)
            final_labels.append(final)
            cap_labels.append(cap)
    return words, init_labels, final_labels, cap_labels

# 3) Load sentences from file
path = "es_419_validas.txt"
raw_sentences = read_raw_sentences(
    path=path,
    n_max_sentences=10000,   # or None to load all
    shuffle=True,
    random_seed=42,
    report_every=5000
)

# 4) Tokenize+label into DataFrame
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
data = []
for inst_id, sentence in enumerate(raw_sentences, start=1):
    words, init_lbls, final_lbls, cap_lbls = extract_labels(sentence)
    token_idx = 0
    for word, init_lbl, final_lbl, cap_lbl in zip(words, init_lbls, final_lbls, cap_lbls):
        subtokens = tokenizer.tokenize(word.lower())
        for i, sub in enumerate(subtokens):
            # initial only on first subtoken
            punct_init = init_lbl if i == 0 else ''
            # final only on last subtoken
            punct_final = final_lbl if i == len(subtokens)-1 else ''
            data.append([
                inst_id,
                token_idx,
                sub,
                punct_init,
                punct_final,
                cap_lbl
            ])
            token_idx += 1
    if inst_id % 500_000 == 0:
        print(f"… processed {inst_id} sentences, {len(data)} tokens so far")

df = pd.DataFrame(
    data,
    columns=["inst_id", "token_id", "token", "punt_inicial", "punt_final", "capitalizacion"]
)
print(f"Final: {df.shape[0]} tokens from {inst_id} sentences")
print(df.head())


… loaded 5000 sentences
… loaded 10000 sentences
Done loading: 10000 sentences
Shuffled sentences
Final: 82243 tokens from 10000 sentences
   inst_id  token_id   token punt_inicial punt_final  capitalizacion
0        1         0    cómo            ¿                          1
1        1         1      va                       ,               0
2        1         2   mucha                                       0
3        1         3  ##chos                       ,               0
4        1         4    todo                                       0


In [21]:
# Convert token strings to BERT token IDs
df["token_id_bert"] = tokenizer.convert_tokens_to_ids(df["token"].tolist())

# Group by instance to form sequences
grouped = {}
for inst_id, group in df.groupby("inst_id"):
    grouped[inst_id] = {
        "input_ids": group["token_id_bert"].tolist(),
        "init_labels": [0 if lbl=='' else 1 for lbl in group["punt_inicial"]],
        "final_labels": [0 if lbl=='' else (1 if lbl=='.' else (2 if lbl=='?' else 3))
                         for lbl in group["punt_final"]],
        "cap_labels": group["capitalizacion"].tolist(),
        "tokens": group["token"].tolist()
    }

# Create a list of instances for splitting
instances = list(grouped.values())
random.shuffle(instances)
n = len(instances)
train_split = int(0.8 * n)
val_split = int(0.9 * n)
train_data = instances[:train_split]
val_data   = instances[train_split:val_split]
test_data  = instances[val_split:]


In [22]:
from torch.nn.utils.rnn import pad_sequence

class PunctCapitalDataset(Dataset):
    def __init__(self, instances):
        self.instances = instances
    def __len__(self):
        return len(self.instances)
    def __getitem__(self, idx):
        inst = self.instances[idx]
        return (
            torch.tensor(inst["input_ids"], dtype=torch.long),
            torch.tensor(inst["init_labels"], dtype=torch.long),
            torch.tensor(inst["final_labels"], dtype=torch.long),
            torch.tensor(inst["cap_labels"], dtype=torch.long)
        )

def collate_fn(batch):
    input_ids, init_labs, final_labs, cap_labs = zip(*batch)
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
    init_labs  = pad_sequence(init_labs,  batch_first=True, padding_value=-100)
    final_labs = pad_sequence(final_labs, batch_first=True, padding_value=-100)
    cap_labs   = pad_sequence(cap_labs,   batch_first=True, padding_value=-100)
    return input_ids, init_labs, final_labs, cap_labs

train_loader = DataLoader(PunctCapitalDataset(train_data), batch_size=128, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(PunctCapitalDataset(val_data), batch_size=128, shuffle=False, collate_fn=collate_fn)


In [7]:
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
d_model      = 256
nhead        = 8
dim_feed     = 512
num_layers   = 4
dropout      = 0.1
max_len      = 512
pad_idx      = tokenizer.pad_token_id  # usually 0
batch_size   = 32
lr           = 5e-4
epochs       = 15
patience_es  = 3


In [8]:
class CustomTransformerModel(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model=256,
        nhead=8,
        dim_feedforward=512,
        num_layers=4,
        max_len=512,
        pad_idx=0,
        dropout=0.1,
    ):
        super().__init__()
        self.d_model = d_model
        # embeddings + positional
        self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.pos_emb = nn.Embedding(max_len, d_model)
        # transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        # three heads
        self.init_head  = nn.Linear(d_model, 2)
        self.final_head = nn.Linear(d_model, 4)
        self.cap_head   = nn.Linear(d_model, 4)

    def forward(self, input_ids):
        # input_ids: [B, T]
        B, T = input_ids.size()
        # token + positional
        pos = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, T)
        x = self.token_emb(input_ids) * math.sqrt(self.d_model)
        x = x + self.pos_emb(pos)
        # transformer expects [T, B, d_model]
        x = self.transformer(
            x.transpose(0,1),
            src_key_padding_mask=(input_ids == pad_idx)  # mask padded positions
        )
        x = x.transpose(0,1)  # [B, T, d_model]
        return (
            self.init_head(x),   # [B, T, 2]
            self.final_head(x),  # [B, T, 4]
            self.cap_head(x)     # [B, T, 4]
        )


In [9]:
class BertPunctCapModel(nn.Module):
    def __init__(self,pretrained="bert-base-multilingual-cased"):
        super().__init__()
        self.bert = AutoModel.from_pretrained(pretrained)
        h = self.bert.config.hidden_size
        self.h_init = nn.Linear(h,2)
        self.h_fin  = nn.Linear(h,4)
        self.h_cap  = nn.Linear(h,4)
    def forward(self,ids,mask):
        out = self.bert(input_ids=ids,attention_mask=mask).last_hidden_state
        return self.h_init(out), self.h_fin(out), self.h_cap(out)

In [11]:
model = CustomTransformerModel(
    vocab_size=tokenizer.vocab_size,
    d_model=d_model,
    nhead=nhead,
    dim_feedforward=dim_feed,
    num_layers=num_layers,
    max_len=max_len,
    pad_idx=pad_idx,
    dropout=dropout
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = Adam(model.parameters(), lr=lr)
#scheduler = ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=1, verbose=True)




In [12]:
from sklearn.metrics import f1_score, classification_report

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

for epoch in range(1, 1+2):  # e.g. 5 epochs
    model.train()
    running_loss = 0.0
    n_batches = 0
    for input_ids, init_labs, final_labs, cap_labs in train_loader:
        input_ids = input_ids.to(device)
        init_labs  = init_labs.to(device)
        final_labs = final_labs.to(device)
        cap_labs   = cap_labs.to(device)

        optimizer.zero_grad()
        init_logits, final_logits, cap_logits = model(input_ids)

        loss_init  = criterion(init_logits.view(-1, 2),  init_labs.view(-1))
        loss_final = criterion(final_logits.view(-1, 4), final_labs.view(-1))
        loss_cap   = criterion(cap_logits.view(-1, 4),   cap_labs.view(-1))
        loss = loss_init + loss_final + loss_cap
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        n_batches += 1

    avg_train_loss = running_loss / n_batches
    print(f"Epoch {epoch} — Train loss: {avg_train_loss:.4f}")

    # --- Validation ---
    model.eval()
    val_loss = 0.0
    n_val_batches = 0

    all_init_trues,  all_init_preds  = [], []
    all_final_trues, all_final_preds = [], []
    all_cap_trues,   all_cap_preds   = [], []

    with torch.no_grad():
        for input_ids, init_labs, final_labs, cap_labs in val_loader:
            input_ids = input_ids.to(device)
            init_labs  = init_labs.to(device)
            final_labs = final_labs.to(device)
            cap_labs   = cap_labs.to(device)

            init_logits, final_logits, cap_logits = model(input_ids)

            # compute val loss
            loss_init  = criterion(init_logits.view(-1, 2),  init_labs.view(-1))
            loss_final = criterion(final_logits.view(-1, 4), final_labs.view(-1))
            loss_cap   = criterion(cap_logits.view(-1, 4),   cap_labs.view(-1))
            loss = loss_init + loss_final + loss_cap
            val_loss += loss.item()
            n_val_batches += 1

            # get predictions
            init_preds  = init_logits.argmax(dim=-1)
            final_preds = final_logits.argmax(dim=-1)
            cap_preds   = cap_logits.argmax(dim=-1)

            # mask out padding (-100)
            mask_init  = (init_labs.view(-1)  != -100)
            mask_final = (final_labs.view(-1) != -100)
            mask_cap   = (cap_labs.view(-1)   != -100)

            all_init_trues.extend(init_labs.view(-1)[mask_init].cpu().tolist())
            all_init_preds.extend(init_preds.view(-1)[mask_init].cpu().tolist())
            all_final_trues.extend(final_labs.view(-1)[mask_final].cpu().tolist())
            all_final_preds.extend(final_preds.view(-1)[mask_final].cpu().tolist())
            all_cap_trues.extend(cap_labs.view(-1)[mask_cap].cpu().tolist())
            all_cap_preds.extend(cap_preds.view(-1)[mask_cap].cpu().tolist())

    avg_val_loss = val_loss / n_val_batches
    print(f"Epoch {epoch} — Val loss:   {avg_val_loss:.4f}")

    # Compute macro-F1
    f1_init_macro  = f1_score(all_init_trues,  all_init_preds,  average='macro', zero_division=0)
    f1_final_macro = f1_score(all_final_trues, all_final_preds, average='macro', zero_division=0)
    f1_cap_macro   = f1_score(all_cap_trues,   all_cap_preds,   average='macro', zero_division=0)
    print(f"Epoch {epoch} — F1 (macro): init={f1_init_macro:.3f}, final={f1_final_macro:.3f}, cap={f1_cap_macro:.3f}")

    # Per-class F1 reports
    print("\nInitial punctuation per-class F1:")
    print(classification_report(all_init_trues, all_init_preds, labels=[0,1], target_names=['no-¿','¿'], zero_division=0))

    print("Final punctuation per-class F1:")
    print(classification_report(all_final_trues, all_final_preds,
                                labels=[0,1,2,3],
                                target_names=['none','.', '?', ','], zero_division=0))

    print("Capitalization per-class F1:")
    print(classification_report(all_cap_trues, all_cap_preds,
                                labels=[0,1,2,3],
                                target_names=['lower','Initial','Mixed','ALLCAP'], zero_division=0))

    print("-"*60) 


Epoch 1 — Train loss: 0.7682
Epoch 1 — Val loss:   0.6193
Epoch 1 — F1 (macro): init=0.791, final=0.549, cap=0.798

Initial punctuation per-class F1:
              precision    recall  f1-score   support

        no-¿       0.99      1.00      0.99     80562
           ¿       0.75      0.48      0.59      1894

    accuracy                           0.98     82456
   macro avg       0.87      0.74      0.79     82456
weighted avg       0.98      0.98      0.98     82456

Final punctuation per-class F1:
              precision    recall  f1-score   support

        none       0.90      0.97      0.93     69267
           .       0.60      0.42      0.50      7645
           ?       0.64      0.36      0.46      1945
           ,       0.74      0.19      0.31      3599

    accuracy                           0.87     82456
   macro avg       0.72      0.49      0.55     82456
weighted avg       0.86      0.87      0.86     82456

Capitalization per-class F1:
              precision    

In [13]:
from sklearn.metrics import f1_score

model.eval()
output_rows = []

# For metric accumulation
all_init_trues,  all_init_preds  = [], [] 
all_final_trues, all_final_preds = [], []
all_cap_trues,   all_cap_preds   = [], []

idx_map_init    = {0:'', 1:'¿'}
idx_map_final   = {0:'', 1:'.', 2:'?', 3:','}

for inst_id, instance in enumerate(test_data):
    # prepare inputs
    input_ids = torch.tensor(instance["input_ids"], dtype=torch.long).unsqueeze(0).to(device)
    with torch.no_grad():
        init_logits, final_logits, cap_logits = model(input_ids)

    # get token-level preds
    init_pred  = init_logits.argmax(dim=-1).squeeze(0).cpu().tolist()
    final_pred = final_logits.argmax(dim=-1).squeeze(0).cpu().tolist()
    cap_pred   = cap_logits.argmax(dim=-1).squeeze(0).cpu().tolist()

    # retrieve true labels
    init_true  = instance["init_labels"]
    final_true = instance["final_labels"]
    cap_true   = instance["cap_labels"]
    tokens     = instance["tokens"]

    # sanity check
    assert len(init_pred)==len(init_true)==len(tokens)

    # accumulate and record
    for token_idx, token in enumerate(tokens):
        # append to CSV rows
        output_rows.append({
            "instancia_id": inst_id,
            "token_id":     token_idx,
            "token":        token,
            "punt_inicial": idx_map_init[init_pred[token_idx]],
            "punt_final":   idx_map_final[final_pred[token_idx]],
            "capitalizacion": cap_pred[token_idx]
        })
        # accumulate for metrics
        all_init_trues.append(init_true[token_idx])
        all_init_preds.append(init_pred[token_idx])
        all_final_trues.append(final_true[token_idx])
        all_final_preds.append(final_pred[token_idx])
        all_cap_trues.append(cap_true[token_idx])
        all_cap_preds.append(cap_pred[token_idx])

# build and save DataFrame
output_df = pd.DataFrame(output_rows)
output_df.to_csv("predictions.csv", index=False)
print("Wrote predictions.csv")

# compute and print macro-F1 for each task
f1_init  = f1_score(all_init_trues,  all_init_preds,  average="macro", zero_division=0)
f1_final = f1_score(all_final_trues, all_final_preds, average="macro", zero_division=0)
f1_cap   = f1_score(all_cap_trues,   all_cap_preds,   average="macro", zero_division=0)

print(f"Test set performance:")
print(f"  • Initial punctuation F1-macro: {f1_init:.4f}")
print(f"  • Final punctuation   F1-macro: {f1_final:.4f}")
print(f"  • Capitalization      F1-macro: {f1_cap:.4f}")


Wrote predictions.csv
Test set performance:
  • Initial punctuation F1-macro: 0.7848
  • Final punctuation   F1-macro: 0.5956
  • Capitalization      F1-macro: 0.8345


# Bert model

In [23]:
class BertPunctCapModel(nn.Module):
    def __init__(self,pretrained="bert-base-multilingual-cased"):
        super().__init__()
        self.bert = AutoModel.from_pretrained(pretrained)
        h = self.bert.config.hidden_size
        self.h_init = nn.Linear(h,2)
        self.h_fin  = nn.Linear(h,4)
        self.h_cap  = nn.Linear(h,4)
    def forward(self,ids,mask):
        out = self.bert(input_ids=ids,attention_mask=mask).last_hidden_state
        return self.h_init(out), self.h_fin(out), self.h_cap(out)

In [24]:
model =BertPunctCapModel().to(device)

criterion = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = Adam(model.parameters(), lr=lr)
#scheduler = ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=1, verbose=True)


In [25]:
from sklearn.metrics import f1_score, classification_report

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

for epoch in range(1, 1+2):  # e.g. 2 epochs
    # —— TRAIN —— 
    model.train()
    running_loss = 0.0
    for input_ids, init_labs, final_labs, cap_labs in train_loader:
        # move to device
        input_ids = input_ids.to(device)
        init_labs  = init_labs.to(device)
        final_labs = final_labs.to(device)
        cap_labs   = cap_labs.to(device)

        # build attention mask (1 for real tokens, 0 for pad)
        attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)

        optimizer.zero_grad()
        # forward pass: note the two arguments
        init_logits, final_logits, cap_logits = model(input_ids, attention_mask)

        # compute losses
        loss_init  = criterion(init_logits.view(-1, 2),  init_labs.view(-1))
        loss_final = criterion(final_logits.view(-1, 4), final_labs.view(-1))
        loss_cap   = criterion(cap_logits.view(-1, 4),   cap_labs.view(-1))
        loss = loss_init + loss_final + loss_cap
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_train_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch} — Train loss: {avg_train_loss:.4f}")

    # —— VALIDATION —— 
    model.eval()
    val_loss = 0.0

    all_init_trues,  all_init_preds  = [], []
    all_final_trues, all_final_preds = [], []
    all_cap_trues,   all_cap_preds   = [], []

    with torch.no_grad():
        for input_ids, init_labs, final_labs, cap_labs in val_loader:
            input_ids = input_ids.to(device)
            init_labs  = init_labs.to(device)
            final_labs = final_labs.to(device)
            cap_labs   = cap_labs.to(device)

            attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
            init_logits, final_logits, cap_logits = model(input_ids, attention_mask)

            # validation loss
            loss_init  = criterion(init_logits.view(-1, 2),  init_labs.view(-1))
            loss_final = criterion(final_logits.view(-1, 4), final_labs.view(-1))
            loss_cap   = criterion(cap_logits.view(-1, 4),   cap_labs.view(-1))
            val_loss  += (loss_init + loss_final + loss_cap).item()

            # predictions
            init_preds  = init_logits.argmax(dim=-1)
            final_preds = final_logits.argmax(dim=-1)
            cap_preds   = cap_logits.argmax(dim=-1)

            # mask out padding positions
            mask_init  = init_labs.view(-1)  != -100
            mask_final = final_labs.view(-1) != -100
            mask_cap   = cap_labs.view(-1)   != -100

            all_init_trues.extend(init_labs.view(-1)[mask_init].cpu().tolist())
            all_init_preds.extend(init_preds.view(-1)[mask_init].cpu().tolist())
            all_final_trues.extend(final_labs.view(-1)[mask_final].cpu().tolist())
            all_final_preds.extend(final_preds.view(-1)[mask_final].cpu().tolist())
            all_cap_trues.extend(cap_labs.view(-1)[mask_cap].cpu().tolist())
            all_cap_preds.extend(cap_preds.view(-1)[mask_cap].cpu().tolist())

    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch {epoch} — Val loss:   {avg_val_loss:.4f}")

    # compute macro-F1
    f1_init = f1_score(all_init_trues, all_init_preds, average='macro', zero_division=0)
    f1_final = f1_score(all_final_trues, all_final_preds, average='macro', zero_division=0)
    f1_cap = f1_score(all_cap_trues, all_cap_preds, average='macro', zero_division=0)
    print(f"Epoch {epoch} — F1 (macro): init={f1_init:.3f}, final={f1_final:.3f}, cap={f1_cap:.3f}")

    # per-class reports (optional)
    print("Initial punctuation per-class:")
    print(classification_report(all_init_trues, all_init_preds, labels=[0,1], target_names=['no-¿','¿'], zero_division=0))
    print("Final punctuation per-class:")
    print(classification_report(all_final_trues, all_final_preds, labels=[0,1,2,3], target_names=['none','.','? ',','], zero_division=0))
    print("Capitalization per-class:")
    print(classification_report(all_cap_trues, all_cap_preds, labels=[0,1,2,3], target_names=['lower','Initial','Mixed','ALLCAP'], zero_division=0))
    print("-"*50)


Epoch 1 — Train loss: 1.5490
Epoch 1 — Val loss:   1.2930
Epoch 1 — F1 (macro): init=0.494, final=0.228, cap=0.219
Initial punctuation per-class:
              precision    recall  f1-score   support

        no-¿       0.98      1.00      0.99      8075
           ¿       0.00      0.00      0.00       204

    accuracy                           0.98      8279
   macro avg       0.49      0.50      0.49      8279
weighted avg       0.95      0.98      0.96      8279

Final punctuation per-class:
              precision    recall  f1-score   support

        none       0.84      1.00      0.91      6948
           .       0.00      0.00      0.00       757
          ?        0.00      0.00      0.00       205
           ,       0.00      0.00      0.00       369

    accuracy                           0.84      8279
   macro avg       0.21      0.25      0.23      8279
weighted avg       0.70      0.84      0.77      8279

Capitalization per-class:
              precision    recall  f1

KeyboardInterrupt: 

In [None]:
from sklearn.metrics import f1_score

model.eval()
output_rows = []

# For metric accumulation
all_init_trues,  all_init_preds  = [], [] 
all_final_trues, all_final_preds = [], []
all_cap_trues,   all_cap_preds   = [], []

idx_map_init    = {0:'', 1:'¿'}
idx_map_final   = {0:'', 1:'.', 2:'?', 3:','}

for inst_id, instance in enumerate(test_data):
    # prepare inputs
    input_ids = torch.tensor(instance["input_ids"], dtype=torch.long).unsqueeze(0).to(device)
    with torch.no_grad():
        init_logits, final_logits, cap_logits = model(input_ids)

    # get token-level preds
    init_pred  = init_logits.argmax(dim=-1).squeeze(0).cpu().tolist()
    final_pred = final_logits.argmax(dim=-1).squeeze(0).cpu().tolist()
    cap_pred   = cap_logits.argmax(dim=-1).squeeze(0).cpu().tolist()

    # retrieve true labels
    init_true  = instance["init_labels"]
    final_true = instance["final_labels"]
    cap_true   = instance["cap_labels"]
    tokens     = instance["tokens"]

    # sanity check
    assert len(init_pred)==len(init_true)==len(tokens)

    # accumulate and record
    for token_idx, token in enumerate(tokens):
        # append to CSV rows
        output_rows.append({
            "instancia_id": inst_id,
            "token_id":     token_idx,
            "token":        token,
            "punt_inicial": idx_map_init[init_pred[token_idx]],
            "punt_final":   idx_map_final[final_pred[token_idx]],
            "capitalizacion": cap_pred[token_idx]
        })
        # accumulate for metrics
        all_init_trues.append(init_true[token_idx])
        all_init_preds.append(init_pred[token_idx])
        all_final_trues.append(final_true[token_idx])
        all_final_preds.append(final_pred[token_idx])
        all_cap_trues.append(cap_true[token_idx])
        all_cap_preds.append(cap_pred[token_idx])

# build and save DataFrame
output_df = pd.DataFrame(output_rows)
output_df.to_csv("predictions.csv", index=False)
print("Wrote predictions.csv")

# compute and print macro-F1 for each task
f1_init  = f1_score(all_init_trues,  all_init_preds,  average="macro", zero_division=0)
f1_final = f1_score(all_final_trues, all_final_preds, average="macro", zero_division=0)
f1_cap   = f1_score(all_cap_trues,   all_cap_preds,   average="macro", zero_division=0)

print(f"Test set performance:")
print(f"  • Initial punctuation F1-macro: {f1_init:.4f}")
print(f"  • Final punctuation   F1-macro: {f1_final:.4f}")
print(f"  • Capitalization      F1-macro: {f1_cap:.4f}")
