In [None]:
!pip install x-transformers

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.metrics import f1_score

from transformers import AutoTokenizer, AutoModel
from x_transformers import TransformerWrapper, Decoder

In [None]:
TRAIN_CSV = "prachatai_train.csv"
VAL_CSV   = "prachatai_val.csv"
TEST_CSV  = "prachatai_test.csv"

MODEL_PATH = "wangchan_xt_best.pt"

MAX_LEN = 256
BATCH_SIZE = 8
EPOCHS = 40
PATIENCE = 8
LR = 2e-4

In [None]:
LABEL_COLS = [
    "politics", "human_rights", "quality_of_life", "international",
    "social", "environment", "economics", "culture", "labor",
    "national_security", "ict", "education"
]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

tokenizer = AutoTokenizer.from_pretrained(
    "airesearch/wangchanberta-base-att-spm-uncased"
)
PAD_ID = tokenizer.pad_token_id

In [None]:
def load_dataset(csv_path):
    df = pd.read_csv(csv_path)
    texts = df["body_text"].astype(str).tolist()
    labels = df[LABEL_COLS].values.astype(np.float32)

    enc = tokenizer(
        texts,
        truncation=True,
        max_length=MAX_LEN,
        padding=False
    )
    return enc["input_ids"], labels

X_train, y_train = load_dataset(TRAIN_CSV)
X_val, y_val     = load_dataset(VAL_CSV)
X_test, y_test   = load_dataset(TEST_CSV)

In [None]:
class ThaiDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = torch.tensor(y, dtype=torch.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return torch.tensor(self.X[idx]), self.y[idx]

def collate_fn(batch):
    seqs, labels = zip(*batch)

    padded = pad_sequence(
        seqs, batch_first=True, padding_value=PAD_ID
    )
    attn_mask = (padded != PAD_ID).long()

    return (
        padded.to(device),
        attn_mask.to(device),
        torch.stack(labels).to(device)
    )

train_loader = DataLoader(
    ThaiDataset(X_train, y_train),
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    ThaiDataset(X_val, y_val),
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    ThaiDataset(X_test, y_test),
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn
)

In [None]:
encoder = AutoModel.from_pretrained(
    "airesearch/wangchanberta-base-att-spm-uncased"
)
for p in encoder.parameters():
    p.requires_grad = False

In [None]:
class WangchanXTClassifier(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = encoder
        self.hidden = encoder.config.hidden_size  # 768

        self.decoder = TransformerWrapper(
            num_tokens=1,          # dummy
            max_seq_len=MAX_LEN,
            attn_layers=Decoder(
                dim=self.hidden,
                depth=3,
                heads=8,
                cross_attend=True,
                causal=False
            )
        )

        self.classifier = nn.Linear(self.hidden, len(LABEL_COLS))

    def forward(self, input_ids, attention_mask):
        # ----- Encoder -----
        enc = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask
        ).last_hidden_state            # [B, L, 768]

        # ----- Decoder -----
        dummy = torch.zeros(
            enc.size(0),
            enc.size(1),
            self.hidden,
            device=enc.device
        )

        dec = self.decoder(
            x=dummy,
            context=enc,
            context_mask=attention_mask.bool(),
            return_embeddings=True
        )

        mask = attention_mask.unsqueeze(-1)
        pooled = (dec * mask).sum(1) / mask.sum(1).clamp(min=1)

        return torch.sigmoid(self.classifier(pooled))

In [None]:
model = WangchanXTClassifier().to(device)

criterion = nn.BCELoss()
optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LR
)

scaler = torch.cuda.amp.GradScaler()

In [None]:
class EarlyStopping:
    def __init__(self, patience):
        self.best = 0.0
        self.counter = 0
        self.patience = patience

    def step(self, score, model):
        if score > self.best:
            self.best = score
            self.counter = 0
            torch.save(model.state_dict(), MODEL_PATH)
            return False
        else:
            self.counter += 1
            return self.counter >= self.patience


In [None]:
early_stop = EarlyStopping(PATIENCE)

# ---------------- TRAIN ----------------
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for Xb, maskb, yb in train_loader:
        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            preds = model(Xb, maskb)
            loss = criterion(preds, yb)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

    # ----- VALIDATION -----
    model.eval()
    yt, yp = [], []

    with torch.no_grad():
        for Xb, maskb, yb in val_loader:
            preds = (model(Xb, maskb) > 0.5).int()
            yt.append(yb.cpu().numpy())
            yp.append(preds.cpu().numpy())

    yt = np.vstack(yt)
    yp = np.vstack(yp)
    val_f1 = f1_score(yt, yp, average="macro")

    print(
        f"Epoch {epoch+1:03d} | "
        f"Loss {total_loss:.4f} | "
        f"Val F1 {val_f1:.4f}"
    )

    if early_stop.step(val_f1, model):
        print("⏹ Early stopping")
        break

In [None]:
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()

yt, yp = [], []

with torch.no_grad():
    for Xb, maskb, yb in test_loader:
        preds = (model(Xb, maskb) > 0.5).int()
        yt.append(yb.cpu().numpy())
        yp.append(preds.cpu().numpy())

yt = np.vstack(yt)
yp = np.vstack(yp)

print("\nFINAL TEST RESULTS")
print("Macro F1:", f1_score(yt, yp, average="macro"))
print("Micro F1:", f1_score(yt, yp, average="micro"))

for i, label in enumerate(LABEL_COLS):
    print(label, f1_score(yt[:, i], yp[:, i]))

In [None]:
def predict(text, threshold=0.5):
    enc = tokenizer(
        text,
        truncation=True,
        max_length=MAX_LEN,
        return_tensors="pt"
    )

    with torch.no_grad():
        probs = model(
            enc["input_ids"].to(device),
            enc["attention_mask"].to(device)
        )[0].cpu().numpy()

    return sorted(
        [(LABEL_COLS[i], float(probs[i]))
         for i in range(len(LABEL_COLS)) if probs[i] >= threshold],
        key=lambda x: x[1],
        reverse=True
    )

print("\nPREDICT EXAMPLES")
print(predict("รัฐบาลไทยประกาศนโยบายด้านสิ่งแวดล้อมใหม่"))
print(predict("แรงงานเรียกร้องสิทธิ์การทำงาน"))