In [13]:
#!/usr/bin/env python3
# train_combined_deberta_fixed_imports.py

import os
import json
import random
import difflib
from pathlib import Path
from typing import List, Dict, Tuple

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score
from tqdm.auto import tqdm

# -----------------------
# CONFIG - edit if needed
# -----------------------
MODEL_NAME = "microsoft/deberta-v3-base"
AGENT_FILE = "/kaggle/input/primary-secondary-data-with-general/labeled_intent_dataset_agent_with_general.json"
CUSTOMER_FILE = "/kaggle/input/primary-secondary-data-with-general/labeled_intent_dataset_customer_with_general.json"
TAXONOMY_PATH = "/kaggle/input/taxonomy-json/taxonomy.json"
OUTPUT_DIR = "/kaggle/working/deberta_debug_model_fixed"

BATCH_SIZE = 16
MAX_LEN = 128
LR = 2e-5
EPOCHS = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
ALPHA_SEC = 1.0  # weight for secondary loss
WARMUP_STEPS = 100

os.makedirs(OUTPUT_DIR, exist_ok=True)
random.seed(SEED)
torch.manual_seed(SEED)

# -----------------------
# Helpers
# -----------------------
def load_json_list(path: str) -> List[dict]:
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    if not isinstance(data, list):
        raise ValueError(f"{path} must contain a JSON array (list of objects).")
    return data

def normalize_str(s):
    if s is None:
        return ""
    return str(s).strip()

def fuzzy_match_one(item: str, choices: List[str], cutoff: float = 0.6) -> str:
    item_s = normalize_str(item)
    if not item_s:
        return None
    # exact match
    for c in choices:
        if item_s == c:
            return c
    # case-insensitive exact
    for c in choices:
        if item_s.lower() == c.lower():
            return c
    # fuzzy via difflib
    matches = difflib.get_close_matches(item_s, choices, n=1, cutoff=cutoff)
    if matches:
        return matches[0]
    # also try lowercased choices
    matches = difflib.get_close_matches(item_s.lower(), [c.lower() for c in choices], n=1, cutoff=cutoff)
    if matches:
        # find original-cased candidate
        for c in choices:
            if c.lower() == matches[0]:
                return c
    return None

# -----------------------
# Load taxonomy & build label maps
# -----------------------
with open(TAXONOMY_PATH, "r", encoding="utf-8") as f:
    taxonomy = json.load(f)

primary_labels = sorted(list(taxonomy.keys()))
primary2id = {p: i for i, p in enumerate(primary_labels)}
id2primary = {i: p for p, i in primary2id.items()}

sec_agent = sorted({s for v in taxonomy.values() for s in v.get("agent", [])})
sec_customer = sorted({s for v in taxonomy.values() for s in v.get("user", [])})
all_secondary = sorted(list(set(sec_agent + sec_customer)))
sec2id = {s: i for i, s in enumerate(all_secondary)}
id2sec = {i: s for s, i in sec2id.items()}

# allowed_map: (primary_id, speaker) -> list of allowed secondary ids
allowed_map: Dict[Tuple[int, str], List[int]] = {}
for p_name, p_id in primary2id.items():
    a_children = taxonomy[p_name].get("agent", [])
    u_children = taxonomy[p_name].get("user", [])
    allowed_map[(p_id, "agent")] = [sec2id[s] for s in a_children if s in sec2id]
    allowed_map[(p_id, "customer")] = [sec2id[s] for s in u_children if s in sec2id]

# -----------------------
# Load raw data files
# -----------------------
agent_raw = load_json_list(AGENT_FILE)
customer_raw = load_json_list(CUSTOMER_FILE)
print("Loaded agent:", len(agent_raw), "customer:", len(customer_raw))

# -----------------------
# Fuzzy map primary & secondary labels to taxonomy (best-effort)
# -----------------------
def map_records(records: List[dict]) -> List[dict]:
    mapped = []
    for r in records:
        rr = dict(r)  # copy
        prim = normalize_str(rr.get("primary_intent") or "")
        sec = normalize_str(rr.get("secondary_intent") or "")
        sp = normalize_str(rr.get("speaker") or rr.get("original_speaker") or "").lower()

        # map primary
        mapped_primary = fuzzy_match_one(prim, primary_labels, cutoff=0.6)
        if mapped_primary:
            rr["primary_intent"] = mapped_primary
        # map secondary from relevant pool
        pool = sec_agent if sp == "agent" else sec_customer
        mapped_secondary = fuzzy_match_one(sec, pool, cutoff=0.55)
        if mapped_secondary:
            rr["secondary_intent"] = mapped_secondary

        # normalize speaker
        if sp.startswith("agent"):
            rr["speaker"] = "agent"
        elif sp.startswith("cust") or sp.startswith("customer") or sp.startswith("user"):
            rr["speaker"] = "customer"
        else:
            # keep provided speaker lowercased (fallback)
            rr["speaker"] = sp if sp else rr.get("speaker","").lower() or "customer"
        mapped.append(rr)
    return mapped

agent_mapped = map_records(agent_raw)
customer_mapped = map_records(customer_raw)

