In [1]:
import nltk
nltk.download("punkt")

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

## Data Preprocessing

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
import json
from collections import Counter
from nltk.tokenize import word_tokenize
from tqdm import tqdm


# -------- Tokenizer --------
def nltk_tokenizer(text):
    return word_tokenize(text.lower())


# -------- Read JSON (for DBPedia) --------
def load_dbpedia_data(train_path, val_path, test_path=None):
    # ---- train ----
    with open(train_path, "r", encoding="utf-8") as f:
        train_data = json.load(f)
    train_texts = [item["text"] for item in train_data]
    train_labels = [int(item["label"]) for item in train_data]

    # ---- val ----
    with open(val_path, "r", encoding="utf-8") as f:
        val_data = json.load(f)
    val_texts = [item["text"] for item in val_data]
    val_labels = [int(item["label"]) for item in val_data]

    # ---- test ----
    if test_path is not None:
        with open(test_path, "r", encoding="utf-8") as f:
            test_data = json.load(f)
        test_texts = [item["text"] for item in test_data]
        test_ids = [item.get("id", i) for i, item in enumerate(test_data)]
    else:
        test_texts, test_ids = [], []

    num_classes = max(train_labels + val_labels) + 1

    return train_texts, train_labels, val_texts, val_labels, test_texts, test_ids, num_classes


# -------- Vocab --------
class Vocab:
    def __init__(self, tokens_list, min_freq=1):
        counter = Counter()
        for tokens in tokens_list:
            counter.update(tokens)

        self.itos = ["<unk>", "<pad>"] + [tok for tok, freq in counter.items() if freq >= min_freq]
        self.stoi = {tok: i for i, tok in enumerate(self.itos)}

    def __len__(self):
        return len(self.itos)

    def __getitem__(self, token):
        return self.stoi.get(token, self.stoi["<unk>"])


# -------- Dataset --------
class TextDataset(Dataset):
    def __init__(self, texts, labels=None, vocab=None, max_len=128):
        self.tokens_list = [nltk_tokenizer(text) for text in tqdm(texts, desc="Tokenizing")]
        self.labels = labels
        self.max_len = max_len

        if vocab is None:
            self.vocab = Vocab(self.tokens_list)
        else:
            self.vocab = vocab

    def __len__(self):
        return len(self.tokens_list)

    def __getitem__(self, idx):
        tokens = self.tokens_list[idx]
        ids = [self.vocab[tok] for tok in tokens[:self.max_len]]

        if len(ids) < self.max_len:
            ids += [self.vocab["<pad>"]] * (self.max_len - len(ids))

        ids = torch.tensor(ids, dtype=torch.long)

        if self.labels is not None:
            label = torch.tensor(int(self.labels[idx]), dtype=torch.long)
            return ids, label
        else:
            return ids


# -------- DataLoader --------
def get_dataloaders(train_texts, train_labels, val_texts, val_labels,
                    batch_size=32, max_len=128):
    train_dataset = TextDataset(train_texts, train_labels, max_len=max_len)
    val_dataset = TextDataset(val_texts, val_labels, vocab=train_dataset.vocab, max_len=max_len)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, train_dataset.vocab


# -------- Test Loader --------
def get_test_loader(test_texts, vocab, batch_size=32, max_len=128):
    test_dataset = TextDataset(test_texts, labels=None, vocab=vocab, max_len=max_len)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return test_loader


## 1. Basic SSM

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# --- basic SSM ---
class DiagonalSSM(nn.Module):
    def __init__(self, state_size: int, dt_init: float = 1.0, init_scale: float = 0.1):
        super().__init__()
        #=========TODO==========#




        #=========TODO==========#

    def _discretize(self):
        #=========TODO==========#




        #=========TODO==========#

    def forward(self, u):
        #=========TODO==========#





        #=========TODO==========#




# --- Block ---
class SSMBlock(nn.Module):
    def __init__(self, in_dim: int, state_size: int, dropout: float = 0.1):
        super().__init__()
        self.proj_in = nn.Linear(in_dim, 1)
        self.ssm = DiagonalSSM(state_size)
        self.proj_out = nn.Linear(state_size, in_dim)
        self.gate = nn.Sigmoid()
        self.norm = nn.LayerNorm(in_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x
        u = self.proj_in(x)
        y = self.ssm(u)
        y = self.proj_out(y)
        y = self.gate(y) * y
        y = self.dropout(y)
        return self.norm(residual + y)


# --- Text Classifier ---
class SSMTextClassifier(nn.Module):
    def __init__(self, vocab_size: int, num_classes: int, state_size: int, num_layers: int,
             emb_dim: int = 128, dropout: float = 0.1, pad_idx: int = 0):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.blocks = nn.ModuleList([SSMBlock(emb_dim, state_size, dropout) for _ in range(num_layers)])
        self.head = nn.Linear(emb_dim, num_classes)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_ids):
        x = self.emb(input_ids)
        for blk in self.blocks:
            x = blk(x)
        pooled = x.mean(dim=1)
        return self.head(self.dropout(pooled))

