In [None]:
import os, json, math, time, random, platform, argparse
from typing import Dict, Iterator, List, Optional
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
try:
    from torch.utils.data import BufferedShuffleDataset  # optional, not required
    HAS_BUF_SHUFFLE = True
except Exception:
    HAS_BUF_SHUFFLE = False

import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast, get_linear_schedule_with_warmup

In [None]:
def load_local_tokenizer(path: str):
    tok = PreTrainedTokenizerFast(tokenizer_file=path)
    # Ensure specials
    if tok.cls_token is None: tok.cls_token = "[CLS]"
    if tok.sep_token is None: tok.sep_token = "[SEP]"
    if tok.pad_token is None: tok.add_special_tokens({"pad_token": "[PAD]"})
    return tok

import json, torch, random
from typing import Dict, List, Optional, Iterator
from torch.utils.data import IterableDataset, DataLoader

class InMemoryJsonlRows(Dataset):
    """
    Loads a whole JSONL into RAM on init.
    Each row is expected to have:
      {"query_tokenized":[...], "target_tokenized":[...], "target_mask":[...]}
    """
    def __init__(self, path: str):
        self.data: List[Dict] = []
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                self.data.append(json.loads(line))
            print('Loaded')

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict:
        return self.data[idx]

# ---------- helpers ----------
def pad_1d(seqs: List[torch.Tensor], pad_val: int) -> torch.Tensor:
    max_len = max(x.size(0) for x in seqs)
    out = []
    for x in seqs:
        if x.size(0) < max_len:
            pad = torch.full((max_len - x.size(0),), pad_val, dtype=x.dtype)
            x = torch.cat([x, pad], dim=0)
        out.append(x)
    return torch.stack(out, dim=0)

def make_collate_two_stream(cls_id: int, sep_id: int, pad_id: int,
                            max_query_len: Optional[int] = 128,
                            max_target_len: Optional[int] = 480):
    """
    Returns a collate_fn that:
      - truncates q/t to max lengths
      - inserts [CLS] ... [SEP]
      - aligns target_mask to target and makes labels (0 on CLS/SEP)
      - pads everything to batch max
    """
    def collate(batch: List[Dict]) -> Dict[str, torch.Tensor]:
        q_ids_list, q_attn_list = [], []
        t_ids_list, t_attn_list = [], []
        labels_list = []

        for ex in batch:
            q = ex["query_tokenized"]
            t = ex["target_tokenized"]
            m = ex["target_mask"]

            # truncate
            if max_query_len: q = q[:max_query_len]
            if max_target_len: t = t[:max_target_len]

            # align mask length to (possibly truncated) target
            if len(m) != len(ex["target_tokenized"]):
                # if your data guarantees equality you can drop this guard
                if len(m) > len(ex["target_tokenized"]):
                    m = m[:len(ex["target_tokenized"])]
                else:
                    m = m + [0] * (len(ex["target_tokenized"]) - len(m))
            m = m[:len(t)]
            if len(m) < len(t):
                m = m + [0] * (len(t) - len(m))

            # insert specials
            q_ids = [cls_id] + q + [sep_id]
            t_ids = [cls_id] + t + [sep_id]

            # attention masks (1 where real tokens)
            q_attn = [1] * len(q_ids)
            t_attn = [1] * len(t_ids)

            # labels over target: 0 on CLS/SEP
            labels = [0.0] * len(t_ids)
            for i, bit in enumerate(m):
                labels[1 + i] = float(bit)

            # to tensors
            q_ids_list.append(torch.tensor(q_ids, dtype=torch.long))
            q_attn_list.append(torch.tensor(q_attn, dtype=torch.long))
            t_ids_list.append(torch.tensor(t_ids, dtype=torch.long))
            t_attn_list.append(torch.tensor(t_attn, dtype=torch.long))
            labels_list.append(torch.tensor(labels, dtype=torch.float))

        # pad to batch max
        q_ids  = pad_1d(q_ids_list, pad_id)
        q_attn = pad_1d(q_attn_list, 0)
        t_ids  = pad_1d(t_ids_list, pad_id)
        t_attn = pad_1d(t_attn_list, 0)
        labels = pad_1d(labels_list, 0)

        valid_mask = (t_attn == 1).float()  # ignores CLS/SEP + padding
        return {
            "query_input_ids": q_ids,
            "query_attention_mask": q_attn,
            "target_input_ids": t_ids,
            "target_attention_mask": t_attn,
            "labels_full": labels,
            "valid_mask": valid_mask
        }
    return collate