# -----------------------
# Filter valid records
# -----------------------
def is_valid_record(r):
    return (
        (r.get("primary_intent") in primary2id)
        and (r.get("secondary_intent") in sec2id)
        and (str(r.get("speaker","")).lower() in ("agent","customer"))
        and ( (r.get("text") or r.get("full_text") or "").strip() != "" )
    )

agent_valid = [r for r in agent_mapped if is_valid_record(r)]
customer_valid = [r for r in customer_mapped if is_valid_record(r)]
all_records = agent_valid + customer_valid
random.shuffle(all_records)
print("Valid records -> agent:", len(agent_valid), "customer:", len(customer_valid), "total:", len(all_records))

if len(all_records) == 0:
    raise RuntimeError("No valid records after mapping/filtering. Check file contents, keys 'primary_intent' and 'secondary_intent', and taxonomy strings.")

# -----------------------
# Tokenizer & Dataset
# -----------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

class IntentDataset(Dataset):
    def __init__(self, records: List[dict], tokenizer, max_len=128):
        self.records = records
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        r = self.records[idx]
        text = r.get("text") or r.get("full_text") or ""
        enc = self.tokenizer(text, truncation=True, padding="max_length", max_length=self.max_len, return_tensors="pt")
        primary_id = primary2id.get(r.get("primary_intent"), -1)
        secondary_id = sec2id.get(r.get("secondary_intent"), -1)
        return {
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "speaker": r.get("speaker","").lower(),
            "primary_id": primary_id,
            "secondary_id": secondary_id
        }

def collate_fn(batch):
    input_ids = torch.stack([b["input_ids"] for b in batch])
    attention_mask = torch.stack([b["attention_mask"] for b in batch])
    speakers = [b["speaker"] for b in batch]
    primaries = torch.tensor([b["primary_id"] for b in batch], dtype=torch.long)
    secondaries = torch.tensor([b["secondary_id"] for b in batch], dtype=torch.long)
    return {"input_ids": input_ids, "attention_mask": attention_mask, "speakers": speakers, "primaries": primaries, "secondaries": secondaries}

# train/val split
split_idx = int(0.8 * len(all_records))
train_records = all_records[:split_idx]
val_records = all_records[split_idx:]

train_ds = IntentDataset(train_records, tokenizer, max_len=MAX_LEN)
val_ds = IntentDataset(val_records, tokenizer, max_len=MAX_LEN)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

# -----------------------
# Model
# -----------------------
class DebertaMaskedModel(nn.Module):
    def __init__(self, base_model_name, num_primaries, num_secondaries, dropout=0.1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(base_model_name)
        hidden = self.encoder.config.hidden_size

        # ---------- Primary Head ----------
        self.primary_head = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.GELU(),
            nn.LayerNorm(hidden),
            nn.Dropout(dropout),
            nn.Linear(hidden, num_primaries)
        )

        # ---------- Secondary Head ----------
        self.secondary_head = nn.Sequential(
            nn.Linear(hidden, hidden * 2),
            nn.GELU(),
            nn.Dropout(dropout),

            nn.Linear(hidden * 2, hidden),
            nn.GELU(),
            nn.LayerNorm(hidden),
            nn.Dropout(dropout),

            nn.Linear(hidden, num_secondaries)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, input_ids, attention_mask):
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls = out.last_hidden_state[:, 0, :]
        cls = self.dropout(cls)
        return self.primary_head(cls), self.secondary_head(cls)


device = torch.device(DEVICE)
model = DebertaMaskedModel(MODEL_NAME, num_primaries=len(primary_labels), num_secondaries=len(all_secondary)).to(device)