In [None]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from tqdm import tqdm
import sys


def train_SSM(train_texts, train_labels, val_texts, val_labels,
              num_classes, epochs, batch_size, max_len,
              emb_dim, state_size, num_layers, lr,
              device="cuda"):

    train_loader, val_loader, vocab = get_dataloaders(
        train_texts, train_labels, val_texts, val_labels,
        batch_size=batch_size, max_len=max_len
    )

    pad_idx = vocab.stoi["<pad>"]
    model = SSMTextClassifier(
        vocab_size=len(vocab),
        num_classes=num_classes,
        emb_dim=emb_dim,
        state_size=state_size,
        num_layers=num_layers,
        dropout=0.1,
        pad_idx=pad_idx
    ).to(device)

    classes = np.unique(train_labels)
    class_weights = compute_class_weight(
        class_weight="balanced",
        classes=classes,
        y=train_labels
    )
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scaler = torch.amp.GradScaler("cuda")

    # --- training ---
    for epoch in range(epochs):
        model.train()
        total_loss, correct, total = 0.0, 0, 0
        pbar = tqdm(train_loader, desc=f"Train {epoch+1}/{epochs}", unit="batch")

        for xb, yb in pbar:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()

            with torch.amp.autocast("cuda"):
                logits = model(xb)
                loss = criterion(logits, yb)

            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            preds = logits.argmax(dim=1)
            correct += (preds == yb).sum().item()
            total += yb.size(0)

            pbar.set_postfix({
                "loss": f"{loss.item():.4f}",
                "acc": f"{(correct / total):.4f}"
            })


        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            pbar_val = tqdm(val_loader, desc=f"Valid {epoch+1}/{epochs}", unit="batch")
            for xb, yb in pbar_val:
                xb, yb = xb.to(device), yb.to(device)
                with torch.amp.autocast("cuda"):
                    logits = model(xb)
                    loss = criterion(logits, yb)
                val_loss += loss.item()
                preds = logits.argmax(dim=1)
                val_correct += (preds == yb).sum().item()
                val_total += yb.size(0)

        train_acc = correct / total
        val_acc = val_correct / val_total
        avg_train_loss = total_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)

        print(f"Epoch {epoch+1}/{epochs} | "
              f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | "
              f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

    return model, vocab

In [5]:
def main_SSM(
    train_path,
    valid_path,
    epochs,
    batch_size,
    max_len,
    emb_dim,
    state_size,
    num_layers,
    lr
):

    train_texts, train_labels, val_texts, val_labels, _, _, num_classes = load_dbpedia_data(
        train_path,
        valid_path,
        None
    )

    model, vocab = train_SSM(
        train_texts, train_labels, val_texts, val_labels,
        num_classes=num_classes,
        epochs=epochs,
        batch_size=batch_size,
        max_len=max_len,
        emb_dim=emb_dim,
        state_size=state_size,
        num_layers=num_layers,
        lr=lr
    )

    return model, vocab


In [None]:
train_path = "/kaggle/input/dataset-llm/dataset/train.json"
valid_path = "/kaggle/input/dataset-llm/dataset/val.json"

model, vocab = main_SSM(
    train_path, valid_path,
    epochs=5,
    batch_size=16,
    max_len=64,
    emb_dim=128,
    state_size=32,
    num_layers=3,
    lr=5e-4
)

## 2. FFT-based Convolutional SSM

In [None]:
# --- FFT-based Convolutional SSM (Simplified S4) ---
import torch
import torch.nn as nn
import torch.nn.functional as F

class FFTSSMBlock(nn.Module):
    def __init__(self, seq_len: int, in_dim: int, dropout: float = 0.1):
        super().__init__()
        self.seq_len = seq_len
        self.in_dim = in_dim

        #=========TODO==========#





        #=========TODO==========#

        self.norm = nn.LayerNorm(in_dim)
        self.dropout = nn.Dropout(dropout)
        self.gate = nn.Sigmoid()

    def compute_kernel(self, L):
        t = torch.arange(L, device=self.A_raw.device, dtype=torch.float32)
        A = -F.softplus(self.A_raw)
        decay = torch.exp(torch.outer(t, A))
        K = decay * (self.C * self.B) * 0.1
        return K.T

    def forward(self, x):

        #=========TODO==========#





        #=========TODO==========#

        # --- gating, dropout, residual ---
        y = self.proj_out(y)
        y = self.gate(y) * y
        y = self.dropout(y)
        return self.norm(residual + y)

class S4TextClassifier(nn.Module):
    def __init__(self, vocab_size: int, num_classes: int, seq_len: int,
                 emb_dim=128, num_layers=2, dropout=0.1, pad_idx=0):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)

        self.blocks = nn.ModuleList([
            FFTSSMBlock(seq_len, emb_dim, dropout)
            for _ in range(num_layers)
        ])

        self.head = nn.Linear(emb_dim, num_classes)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_ids):
        # === embedding ===
        x = self.emb(input_ids)

        # === stacked SSM blocks ===
        for blk in self.blocks:
            x = blk(x)

        # === pooling & classifier head ===
        pooled = x.mean(dim=1)
        return self.head(self.dropout(pooled))