In [None]:
class ConvBlock1D(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, p=None):
        super().__init__()
        if p is None: p = k // 2
        self.net = nn.Sequential(
            nn.Conv1d(in_ch, out_ch, k, padding=p),
            nn.ReLU(),
            nn.Conv1d(out_ch, out_ch, k, padding=p),
            nn.ReLU()
        )
    def forward(self, x): return self.net(x)

class UNet1DHead(nn.Module):
    def __init__(self, hidden=768, dropout=0.1):
        super().__init__()
        self.enc1 = ConvBlock1D(hidden, 256)
        self.pool1 = nn.MaxPool1d(2, ceil_mode=True)
        self.enc2 = ConvBlock1D(256, 512)
        self.pool2 = nn.MaxPool1d(2, ceil_mode=True)
        self.bottleneck = ConvBlock1D(512, 1024)
        self.up2 = nn.ConvTranspose1d(1024, 512, 2, stride=2)
        self.dec2 = ConvBlock1D(1024, 512)
        self.up1 = nn.ConvTranspose1d(512, 256, 2, stride=2)
        self.dec1 = ConvBlock1D(512, 256)
        self.drop = nn.Dropout(dropout)
        self.classifier = nn.Conv1d(256, 1, 1)

    @staticmethod
    def _crop_or_pad(x, T_target):
        B, C, T = x.shape
        if T == T_target: return x
        if T > T_target:
            start = (T - T_target) // 2
            return x[:, :, start:start+T_target]
        pad = T_target - T
        left = pad // 2; right = pad - left
        return nn.functional.pad(x, (left, right))

    def forward(self, hs_tgt):  # [B,T,H]
        T_orig = hs_tgt.size(1)
        x = hs_tgt.permute(0, 2, 1)      # [B,H,T]
        e1 = self.enc1(x)                # [B,256,T]
        p1 = self.pool1(e1)              # ~T/2
        e2 = self.enc2(p1)               # [B,512,~T/2]
        p2 = self.pool2(e2)              # ~T/4
        b  = self.bottleneck(p2)         # [B,1024,~T/4]
        u2 = self.up2(b)                 # ~T/2
        e2c = self._crop_or_pad(e2, u2.size(-1))
        d2 = self.dec2(torch.cat([u2, e2c], dim=1))
        u1 = self.up1(d2)                # ~T
        e1c = self._crop_or_pad(e1, u1.size(-1))
        d1 = self.dec1(torch.cat([u1, e1c], dim=1))
        d1 = self.drop(d1)
        logits = self.classifier(d1).squeeze(1)  # [B,~T]
        if logits.size(1) != T_orig:
            logits = self._crop_or_pad(logits.unsqueeze(1), T_orig).squeeze(1)
        return logits

class FiLMConditioner(nn.Module):
    def __init__(self, hidden=768):
        super().__init__()
        self.gamma = nn.Linear(hidden, hidden)
        self.beta  = nn.Linear(hidden, hidden)
    def forward(self, hs_tgt, q_vec):  # hs_tgt [B,T,H], q_vec [B,H]
        gamma = self.gamma(q_vec).unsqueeze(1)  # [B,1,H]
        beta  = self.beta(q_vec).unsqueeze(1)   # [B,1,H]
        return hs_tgt * (1.0 + gamma) + beta