# optimizer + scheduler
no_decay = ["bias", "LayerNorm.weight"]
grouped = [
    {"params": [p for n,p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
    {"params": [p for n,p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optim = AdamW(grouped, lr=LR)
total_steps = max(1, len(train_loader) * EPOCHS)
scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps=WARMUP_STEPS, num_training_steps=total_steps)

primary_loss_fn = nn.CrossEntropyLoss()
secondary_loss_fn = nn.CrossEntropyLoss()

# -----------------------
# Allowed mask helper
# -----------------------
def build_allowed_mask_batch(primary_ids: torch.Tensor, speakers: List[str], num_secondary: int):
    batch = primary_ids.cpu().tolist()
    mask = torch.zeros((len(batch), num_secondary), dtype=torch.bool)
    for i, p in enumerate(batch):
        sp = speakers[i].lower()
        allowed = allowed_map.get((int(p), sp), [])
        if allowed:
            mask[i, allowed] = True
        else:
            fallback = sec_agent if sp == "agent" else sec_customer
            mask[i, [sec2id[s] for s in fallback]] = True
    return mask.to(device)

# -----------------------
# Training loop
# -----------------------
best_val_primary_acc = 0.0

for epoch in range(1, EPOCHS + 1):
    model.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Train Epoch {epoch}")
    for batch in pbar:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        speakers = batch["speakers"]
        prim_y = batch["primaries"].to(device)
        sec_y = batch["secondaries"].to(device)

        optim.zero_grad()
        primary_logits, secondary_logits = model(input_ids, attention_mask)

        loss_p = primary_loss_fn(primary_logits, prim_y)

        allowed_mask = build_allowed_mask_batch(prim_y, speakers, num_secondary=len(all_secondary))
        masked_sec_logits = secondary_logits.masked_fill(~allowed_mask, -1e9)
        loss_s = secondary_loss_fn(masked_sec_logits, sec_y)

        loss = loss_p + ALPHA_SEC * loss_s
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optim.step()
        scheduler.step()

        running_loss += loss.item()
    # validation
    model.eval()
    prim_preds, prim_trues = [], []
    sec_preds, sec_trues = [], []
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            speakers = batch["speakers"]
            prim_y = batch["primaries"].to(device)
            sec_y = batch["secondaries"].to(device)

            primary_logits, secondary_logits = model(input_ids, attention_mask)
            prim_pred = torch.argmax(primary_logits, dim=1).cpu().tolist()
            prim_preds.extend(prim_pred)
            prim_trues.extend(prim_y.cpu().tolist())

            allowed_mask = build_allowed_mask_batch(prim_y, speakers, num_secondary=len(all_secondary))
            masked_sec_logits = secondary_logits.masked_fill(~allowed_mask, -1e9)
            sec_pred = torch.argmax(masked_sec_logits, dim=1).cpu().tolist()
            sec_preds.extend(sec_pred)
            sec_trues.extend(sec_y.cpu().tolist())

    prim_acc = accuracy_score(prim_trues, prim_preds) if prim_trues else 0.0
    sec_acc = accuracy_score(sec_trues, sec_preds) if sec_trues else 0.0
    print(f"Epoch {epoch} VALID primary_acc={prim_acc:.4f} secondary_acc={sec_acc:.4f}")

    if prim_acc > best_val_primary_acc:
        best_val_primary_acc = prim_acc
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "primary2id": primary2id,
            "sec2id": sec2id
        }, os.path.join(OUTPUT_DIR, "best_model.pt"))
        tokenizer.save_pretrained(OUTPUT_DIR)
        print("Saved best model.")

# final save
torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "final_model.pt"))
tokenizer.save_pretrained(OUTPUT_DIR)
print("Training finished. Best primary acc:", best_val_primary_acc)
print("Saved final model & tokenizer to", OUTPUT_DIR)

Loaded agent: 11583 customer: 11800
Valid records -> agent: 11583 customer: 11800 total: 23383


Train Epoch 1:   0%|          | 0/1170 [00:00<?, ?it/s]

Validation:   0%|          | 0/293 [00:00<?, ?it/s]

Epoch 1 VALID primary_acc=0.4182 secondary_acc=0.5166
Saved best model.


Train Epoch 2:   0%|          | 0/1170 [00:00<?, ?it/s]

Validation:   0%|          | 0/293 [00:00<?, ?it/s]

Epoch 2 VALID primary_acc=0.4432 secondary_acc=0.5602
Saved best model.


Train Epoch 3:   0%|          | 0/1170 [00:00<?, ?it/s]

Validation:   0%|          | 0/293 [00:00<?, ?it/s]

Epoch 3 VALID primary_acc=0.4539 secondary_acc=0.5893
Saved best model.


Train Epoch 4:   0%|          | 0/1170 [00:00<?, ?it/s]

Validation:   0%|          | 0/293 [00:00<?, ?it/s]

Epoch 4 VALID primary_acc=0.4550 secondary_acc=0.5948
Saved best model.
Training finished. Best primary acc: 0.45499251657045114
Saved final model & tokenizer to /kaggle/working/deberta_debug_model_fixed


In [18]:
#!/usr/bin/env python3
# continue_train_deberta.py

import os
import json
import random
import difflib
from pathlib import Path
from typing import List, Dict, Tuple

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score
from tqdm.auto import tqdm

# -----------------------
# CONFIG - edit if needed
# -----------------------
MODEL_NAME = "microsoft/deberta-v3-base"
AGENT_FILE = "/kaggle/input/primary-secondary-data-with-general/labeled_intent_dataset_agent_with_general.json"
CUSTOMER_FILE = "/kaggle/input/primary-secondary-data-with-general/labeled_intent_dataset_customer_with_general.json"
TAXONOMY_PATH = "/kaggle/input/taxonomy-json/taxonomy.json"
CHECKPOINT_PATH = "/kaggle/working/deberta_debug_model_fixed/best_model.pt"
OUTPUT_DIR = "/kaggle/working/deberta_continued_model"

BATCH_SIZE = 16
MAX_LEN = 128
LR = 1e-5  # Lower learning rate for fine-tuning
EPOCHS = 8  # Additional 8 epochs
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
ALPHA_SEC = 1.0
WARMUP_STEPS = 50

os.makedirs(OUTPUT_DIR, exist_ok=True)
random.seed(SEED)
torch.manual_seed(SEED)

# -----------------------
# Helpers
# -----------------------
def load_json_list(path: str) -> List[dict]:
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    if not isinstance(data, list):
        raise ValueError(f"{path} must contain a JSON array (list of objects).")
    return data

def normalize_str(s):
    if s is None:
        return ""
    return str(s).strip()