In [None]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from sklearn.utils.class_weight import compute_class_weight

def train_S4(
    train_texts, train_labels,
    val_texts, val_labels,
    num_classes,
    epochs=3,
    batch_size=32,
    max_len=128,
    emb_dim=256,
    num_layers=4,
    lr=1e-3,
    device="cuda"
):

    train_loader, val_loader, vocab = get_dataloaders(
        train_texts, train_labels, val_texts, val_labels,
        batch_size=batch_size, max_len=max_len
    )

    pad_idx = vocab.stoi["<pad>"]
    model = S4TextClassifier(
        vocab_size=len(vocab),
        num_classes=num_classes,
        seq_len=max_len,
        emb_dim=emb_dim,
        num_layers=num_layers,
        dropout=0.1,
        pad_idx=pad_idx
    ).to(device)

    classes = np.unique(train_labels)
    class_weights = compute_class_weight(
        class_weight="balanced",
        classes=classes,
        y=train_labels
    )
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scaler = torch.amp.GradScaler("cuda")

    for epoch in range(epochs):
        model.train()
        total_loss, correct, total = 0.0, 0, 0
        pbar = tqdm(train_loader, desc=f"Train {epoch+1}/{epochs}", unit="batch")

        for xb, yb in pbar:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()

            with torch.amp.autocast("cuda"):
                logits = model(xb)
                loss = criterion(logits, yb)

            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            preds = logits.argmax(dim=1)
            correct += (preds == yb).sum().item()
            total += yb.size(0)

            pbar.set_postfix({
                "loss": f"{loss.item():.4f}",
                "acc": f"{(correct / total):.4f}"
            })

        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            pbar_val = tqdm(val_loader, desc=f"Valid {epoch+1}/{epochs}", unit="batch")
            for xb, yb in pbar_val:
                xb, yb = xb.to(device), yb.to(device)
                with torch.amp.autocast("cuda"):
                    logits = model(xb)
                    loss = criterion(logits, yb)
                val_loss += loss.item()
                preds = logits.argmax(dim=1)
                val_correct += (preds == yb).sum().item()
                val_total += yb.size(0)

        train_acc = correct / total
        val_acc = val_correct / val_total
        avg_train_loss = total_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)

        print(f"Epoch {epoch+1}/{epochs} | "
              f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | "
              f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

    return model, vocab


In [None]:
def main_S4(
    train_path,
    val_path,
    test_path,
    epochs,
    batch_size,
    max_len,
    emb_dim,
    num_layers,
    lr
):

    train_texts, train_labels, val_texts, val_labels, test_texts, test_ids, num_classes = load_dbpedia_data(
        train_path, val_path, test_path
    )

    model, vocab = train_S4(
        train_texts, train_labels,
        val_texts, val_labels,
        num_classes=num_classes,
        epochs=epochs,
        batch_size=batch_size,
        max_len=max_len,
        emb_dim=emb_dim,
        num_layers=num_layers,
        lr=lr
    )

    return model, vocab


In [None]:
train_path = "/kaggle/input/dataset-llm/dataset/train.json"
valid_path = "/kaggle/input/dataset-llm/dataset/val.json"

model, vocab = main_S4(
    train_path, valid_path, None,
    epochs=5,
    batch_size=16,
    max_len=64,
    emb_dim=128,
    num_layers=3,
    lr=5e-4
)