class LabseTwoStreamUNet(nn.Module):
    """
    Shared frozen LaBSE encodes query and target separately.
    Query pooled -> FiLM conditioning of target features -> U-Net head -> per-target-token logits.
    """
    def __init__(self, base_model: str, dropout: float = 0.1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(base_model)
        for p in self.encoder.parameters(): p.requires_grad = False
        for p in self.encoder.encoder.layer[-4:].parameters(): p.requires_grad = True
        hidden = self.encoder.config.hidden_size
        self.conditioner = FiLMConditioner(hidden=hidden)
        self.head = UNet1DHead(hidden=hidden, dropout=dropout)

    @staticmethod
    def mean_pool(hs, attn):  # hs [B,T,H], attn [B,T]
        mask = attn.unsqueeze(-1).float()
        num = (hs * mask).sum(dim=1)
        den = mask.sum(dim=1).clamp_min(1.0)
        return num / den

    def forward(self, q_ids, q_attn, t_ids, t_attn):
        q_out = self.encoder(input_ids=q_ids, attention_mask=q_attn)
        q_vec = self.mean_pool(q_out.last_hidden_state, q_attn)  # [B,H]

        t_out = self.encoder(input_ids=t_ids, attention_mask=t_attn)
        t_hs  = t_out.last_hidden_state                           # [B,T,H]

        t_cond = self.conditioner(t_hs, q_vec)                    # [B,T,H]
        logits = self.head(t_cond)                                # [B,T]
        return logits


In [None]:
def _bin_to_spans(bin_arr):
    """
    Convert a 0/1 iterable into spans over indices [start, end) where 1's are contiguous.
    Example: [0,1,1,0,1] -> [(1,3), (4,5)]
    """
    spans = []
    in_run = False
    start = 0
    for i, v in enumerate(bin_arr):
        if v and not in_run:
            in_run = True
            start = i
        elif not v and in_run:
            in_run = False
            spans.append((start, i))
    if in_run:
        spans.append((start, len(bin_arr)))
    return spans

def _exact_span_prf1(pred_spans, gold_spans):
    """
    Exact boundary match. Count TP when a predicted (s,e) is in gold exactly.
    """
    pred_set = set(pred_spans)
    gold_set = set(gold_spans)
    tp = len(pred_set & gold_set)
    fp = len(pred_set - gold_set)
    fn = len(gold_set - pred_set)
    prec = tp / (tp + fp + 1e-9)
    rec  = tp / (tp + fn + 1e-9)
    f1   = 2 * prec * rec / (prec + rec + 1e-9)
    return tp, fp, fn, prec, rec, f1

def _span_metrics_from_logits_row(logits_row, labels_row, t_attn_row, thr=0.5, smooth_k=0):
    """
    Compute span-level stats for a single sample in the batch.
    We:
      - restrict to real target tokens (exclude target CLS at pos 0 and SEP at last)
      - threshold probs -> binary
      - optional 1D median smoothing (smooth_k must be odd >=3; 0 = off)
      - convert to spans and compare to gold spans
    """
    import torch

    # length of target stream (including CLS/SEP)
    L = int(t_attn_row.sum().item())
    if L < 3:
        # not enough tokens to form text (CLS + SEP only)
        return 0, 0, 0, 0.0, 0.0, 0.0

    # Slice to TEXT ONLY region: positions [1 .. L-2]
    probs = torch.sigmoid(logits_row[1:L-1])
    gold  = labels_row[1:L-1]  # already 0/1 there

    if smooth_k and smooth_k >= 3 and (smooth_k % 2 == 1):
        # simple median filter (no external deps)
        pad = smooth_k // 2
        # replicate-pad
        padded = torch.nn.functional.pad(probs.unsqueeze(0).unsqueeze(0), (pad, pad), mode='replicate')[0,0]
        windows = padded.unfold(0, smooth_k, 1)
        probs = windows.median(dim=1).values

    pred_bin = (probs >= thr).to(torch.int).tolist()
    gold_bin = gold.to(torch.int).tolist()

    pred_spans = _bin_to_spans(pred_bin)
    gold_spans = _bin_to_spans(gold_bin)

    return _exact_span_prf1(pred_spans, gold_spans)

def span_eval_batch(logits, labels_full, target_attention_mask, thr=0.5, smooth_k=0):
    """
    Aggregate span-level metrics over a batch.
    Returns (prec, rec, f1, tp, fp, fn).
    """
    B = logits.size(0)
    tp=fp=fn=0
    for i in range(B):
        _tp,_fp,_fn, *_ = _span_metrics_from_logits_row(
            logits[i], labels_full[i], target_attention_mask[i], thr=thr, smooth_k=smooth_k
        )
        tp += _tp; fp += _fp; fn += _fn
    prec = tp / (tp + fp + 1e-9)
    rec  = tp / (tp + fn + 1e-9)
    f1   = 2 * prec * rec / (prec + rec + 1e-9)
    return prec, rec, f1, tp, fp, fn

def token_f1_precision_recall(logits, labels_full, valid_mask, thr=0.5):
    with torch.no_grad():
        probs = torch.sigmoid(logits)
        preds = (probs >= thr).float()
        y_true = labels_full * valid_mask
        y_pred = preds * valid_mask
        tp = (y_true * y_pred).sum().item()
        fp = ((1 - y_true) * y_pred).sum().item()
        fn = (y_true * (1 - y_pred)).sum().item()
        prec = tp / (tp + fp + 1e-9)
        rec  = tp / (tp + fn + 1e-9)
        f1   = 2 * prec * rec / (prec + rec + 1e-9)
        return f1, prec, rec

def bench_throughput(loader, model, steps=50):
    device = next(model.parameters()).device
    it = iter(loader)
    # warmup a few
    for _ in range(3):
        try:
            b = next(it)
        except StopIteration:
            it = iter(loader); b = next(it)
        with torch.no_grad(), torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            _ = model(b["query_input_ids"].to(device),
                      b["query_attention_mask"].to(device),
                      b["target_input_ids"].to(device),
                      b["target_attention_mask"].to(device))
    torch.cuda.synchronize()
    t0 = time.time()
    batches = 0
    tokens = 0
    for _ in range(steps):
        try:
            b = next(it)
        except StopIteration:
            it = iter(loader); b = next(it)
        B = b["target_input_ids"].shape[0]
        Tq = b["query_input_ids"].shape[1]
        Tt = b["target_input_ids"].shape[1]
        with torch.no_grad(), torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            _ = model(b["query_input_ids"].to(device),
                      b["query_attention_mask"].to(device),
                      b["target_input_ids"].to(device),
                      b["target_attention_mask"].to(device))
        batches += 1
        tokens  += B * (Tq + Tt)
    torch.cuda.synchronize()
    dt = time.time() - t0
    return tokens / dt, batches / dt


In [None]:
tokenizer_path = '/content/drive/MyDrive/labse_tokenizer.json'
model_name = "sentence-transformers/LaBSE"
train_path = "/content/drive/MyDrive/dataset-yeshibish-labse-train.jsonl"
val_path = "/content/drive/MyDrive/dataset-yeshibish-labse-val.jsonl"
batch_size = 128
max_epochs = 15
lr = 2e-4
weight_decay = 5e-4
dropout = 0.3
max_query_len = 480
max_target_len = 480
early_stop_patience = 4

# OS-aware DataLoader settings
is_windows = platform.system().lower().startswith("win")
num_workers = 0 if is_windows else 2
pin_memory = not is_windows
persistent = not is_windows

tok = load_local_tokenizer(tokenizer_path)
CLS, SEP, PAD = tok.cls_token_id, tok.sep_token_id, tok.pad_token_id

train_ds = InMemoryJsonlRows(train_path)
val_ds   = InMemoryJsonlRows(val_path)

collate = make_collate_two_stream(CLS, SEP, PAD, max_query_len=max_query_len, max_target_len=max_target_len)

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,                 # <-- free shuffle, no Drive I/O
    collate_fn=collate,           # your collate_two_stream
    num_workers=num_workers,
    pin_memory=pin_memory,
    persistent_workers=persistent
)