def fuzzy_match_one(item: str, choices: List[str], cutoff: float = 0.6) -> str:
    item_s = normalize_str(item)
    if not item_s:
        return None
    for c in choices:
        if item_s == c:
            return c
    for c in choices:
        if item_s.lower() == c.lower():
            return c
    matches = difflib.get_close_matches(item_s, choices, n=1, cutoff=cutoff)
    if matches:
        return matches[0]
    matches = difflib.get_close_matches(item_s.lower(), [c.lower() for c in choices], n=1, cutoff=cutoff)
    if matches:
        for c in choices:
            if c.lower() == matches[0]:
                return c
    return None

# -----------------------
# Load taxonomy & build label maps
# -----------------------
with open(TAXONOMY_PATH, "r", encoding="utf-8") as f:
    taxonomy = json.load(f)

primary_labels = sorted(list(taxonomy.keys()))
primary2id = {p: i for i, p in enumerate(primary_labels)}
id2primary = {i: p for p, i in primary2id.items()}

sec_agent = sorted({s for v in taxonomy.values() for s in v.get("agent", [])})
sec_customer = sorted({s for v in taxonomy.values() for s in v.get("user", [])})
all_secondary = sorted(list(set(sec_agent + sec_customer)))
sec2id = {s: i for i, s in enumerate(all_secondary)}
id2sec = {i: s for s, i in sec2id.items()}

allowed_map: Dict[Tuple[int, str], List[int]] = {}
for p_name, p_id in primary2id.items():
    a_children = taxonomy[p_name].get("agent", [])
    u_children = taxonomy[p_name].get("user", [])
    allowed_map[(p_id, "agent")] = [sec2id[s] for s in a_children if s in sec2id]
    allowed_map[(p_id, "customer")] = [sec2id[s] for s in u_children if s in sec2id]

# -----------------------
# Load raw data files
# -----------------------
agent_raw = load_json_list(AGENT_FILE)
customer_raw = load_json_list(CUSTOMER_FILE)
print("Loaded agent:", len(agent_raw), "customer:", len(customer_raw))

# -----------------------
# Fuzzy map primary & secondary labels
# -----------------------
def map_records(records: List[dict]) -> List[dict]:
    mapped = []
    for r in records:
        rr = dict(r)
        prim = normalize_str(rr.get("primary_intent") or "")
        sec = normalize_str(rr.get("secondary_intent") or "")
        sp = normalize_str(rr.get("speaker") or rr.get("original_speaker") or "").lower()

        mapped_primary = fuzzy_match_one(prim, primary_labels, cutoff=0.6)
        if mapped_primary:
            rr["primary_intent"] = mapped_primary
        pool = sec_agent if sp == "agent" else sec_customer
        mapped_secondary = fuzzy_match_one(sec, pool, cutoff=0.55)
        if mapped_secondary:
            rr["secondary_intent"] = mapped_secondary

        if sp.startswith("agent"):
            rr["speaker"] = "agent"
        elif sp.startswith("cust") or sp.startswith("customer") or sp.startswith("user"):
            rr["speaker"] = "customer"
        else:
            rr["speaker"] = sp if sp else rr.get("speaker","").lower() or "customer"
        mapped.append(rr)
    return mapped

agent_mapped = map_records(agent_raw)
customer_mapped = map_records(customer_raw)

# -----------------------
# Filter valid records
# -----------------------
def is_valid_record(r):
    return (
        (r.get("primary_intent") in primary2id)
        and (r.get("secondary_intent") in sec2id)
        and (str(r.get("speaker","")).lower() in ("agent","customer"))
        and ( (r.get("text") or r.get("full_text") or "").strip() != "" )
    )

agent_valid = [r for r in agent_mapped if is_valid_record(r)]
customer_valid = [r for r in customer_mapped if is_valid_record(r)]
all_records = agent_valid + customer_valid
random.shuffle(all_records)
print("Valid records -> agent:", len(agent_valid), "customer:", len(customer_valid), "total:", len(all_records))

if len(all_records) == 0:
    raise RuntimeError("No valid records after mapping/filtering.")

# -----------------------
# Tokenizer & Dataset
# -----------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

class IntentDataset(Dataset):
    def __init__(self, records: List[dict], tokenizer, max_len=128):
        self.records = records
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        r = self.records[idx]
        text = r.get("text") or r.get("full_text") or ""
        enc = self.tokenizer(text, truncation=True, padding="max_length", max_length=self.max_len, return_tensors="pt")
        primary_id = primary2id.get(r.get("primary_intent"), -1)
        secondary_id = sec2id.get(r.get("secondary_intent"), -1)
        return {
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "speaker": r.get("speaker","").lower(),
            "primary_id": primary_id,
            "secondary_id": secondary_id
        }

def collate_fn(batch):
    input_ids = torch.stack([b["input_ids"] for b in batch])
    attention_mask = torch.stack([b["attention_mask"] for b in batch])
    speakers = [b["speaker"] for b in batch]
    primaries = torch.tensor([b["primary_id"] for b in batch], dtype=torch.long)
    secondaries = torch.tensor([b["secondary_id"] for b in batch], dtype=torch.long)
    return {"input_ids": input_ids, "attention_mask": attention_mask, "speakers": speakers, "primaries": primaries, "secondaries": secondaries}

