In [6]:
TOPIC_HIERARCHY = {
    "Politics": ["India", "UK", "USA", "China", "Russia", "Global"],
    "Sports": ["Cricket", "Football", "Basketball", "Tennis", "Olympics"],
    "Technology": ["Artificial Intelligence", "Machine Learning", "Software Development", "Cybersecurity", "Blockchain"],
    "Business": ["Startups", "Finance", "Stock Market", "Economy", "E-commerce"],
    "Entertainment": ["Movies", "TV Shows", "Music", "Celebrities", "OTT Platforms"],
    "Science": ["Physics", "Biology", "Space", "Climate", "Research"],
    "Health": ["Fitness", "Nutrition", "Mental Health", "Diseases", "Medicine"],
    "Education": ["Exams", "Universities", "Online Courses", "Careers", "Research"],
    "General": ["Chitchat", "Greetings", "Meta", "Clarification", "Other"]
}


def flatten_messages(messages):
    parts = []
    for m in messages:
        parts.append(f"{m['role'].capitalize()}: {m['content'].strip()}")
    return "\n".join(parts)
    
ID2LABEL = {}
idx = 0
for l1, l2_list in TOPIC_HIERARCHY.items():
    for l2 in l2_list:
        ID2LABEL[idx] = (l1, l2)
        idx += 1

LEVEL2_LABEL2ID = {v: k for k, v in ID2LABEL.items()}
NUM_CLASSES = len(LEVEL2_LABEL2ID)

print(f"Number of classes: {NUM_CLASSES}")

Number of classes: 46


In [7]:
import json
import torch
from torch.utils.data import Dataset

class TopicHierarchyDataset(Dataset):
    def __init__(self, jsonl_path):
        self.samples = []

        with open(jsonl_path, "r", encoding="utf-8") as f:
            for line in f:
                data = json.loads(line)

                text = flatten_messages(data["messages"])

                l1 = data["labels"]["topic"]["level_1"]
                l2 = data["labels"]["topic"]["level_2"]

                label_id = LEVEL2_LABEL2ID[(l1, l2)]

                self.samples.append({
                    "text": text,
                    "label": label_id
                })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        item = self.samples[idx]
        return {
            "text": item["text"],
            "label": torch.tensor(item["label"], dtype=torch.long)
        }

def collate_fn(batch):
    return {
        "texts": [b["text"] for b in batch],
        "labels": torch.stack([b["label"] for b in batch]),
    }


In [8]:
from torch.utils.data import DataLoader, random_split
import torch

dataset = TopicHierarchyDataset("/kaggle/input/finaltagging/data1.jsonl")

train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1

total_size = len(dataset)
train_size = int(train_ratio * total_size)
val_size = int(val_ratio * total_size)
test_size = total_size - train_size - val_size

generator = torch.Generator().manual_seed(42)

train_ds, val_ds, test_ds = random_split(
    dataset,
    [train_size, val_size, test_size],
    generator=generator
)


In [9]:
train_loader = DataLoader(
    train_ds,
    batch_size=8,
    shuffle=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_ds,
    batch_size=8,
    shuffle=False,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_ds,
    batch_size=8,
    shuffle=False,
    collate_fn=collate_fn
)


In [10]:
print("Train:", len(train_ds))
print("Val:", len(val_ds))
print("Test:", len(test_ds))

batch = next(iter(train_loader))
print(batch.keys())

Train: 404
Val: 50
Test: 52
dict_keys(['texts', 'labels'])


In [11]:
batch = next(iter(train_loader))

print("===== BATCH SAMPLE =====")
print(batch["texts"][0])
print("\n===== LABEL =====")

label_id = batch["labels"][0].item()
print("Label ID:", label_id)
print("Decoded Label:", ID2LABEL[label_id])