val_loader = DataLoader(
    val_ds,
    batch_size=32,
    shuffle=False,                # usually keep val deterministic
    collate_fn=collate,
    num_workers=num_workers,
    pin_memory=pin_memory,
    persistent_workers=persistent
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = LabseTwoStreamUNet(base_model=model_name, dropout=dropout).to(device)
# Trainable params only (FiLM + U-Net head)
optim = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad],
                            lr=lr, weight_decay=weight_decay)

# Estimate steps/epoch (streaming: approximate with a few batches)
# If you know dataset size, you can set steps_per_epoch exactly.
steps_per_epoch = 1000  # safe default; logging will still be useful

total_steps = max_epochs * steps_per_epoch
best_span_f1   = 0.0

sched = get_linear_schedule_with_warmup(optim, num_warmup_steps=max(10, total_steps // 10),
                                        num_training_steps=total_steps)

bce = nn.BCEWithLogitsLoss(reduction="none")
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

# quick throughput probe
try:
    tps, bps = bench_throughput(train_loader, model, steps=30)
    print(f"[BENCH] ~{tps:.0f} tokens/s, {bps:.2f} batches/s on this setup.")
except Exception as e:
    print(f"[BENCH] Skipped ({e})")

bad_epochs = 0
train_iter = iter(train_loader)

for epoch in range(1, max_epochs+1):
    model.train()
    with tqdm(total=steps_per_epoch, desc=f"Epoch {epoch}/{max_epochs} [train]", leave=True) as pbar:
        for step_in_epoch in range(steps_per_epoch):
            try:
                batch = next(train_iter)
            except StopIteration:
                train_iter = iter(train_loader)  # restart
                batch = next(train_iter)
            q_ids  = batch["query_input_ids"].to(device)
            q_attn = batch["query_attention_mask"].to(device)
            t_ids  = batch["target_input_ids"].to(device)
            t_attn = batch["target_attention_mask"].to(device)
            labels = batch["labels_full"].to(device)
            valid  = batch["valid_mask"].to(device)

            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                logits = model(q_ids, q_attn, t_ids, t_attn)  # [B,T]
                # crop/pad logits to labels size (rare in practice)
                if logits.size(1) != labels.size(1):
                    T = labels.size(1)
                    if logits.size(1) > T:
                        logits = logits[:, :T]
                    else:
                        pad = torch.zeros((logits.size(0), T - logits.size(1)), device=logits.device)
                        logits = torch.cat([logits, pad], dim=1)
                loss_tok = bce(logits, labels) * valid
                denom = valid.sum().clamp_min(1.0)
                loss = loss_tok.sum() / denom

            optim.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            scaler.unscale_(optim)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optim); scaler.update()
            sched.step()

            f1, p, r = token_f1_precision_recall(logits, labels, valid, thr=0.5)
            pbar.set_postfix(loss=f"{loss.item():.4f}", F1=f"{f1:.4f}", P=f"{p:.4f}", R=f"{r:.4f}")
            pbar.update(1)


    # Validation epoch (single pass)
    model.eval()
    val_loss_sum, val_tok_sum = 0.0, 0.0
    agg_f1 = agg_p = agg_r = 0.0
    nb = 0
    span_tp = span_fp = span_fn = 0

    with torch.no_grad():
        for batch in val_loader:
            q_ids  = batch["query_input_ids"].to(device)
            q_attn = batch["query_attention_mask"].to(device)
            t_ids  = batch["target_input_ids"].to(device)
            t_attn = batch["target_attention_mask"].to(device)
            labels = batch["labels_full"].to(device)
            valid  = batch["valid_mask"].to(device)

            logits = model(q_ids, q_attn, t_ids, t_attn)
            if logits.size(1) != labels.size(1):
                T = labels.size(1)
                if logits.size(1) > T:
                    logits = logits[:, :T]
                else:
                    pad = torch.zeros((logits.size(0), T - logits.size(1)), device=logits.device)
                    logits = torch.cat([logits, pad], dim=1)

            loss_tok = bce(logits, labels) * valid
            denom = valid.sum().clamp_min(1.0)
            loss = loss_tok.sum() / denom
            val_loss_sum += loss.item() * denom.item()
            val_tok_sum  += denom.item()

            f1, p, r = token_f1_precision_recall(logits, labels, valid, thr=0.5)
            agg_f1 += f1; agg_p += p; agg_r += r; nb += 1

            sp_prec, sp_rec, sp_f1, tp, fp, fn = span_eval_batch(
            logits, labels, t_attn, thr=0.5, smooth_k=3  # try k=3; set 0 to disable
            )
            span_tp += tp; span_fp += fp; span_fn += fn

    val_loss = val_loss_sum / max(1.0, val_tok_sum)
    val_f1 = (agg_f1 / max(1, nb)) if nb else 0.0
    val_p  = (agg_p  / max(1, nb)) if nb else 0.0
    val_r  = (agg_r  / max(1, nb)) if nb else 0.0
    span_prec = span_tp / (span_tp + span_fp + 1e-9)
    span_rec  = span_tp / (span_tp + span_fn + 1e-9)
    span_f1   = 2 * span_prec * span_rec / (span_prec + span_rec + 1e-9)

    print(f"[VAL] token-F1 {val_f1:.4f} (P {val_p:.4f} R {val_r:.4f})  "
        f"| span-F1 {span_f1:.4f} (P {span_prec:.4f} R {span_rec:.4f})")

    # Early stopping + checkpoint by F1
    os.makedirs("checkpoints", exist_ok=True)
    if span_f1 > best_span_f1 + 1e-4:
        best_span_f1 = span_f1
        bad_epochs = 0
        state = {
            "epoch": epoch,
            "model": model.state_dict(),
            "optim": optim.state_dict(),
            "sched": sched.state_dict(),
            "scaler": scaler.state_dict(),
            "best_val_f1": best_span_f1,
            "config": { "MODEL_NAME": 'LaBse-Unet', "dropout": dropout }
        }
        torch.save(state, "/content/drive/MyDrive/labse_unet_best_dropout_03_last_4_unfr.pt")
    else:
        bad_epochs += 1
        if bad_epochs >= early_stop_patience:
            print("Early stopping.")
            break