# train/val split
split_idx = int(0.8 * len(all_records))
train_records = all_records[:split_idx]
val_records = all_records[split_idx:]

train_ds = IntentDataset(train_records, tokenizer, max_len=MAX_LEN)
val_ds = IntentDataset(val_records, tokenizer, max_len=MAX_LEN)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

# -----------------------
# Model
# -----------------------
class DebertaMaskedModel(nn.Module):
    def __init__(self, base_model_name, num_primaries, num_secondaries, dropout=0.1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(base_model_name)
        hidden = self.encoder.config.hidden_size

        self.primary_head = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.GELU(),
            nn.LayerNorm(hidden),
            nn.Dropout(dropout),
            nn.Linear(hidden, num_primaries)
        )

        self.secondary_head = nn.Sequential(
            nn.Linear(hidden, hidden * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden * 2, hidden),
            nn.GELU(),
            nn.LayerNorm(hidden),
            nn.Dropout(dropout),
            nn.Linear(hidden, num_secondaries)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, input_ids, attention_mask):
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls = out.last_hidden_state[:, 0, :]
        cls = self.dropout(cls)
        return self.primary_head(cls), self.secondary_head(cls)


device = torch.device(DEVICE)
model = DebertaMaskedModel(MODEL_NAME, num_primaries=len(primary_labels), num_secondaries=len(all_secondary)).to(device)

# -----------------------
# Load checkpoint
# -----------------------
print(f"Loading checkpoint from {CHECKPOINT_PATH}")
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
model.load_state_dict(checkpoint["model_state"])
starting_epoch = checkpoint.get("epoch", 0)
print(f"Loaded model from epoch {starting_epoch}")

# Verify label mappings match
if checkpoint.get("primary2id") != primary2id:
    print("WARNING: primary2id mismatch between checkpoint and current taxonomy!")
if checkpoint.get("sec2id") != sec2id:
    print("WARNING: sec2id mismatch between checkpoint and current taxonomy!")