===== BATCH SAMPLE =====
User: Hey there! How's your day going so far?
Assistant: It's been pretty good, thanks for asking! Just trying to clear my inbox. How about yours?
User: Mine's been a bit hectic, but I'm looking forward to unwinding tonight.
Assistant: Oh, sounds like a good plan. Do you have anything fun lined up?
User: Yeah, I was thinking of finally watching that new sci-fi movie everyone's raving about. My friend Sarah mentioned it earlier.
Assistant: The one with the incredible special effects? I've heard it's epic. Are you going to watch it by yourself?
User: Nah, she's joining me.

===== LABEL =====
Label ID: 41
Decoded Label: ('General', 'Chitchat')


In [12]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel

MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
base_model = AutoModel.from_pretrained(MODEL_NAME)


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/727 [00:00<?, ?B/s]

2026-01-03 12:36:48.162987: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1767443808.570664     106 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1767443808.665242     106 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1767443809.669173     106 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1767443809.669203     106 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1767443809.669205     106 computation_placer.cc:177] computation placer alr

model.safetensors:   0%|          | 0.00/1.19G [00:00<?, ?B/s]

In [13]:
class QwenTopicClassifier(nn.Module):
    def __init__(self, base_model, num_classes):
        super().__init__()
        self.encoder = base_model
        hidden_size = base_model.config.hidden_size

        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes)
        )

    def mean_pool(self, last_hidden_state, attention_mask):
        mask = attention_mask.unsqueeze(-1).float()
        summed = torch.sum(last_hidden_state * mask, dim=1)
        counts = torch.clamp(mask.sum(dim=1), min=1e-9)
        return summed / counts

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )

        pooled = self.mean_pool(outputs.last_hidden_state, attention_mask)
        logits = self.classifier(pooled)

        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)

        return {
            "loss": loss,
            "logits": logits
        }


In [14]:
NUM_CLASSES = 46

model = QwenTopicClassifier(
    base_model=base_model,
    num_classes=NUM_CLASSES
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


QwenTopicClassifier(
  (encoder): Qwen3Model(
    (embed_tokens): Embedding(151669, 1024)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (up_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (down_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
        (post_attention_

In [15]:
def tokenize_batch(batch):
    enc = tokenizer(
        batch["texts"],
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors="pt"
    )
    enc["labels"] = batch["labels"]
    return enc


In [16]:
from torch.optim import AdamW
from tqdm import tqdm

optimizer = AdamW(model.parameters(), lr=2e-5)

model.train()

for epoch in range(3):
    total_loss = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        batch = tokenize_batch(batch)
        batch = {k: v.to(device) for k, v in batch.items()}

        optimizer.zero_grad()
        outputs = model(**batch)
        loss = outputs["loss"]

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} Loss: {total_loss / len(train_loader):.4f}")


Epoch 1: 100%|██████████| 51/51 [00:09<00:00,  5.15it/s]


Epoch 1 Loss: 1.8210


Epoch 2: 100%|██████████| 51/51 [00:08<00:00,  6.07it/s]


Epoch 2 Loss: 0.1500


Epoch 3: 100%|██████████| 51/51 [00:08<00:00,  6.13it/s]

Epoch 3 Loss: 0.0340





In [17]:
model.eval()
correct, total = 0, 0

with torch.no_grad():
    for batch in val_loader:
        batch = tokenize_batch(batch)
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = model(**batch)
        preds = outputs["logits"].argmax(dim=-1)

        correct += (preds == batch["labels"]).sum().item()
        total += batch["labels"].size(0)

print("Validation Accuracy:", correct / total)


Validation Accuracy: 0.98


In [19]:
model.eval()

correct = 0
total = 0

with torch.no_grad():
    for batch in test_loader:
        batch = tokenize_batch(batch)
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = model(**batch)
        preds = outputs["logits"].argmax(dim=-1)

        correct += (preds == batch["labels"]).sum().item()
        total += batch["labels"].size(0)

test_accuracy = correct / total
print(f"Test Accuracy: {correct}/{total} : {test_accuracy:.4f}")


Test Accuracy: 50/52 : 0.9615