Loaded
Loaded


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/804 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.88G [00:00<?, ?B/s]

[BENCH] ~128505 tokens/s, 1.34 batches/s on this setup.


Epoch 1/15 [train]: 100%|██████████| 1000/1000 [23:44<00:00,  1.42s/it, F1=0.6682, P=0.8802, R=0.5386, loss=0.0460]


[VAL] token-F1 0.6728 (P 0.7517 R 0.6147)  | span-F1 0.1912 (P 0.2032 R 0.1805)


Epoch 2/15 [train]: 100%|██████████| 1000/1000 [23:38<00:00,  1.42s/it, F1=0.8251, P=0.8365, R=0.8141, loss=0.0250]


[VAL] token-F1 0.8178 (P 0.8674 R 0.7776)  | span-F1 0.5107 (P 0.5239 R 0.4982)


Epoch 3/15 [train]: 100%|██████████| 1000/1000 [23:39<00:00,  1.42s/it, F1=0.9067, P=0.9452, R=0.8712, loss=0.0162]


[VAL] token-F1 0.8583 (P 0.9096 R 0.8167)  | span-F1 0.6096 (P 0.6369 R 0.5845)


Epoch 4/15 [train]: 100%|██████████| 1000/1000 [23:40<00:00,  1.42s/it, F1=0.8378, P=0.8389, R=0.8366, loss=0.0205]