# -----------------------
# Optimizer + Scheduler
# -----------------------
no_decay = ["bias", "LayerNorm.weight"]
grouped = [
    {"params": [p for n,p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
    {"params": [p for n,p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optim = AdamW(grouped, lr=LR)
total_steps = max(1, len(train_loader) * EPOCHS)
scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps=WARMUP_STEPS, num_training_steps=total_steps)

primary_loss_fn = nn.CrossEntropyLoss()
secondary_loss_fn = nn.CrossEntropyLoss()

# -----------------------
# Allowed mask helper
# -----------------------
def build_allowed_mask_batch(primary_ids: torch.Tensor, speakers: List[str], num_secondary: int):
    batch = primary_ids.cpu().tolist()
    mask = torch.zeros((len(batch), num_secondary), dtype=torch.bool)
    for i, p in enumerate(batch):
        sp = speakers[i].lower()
        allowed = allowed_map.get((int(p), sp), [])
        if allowed:
            mask[i, allowed] = True
        else:
            fallback = sec_agent if sp == "agent" else sec_customer
            mask[i, [sec2id[s] for s in fallback]] = True
    return mask.to(device)

# -----------------------
# Training loop
# -----------------------
best_val_primary_acc = 0.0

for epoch in range(1, EPOCHS + 1):
    actual_epoch = starting_epoch + epoch
    model.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Train Epoch {actual_epoch}")
    
    for batch in pbar:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        speakers = batch["speakers"]
        prim_y = batch["primaries"].to(device)
        sec_y = batch["secondaries"].to(device)

        optim.zero_grad()
        primary_logits, secondary_logits = model(input_ids, attention_mask)

        loss_p = primary_loss_fn(primary_logits, prim_y)

        allowed_mask = build_allowed_mask_batch(prim_y, speakers, num_secondary=len(all_secondary))
        masked_sec_logits = secondary_logits.masked_fill(~allowed_mask, -1e9)
        loss_s = secondary_loss_fn(masked_sec_logits, sec_y)

        loss = loss_p + ALPHA_SEC * loss_s
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optim.step()
        scheduler.step()

        running_loss += loss.item()
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})
    
    avg_train_loss = running_loss / len(train_loader)
    
    # Validation
    model.eval()
    prim_preds, prim_trues = [], []
    sec_preds, sec_trues = [], []
    val_loss = 0.0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            speakers = batch["speakers"]
            prim_y = batch["primaries"].to(device)
            sec_y = batch["secondaries"].to(device)

            primary_logits, secondary_logits = model(input_ids, attention_mask)
            
            loss_p = primary_loss_fn(primary_logits, prim_y)
            allowed_mask = build_allowed_mask_batch(prim_y, speakers, num_secondary=len(all_secondary))
            masked_sec_logits = secondary_logits.masked_fill(~allowed_mask, -1e9)
            loss_s = secondary_loss_fn(masked_sec_logits, sec_y)
            val_loss += (loss_p + ALPHA_SEC * loss_s).item()
            
            prim_pred = torch.argmax(primary_logits, dim=1).cpu().tolist()
            prim_preds.extend(prim_pred)
            prim_trues.extend(prim_y.cpu().tolist())

            sec_pred = torch.argmax(masked_sec_logits, dim=1).cpu().tolist()
            sec_preds.extend(sec_pred)
            sec_trues.extend(sec_y.cpu().tolist())

    avg_val_loss = val_loss / len(val_loader)
    prim_acc = accuracy_score(prim_trues, prim_preds) if prim_trues else 0.0
    sec_acc = accuracy_score(sec_trues, sec_preds) if sec_trues else 0.0
    
    print(f"\nEpoch {actual_epoch} Summary:")
    print(f"  Train Loss: {avg_train_loss:.4f}")
    print(f"  Val Loss: {avg_val_loss:.4f}")
    print(f"  Primary Acc: {prim_acc:.4f}")
    print(f"  Secondary Acc: {sec_acc:.4f}")

    if prim_acc > best_val_primary_acc:
        best_val_primary_acc = prim_acc
        torch.save({
            "epoch": actual_epoch,
            "model_state": model.state_dict(),
            "primary2id": primary2id,
            "sec2id": sec2id,
            "primary_acc": prim_acc,
            "secondary_acc": sec_acc
        }, os.path.join(OUTPUT_DIR, "best_model.pt"))
        tokenizer.save_pretrained(OUTPUT_DIR)
        print(f"  ✓ Saved new best model with primary acc: {prim_acc:.4f}")

# Final save
torch.save({
    "epoch": starting_epoch + EPOCHS,
    "model_state": model.state_dict(),
    "primary2id": primary2id,
    "sec2id": sec2id
}, os.path.join(OUTPUT_DIR, "final_model.pt"))
tokenizer.save_pretrained(OUTPUT_DIR)

print("\n" + "="*60)
print("Training finished!")
print(f"Best primary accuracy: {best_val_primary_acc:.4f}")
print(f"Final model saved to: {OUTPUT_DIR}")
print("="*60)

Loaded agent: 11583 customer: 11800
Valid records -> agent: 11583 customer: 11800 total: 23383




Loading checkpoint from /kaggle/working/deberta_debug_model_fixed/best_model.pt
Loaded model from epoch 4


Train Epoch 5:   0%|          | 0/1170 [00:00<?, ?it/s]

Validation:   0%|          | 0/293 [00:00<?, ?it/s]


Epoch 5 Summary:
  Train Loss: 2.2172
  Val Loss: 2.8732
  Primary Acc: 0.4529
  Secondary Acc: 0.5948
  ✓ Saved new best model with primary acc: 0.4529


Train Epoch 6:   0%|          | 0/1170 [00:00<?, ?it/s]

Validation:   0%|          | 0/293 [00:00<?, ?it/s]


Epoch 6 Summary:
  Train Loss: 1.8931
  Val Loss: 2.9371
  Primary Acc: 0.4458
  Secondary Acc: 0.5974


Train Epoch 7:   0%|          | 0/1170 [00:00<?, ?it/s]

Validation:   0%|          | 0/293 [00:00<?, ?it/s]


Epoch 7 Summary:
  Train Loss: 1.7558
  Val Loss: 3.0370
  Primary Acc: 0.4419
  Secondary Acc: 0.5989


Train Epoch 8:   0%|          | 0/1170 [00:00<?, ?it/s]

Validation:   0%|          | 0/293 [00:00<?, ?it/s]


Epoch 8 Summary:
  Train Loss: 1.7690
  Val Loss: 3.1021
  Primary Acc: 0.4383
  Secondary Acc: 0.5993


Train Epoch 9:   0%|          | 0/1170 [00:00<?, ?it/s]

Validation:   0%|          | 0/293 [00:00<?, ?it/s]


Epoch 9 Summary:
  Train Loss: 1.7566
  Val Loss: 3.1542
  Primary Acc: 0.4409
  Secondary Acc: 0.5980


Train Epoch 10:   0%|          | 0/1170 [00:00<?, ?it/s]

Validation:   0%|          | 0/293 [00:00<?, ?it/s]


Epoch 10 Summary:
  Train Loss: 1.6414
  Val Loss: 3.1919
  Primary Acc: 0.4366
  Secondary Acc: 0.6008


Train Epoch 11:   0%|          | 0/1170 [00:00<?, ?it/s]

Validation:   0%|          | 0/293 [00:00<?, ?it/s]


Epoch 11 Summary:
  Train Loss: 1.5608
  Val Loss: 3.2175
  Primary Acc: 0.4364
  Secondary Acc: 0.6021


Train Epoch 12:   0%|          | 0/1170 [00:00<?, ?it/s]

Validation:   0%|          | 0/293 [00:00<?, ?it/s]


Epoch 12 Summary:
  Train Loss: 1.4944
  Val Loss: 3.2312
  Primary Acc: 0.4353
  Secondary Acc: 0.5982

Training finished!
Best primary accuracy: 0.4529
Final model saved to: /kaggle/working/deberta_continued_model


In [23]:
#!/usr/bin/env python3
# test_model_inference.py

import os
import json
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from typing import Dict, List, Tuple

# -----------------------
# CONFIG
# -----------------------
MODEL_NAME = "microsoft/deberta-v3-base"
CHECKPOINT_PATH = "/kaggle/working/deberta_continued_model/best_model.pt"  # or use deberta_debug_model_fixed/best_model.pt
TAXONOMY_PATH = "/kaggle/input/taxonomy-json/taxonomy.json"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# -----------------------
# Model Definition (same as training)
# -----------------------
class DebertaMaskedModel(nn.Module):
    def __init__(self, base_model_name, num_primaries, num_secondaries, dropout=0.1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(base_model_name)
        hidden = self.encoder.config.hidden_size

        self.primary_head = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.GELU(),
            nn.LayerNorm(hidden),
            nn.Dropout(dropout),
            nn.Linear(hidden, num_primaries)
        )

        self.secondary_head = nn.Sequential(
            nn.Linear(hidden, hidden * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden * 2, hidden),
            nn.GELU(),
            nn.LayerNorm(hidden),
            nn.Dropout(dropout),
            nn.Linear(hidden, num_secondaries)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, input_ids, attention_mask):
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls = out.last_hidden_state[:, 0, :]
        cls = self.dropout(cls)
        return self.primary_head(cls), self.secondary_head(cls)

# -----------------------
# Load Model & Taxonomy
# -----------------------
print("Loading checkpoint...")
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)

primary2id = checkpoint["primary2id"]
sec2id = checkpoint["sec2id"]
id2primary = {i: p for p, i in primary2id.items()}
id2sec = {i: s for s, i in sec2id.items()}

print(f"Loaded model from epoch {checkpoint.get('epoch', 'unknown')}")
print(f"Primary intents: {len(primary2id)}")
print(f"Secondary intents: {len(sec2id)}")

# Load taxonomy for allowed mappings
with open(TAXONOMY_PATH, "r", encoding="utf-8") as f:
    taxonomy = json.load(f)

# Build allowed secondary intents per primary
allowed_map: Dict[Tuple[int, str], List[int]] = {}
for p_name, p_id in primary2id.items():
    a_children = taxonomy[p_name].get("agent", [])
    u_children = taxonomy[p_name].get("user", [])
    allowed_map[(p_id, "agent")] = [sec2id[s] for s in a_children if s in sec2id]
    allowed_map[(p_id, "customer")] = [sec2id[s] for s in u_children if s in sec2id]

# Initialize model
device = torch.device(DEVICE)
model = DebertaMaskedModel(
    MODEL_NAME, 
    num_primaries=len(primary2id), 
    num_secondaries=len(sec2id)
).to(device)

model.load_state_dict(checkpoint["model_state"])
model.eval()

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
print("Model loaded successfully!\n")

# -----------------------
# Inference Function
# -----------------------
def predict(text: str, speaker: str = "customer", max_len: int = 128, top_k: int = 3):
    """
    Predict primary and secondary intents for a given text.
    
    Args:
        text: Input text to classify
        speaker: "agent" or "customer"
        max_len: Maximum token length
        top_k: Number of top predictions to return
    """
    speaker = speaker.lower()
    if speaker not in ["agent", "customer"]:
        speaker = "customer"
    
    # Tokenize
    encoding = tokenizer(
        text,
        truncation=True,
        padding="max_length",
        max_length=max_len,
        return_tensors="pt"
    )
    
    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)
    
    # Predict
    with torch.no_grad():
        primary_logits, secondary_logits = model(input_ids, attention_mask)
    
    # Get primary predictions
    primary_probs = torch.softmax(primary_logits, dim=1)[0]
    primary_pred_id = torch.argmax(primary_probs).item()
    primary_pred = id2primary[primary_pred_id]
    primary_conf = primary_probs[primary_pred_id].item()
    
    # Get top-k primary predictions
    top_primary_probs, top_primary_ids = torch.topk(primary_probs, min(top_k, len(primary2id)))
    top_primaries = [
        (id2primary[idx.item()], prob.item()) 
        for idx, prob in zip(top_primary_ids, top_primary_probs)
    ]
    
    # Get secondary predictions (masked by allowed intents)
    allowed_secondary_ids = allowed_map.get((primary_pred_id, speaker), [])
    
    if allowed_secondary_ids:
        # Create mask for allowed secondaries
        mask = torch.zeros(len(sec2id), dtype=torch.bool, device=device)
        mask[allowed_secondary_ids] = True
        masked_logits = secondary_logits[0].masked_fill(~mask, -1e9)
    else:
        masked_logits = secondary_logits[0]
    
    secondary_probs = torch.softmax(masked_logits, dim=0)
    secondary_pred_id = torch.argmax(secondary_probs).item()
    secondary_pred = id2sec[secondary_pred_id]
    secondary_conf = secondary_probs[secondary_pred_id].item()
    
    # Get top-k secondary predictions
    top_secondary_probs, top_secondary_ids = torch.topk(secondary_probs, min(top_k, len(sec2id)))
    top_secondaries = [
        (id2sec[idx.item()], prob.item()) 
        for idx, prob in zip(top_secondary_ids, top_secondary_probs)
        if prob.item() > -1e8  # Filter out masked values
    ]
    
    return {
        "text": text,
        "speaker": speaker,
        "primary_intent": primary_pred,
        "primary_confidence": primary_conf,
        "secondary_intent": secondary_pred,
        "secondary_confidence": secondary_conf,
        "top_primary_predictions": top_primaries,
        "top_secondary_predictions": top_secondaries,
        "allowed_secondaries": [id2sec[i] for i in allowed_secondary_ids] if allowed_secondary_ids else []
    }

def print_prediction(result: dict):
    """Pretty print prediction results"""
    print("="*70)
    print(f"Input: {result['text']}")
    print(f"Speaker: {result['speaker'].upper()}")
    print("-"*70)
    print(f"PRIMARY INTENT: {result['primary_intent']} (confidence: {result['primary_confidence']:.2%})")
    print(f"SECONDARY INTENT: {result['secondary_intent']} (confidence: {result['secondary_confidence']:.2%})")
    print("-"*70)
    
    print("\nTop Primary Predictions:")
    for i, (intent, conf) in enumerate(result['top_primary_predictions'], 1):
        print(f"  {i}. {intent}: {conf:.2%}")
    
    print("\nTop Secondary Predictions:")
    for i, (intent, conf) in enumerate(result['top_secondary_predictions'], 1):
        print(f"  {i}. {intent}: {conf:.2%}")
    
    if result['allowed_secondaries']:
        print(f"\nAllowed secondaries for this primary: {len(result['allowed_secondaries'])}")
    print("="*70 + "\n")

# -----------------------
# Interactive Testing
# -----------------------
def interactive_mode():
    """Run interactive testing mode"""
    print("\n" + "="*70)
    print("INTERACTIVE TESTING MODE")
    print("="*70)
    print("Enter 'quit' or 'exit' to stop")
    print("Enter 'examples' to see example inputs\n")
    
    while True:
        text = input("Enter text to classify: ").strip()
        
        if text.lower() in ['quit', 'exit', 'q']:
            print("Exiting...")
            break
        
        if text.lower() == 'examples':
            print("\nExample inputs:")
            print("  - I want to cancel my subscription")
            print("  - What's my account balance?")
            print("  - Let me transfer you to a specialist")
            print("  - I need help with my order")
            print("  - Thank you for your help!\n")
            continue
        
        if not text:
            print("Please enter some text.\n")
            continue
        
        speaker = input("Speaker (agent/customer) [default: customer]: ").strip().lower()
        if not speaker:
            speaker = "customer"
        
        result = predict(text, speaker)
        print()
        print_prediction(result)