[VAL] token-F1 0.8761 (P 0.8760 R 0.8798)  | span-F1 0.6433 (P 0.6455 R 0.6410)


Epoch 5/15 [train]: 100%|██████████| 1000/1000 [23:35<00:00,  1.42s/it, F1=0.9037, P=0.9426, R=0.8679, loss=0.0137]


[VAL] token-F1 0.8832 (P 0.8836 R 0.8861)  | span-F1 0.6681 (P 0.6654 R 0.6708)


Epoch 6/15 [train]: 100%|██████████| 1000/1000 [23:36<00:00,  1.42s/it, F1=0.9160, P=0.9398, R=0.8935, loss=0.0093]


[VAL] token-F1 0.8980 (P 0.9221 R 0.8780)  | span-F1 0.7125 (P 0.7276 R 0.6981)


Epoch 7/15 [train]: 100%|██████████| 1000/1000 [23:34<00:00,  1.41s/it, F1=0.9270, P=0.9538, R=0.9016, loss=0.0087]


[VAL] token-F1 0.8973 (P 0.9070 R 0.8909)  | span-F1 0.7151 (P 0.7215 R 0.7089)


Epoch 8/15 [train]: 100%|██████████| 1000/1000 [23:35<00:00,  1.42s/it, F1=0.9499, P=0.9641, R=0.9361, loss=0.0067]