# -----------------------
# Batch Testing
# -----------------------
def batch_test(examples: List[Tuple[str, str]]):
    """Test multiple examples at once"""
    print("\n" + "="*70)
    print("BATCH TESTING")
    print("="*70 + "\n")
    
    for text, speaker in examples:
        result = predict(text, speaker)
        print_prediction(result)

# -----------------------
# Main
# -----------------------
if __name__ == "__main__":
    # Example test cases
    test_examples = [
        ("I want to cancel my subscription", "customer"),
        ("Let me check your account details", "agent"),
        ("What's my current balance?", "customer"),
        ("I'll transfer you to our billing department", "agent"),
        ("Thank you so much for your help!", "customer"),
    ]
    
    print("\n" + "="*70)
    print("MODEL TESTING SCRIPT")
    print("="*70)
    print("\nChoose a mode:")
    print("1. Interactive mode (enter custom inputs)")
    print("2. Batch test with example inputs")
    print("3. Single prediction")
    
    choice = input("\nEnter choice (1/2/3) [default: 1]: ").strip()
    
    if choice == "2":
        batch_test(test_examples)
    elif choice == "3":
        text = input("Enter text: ").strip()
        speaker = input("Enter speaker (agent/customer) [default: customer]: ").strip().lower() or "customer"
        result = predict(text, speaker)
        print()
        print_prediction(result)
    else:
        interactive_mode()

Loading checkpoint...
Loaded model from epoch 5
Primary intents: 17
Secondary intents: 98




Model loaded successfully!


MODEL TESTING SCRIPT

Choose a mode:
1. Interactive mode (enter custom inputs)
2. Batch test with example inputs
3. Single prediction



Enter choice (1/2/3) [default: 1]:  2



BATCH TESTING

Input: I want to cancel my subscription
Speaker: CUSTOMER
----------------------------------------------------------------------
PRIMARY INTENT: Cancellation_Policy (confidence: 83.30%)
SECONDARY INTENT: Inquire_About_Cancellation_Options (confidence: 76.43%)
----------------------------------------------------------------------

Top Primary Predictions:
  1. Cancellation_Policy: 83.30%
  2. Service_Switching: 8.65%
  3. Billing & Refunds: 1.52%

Top Secondary Predictions:
  1. Inquire_About_Cancellation_Options: 76.43%
  2. Ask_About_Reinstatement: 13.11%
  3. Request_Policy_Clarification: 10.46%

Allowed secondaries for this primary: 3

Input: Let me check your account details
Speaker: AGENT
----------------------------------------------------------------------
PRIMARY INTENT: General_Conversation (confidence: 31.78%)
SECONDARY INTENT: Ask_for_Clarification (confidence: 49.91%)
----------------------------------------------------------------------

Top Primary Predict