[VAL] token-F1 0.9060 (P 0.9278 R 0.8880)  | span-F1 0.7318 (P 0.7433 R 0.7207)


Epoch 9/15 [train]: 100%|██████████| 1000/1000 [23:37<00:00,  1.42s/it, F1=0.9375, P=0.9334, R=0.9417, loss=0.0086]


[VAL] token-F1 0.9030 (P 0.9012 R 0.9074)  | span-F1 0.7277 (P 0.7205 R 0.7350)


Epoch 10/15 [train]: 100%|██████████| 1000/1000 [23:33<00:00,  1.41s/it, F1=0.9760, P=0.9797, R=0.9723, loss=0.0047]


[VAL] token-F1 0.9092 (P 0.9297 R 0.8923)  | span-F1 0.7460 (P 0.7550 R 0.7372)


Epoch 11/15 [train]: 100%|██████████| 1000/1000 [23:40<00:00,  1.42s/it, F1=0.9625, P=0.9778, R=0.9476, loss=0.0055]


[VAL] token-F1 0.9131 (P 0.9211 R 0.9076)  | span-F1 0.7518 (P 0.7532 R 0.7504)


Epoch 12/15 [train]: 100%|██████████| 1000/1000 [23:28<00:00,  1.41s/it, F1=0.9657, P=0.9947, R=0.9383, loss=0.0036]


[VAL] token-F1 0.9125 (P 0.9370 R 0.8920)  | span-F1 0.7593 (P 0.7726 R 0.7465)


Epoch 13/15 [train]: 100%|██████████| 1000/1000 [23:36<00:00,  1.42s/it, F1=0.9842, P=0.9776, R=0.9909, loss=0.0021]


[VAL] token-F1 0.9132 (P 0.9483 R 0.8833)  | span-F1 0.7600 (P 0.7820 R 0.7391)


Epoch 14/15 [train]: 100%|██████████| 1000/1000 [23:36<00:00,  1.42s/it, F1=0.9749, P=0.9784, R=0.9714, loss=0.0027]


[VAL] token-F1 0.9160 (P 0.9277 R 0.9074)  | span-F1 0.7639 (P 0.7679 R 0.7600)


Epoch 15/15 [train]: 100%|██████████| 1000/1000 [23:33<00:00,  1.41s/it, F1=0.9728, P=0.9577, R=0.9884, loss=0.0033]


[VAL] token-F1 0.9172 (P 0.9394 R 0.8987)  | span-F1 0.7720 (P 0.7829 R 0.7614)
