**Đề tài:**
Bộ khung phát hiện sớm mã độc tống tiền dựa trên phân tích hành vi kết hợp mồi nhử và tạo tăng cường truy xuất tri thức tấn công
(A framework of early ransomware detection using behavior analysis with decoy techniques and RAG-based attack knowledge)
- Đinh Lê Thành Công - 22520167
- Hồ Hoàng Diệp - 22520249

In [1]:
!pip install xlstm torch-geometric

Collecting xlstm
  Downloading xlstm-2.0.4-py3-none-any.whl.metadata (24 kB)
Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Collecting reportlab (from xlstm)
  Downloading reportlab-4.4.2-py3-none-any.whl.metadata (1.8 kB)
Collecting joypy (from xlstm)
  Downloading joypy-0.2.6-py2.py3-none-any.whl.metadata (812 bytes)
Collecting ftfy (from xlstm)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Collecting mlstm_kernels (from xlstm)
  Downloading mlstm_kernels-2.0.0-py3-none-any.whl.metadata (25 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->xlstm)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->xlstm)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (

In [2]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torch_geometric.data import Batch
from torch_geometric.nn import global_mean_pool, BatchNorm
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, GINConv
from sklearn.metrics import confusion_matrix, f1_score
from sklearn.model_selection import StratifiedShuffleSplit

# ───────────────────────────────────────────
# xLSTM imports
# ───────────────────────────────────────────
from xlstm import (
    xLSTMBlockStack,
    xLSTMBlockStackConfig,
    mLSTMBlockConfig,
    mLSTMLayerConfig,
    sLSTMBlockConfig,
    sLSTMLayerConfig,
    FeedForwardConfig
)

# ───────────────────────────────────────────
# Dataset
# ───────────────────────────────────────────
class MultiModalDataset(Dataset):
    CACHE_FILE = 'vocab.json'

    def __init__(self, json_root, pt_root, max_seq_len=1500):
        self.max_len = max_seq_len
        self.samples = []

        # vocab
        #if os.path.isfile(self.CACHE_FILE):
        #    with open(self.CACHE_FILE, 'r', encoding='utf-8') as f:
        #        self.vocab = json.load(f)
        #    idx = max(self.vocab.values()) + 1
        #    print(f"[INFO] Loaded vocab ({len(self.vocab)} tokens)")
        #else:
        self.vocab, idx = {'<PAD>': 0, '<UNK>': 1}, 2
        print("[INFO] Building vocab from scratch…")

        # iterate folders
        mapping = [
            (os.path.join(json_root, 'json-atb-benign-507'),
             os.path.join(pt_root,  'benign'),      0),
            (os.path.join(json_root, 'ransom-5xx-new', 'ransomware'),
             os.path.join(pt_root,  'ransomware'),  1)
        ]
        for jdir, pdir, label in mapping:
            if not (os.path.isdir(jdir) and os.path.isdir(pdir)):
                continue
            for fname in os.listdir(jdir):
                if not fname.endswith('.json'):
                    continue
                sid   = os.path.splitext(fname)[0]
                jpath = os.path.join(jdir, fname)
                ppath = os.path.join(pdir, f"{sid}.pt")
                if not os.path.isfile(ppath):
                    continue

                feat = self._load_json(jpath)
                toks = self._extract_tokens(feat)

                if not os.path.isfile(self.CACHE_FILE):
                    for t in toks:
                        if t not in self.vocab:
                            self.vocab[t] = idx
                            idx += 1

                self.samples.append((ppath, toks, label))


    @staticmethod
    def _load_json(path):
        with open(path, 'r', encoding='utf-8', errors='ignore') as f:
            return json.load(f)

    def _extract_tokens(self, feat):
        toks = []
        for call in feat.get('api_call_sequence', [])[:1000]:
            toks.append(f"api:{call.get('api','')}")
        for ft, vals in feat.get('behavior_summary', {}).items():
            toks += [f"feature:{ft}:{v}" for v in vals]
        for d in feat.get('dropped_files', []):
            toks.append(f"dropped:{d if not isinstance(d,dict) else d.get('filepath','')}")
        toks += [f"sig:{s.get('name','')}" for s in feat.get('signatures', [])]
        toks += [f"proc:{p.get('name','')}" for p in feat.get('processes', [])]
        for proto, ents in feat.get('network', {}).items():
            for e in ents:
                if isinstance(e, dict):
                    dst  = e.get('dst') or e.get('dst_ip','')
                    port = e.get('dst_port') or e.get('port','')
                    toks.append(f"net:{proto}:{dst}:{port}")
                else:
                    toks.append(f"net:{proto}:{e}")
        return toks

    def __len__(self): return len(self.samples)

    def __getitem__(self, i):
        ppath, toks, label = self.samples[i]
        graph = torch.load(ppath, weights_only=False)

        idxs = [self.vocab.get(t, 1) for t in toks]
        idxs = idxs[:self.max_len] + [0] * max(0, self.max_len - len(idxs))
        seq  = torch.tensor(idxs, dtype=torch.long)
        return graph, seq, torch.tensor(label, dtype=torch.float32)

def collate_fn(batch):
    graphs, seqs, labels = zip(*batch)
    return Batch.from_data_list(graphs), torch.stack(seqs), torch.stack(labels)

# ───────────────────────────────────────────
# Encoders
# ───────────────────────────────────────────
class GCNEncoder(nn.Module):
    def __init__(self, in_feats, hidden=64, drop=0.3):
        super().__init__()
        self.conv1 = GCNConv(in_feats, hidden)
        self.bn1   = BatchNorm(hidden)
        self.conv2 = GCNConv(hidden, hidden)
        self.bn2   = BatchNorm(hidden)
        self.drop  = drop
        self.output_dim = hidden

    def forward(self, x, ei, batch):
        x = F.relu(self.bn1(self.conv1(x, ei)))
        x = F.dropout(x, self.drop, training=self.training)
        x = F.relu(self.bn2(self.conv2(x, ei)))
        x = F.dropout(x, self.drop, training=self.training)
        return global_mean_pool(x, batch)

class GATEncoder(GCNEncoder):
    def __init__(self, in_feats, hidden=64, drop=0.3):
        super().__init__(in_feats, hidden, drop)
        self.conv1 = GATConv(in_feats, hidden, heads=4, concat=False)
        self.conv2 = GATConv(hidden, hidden, heads=4, concat=False)

class SageEncoder(GCNEncoder):
    def __init__(self, in_feats, hidden=64, drop=0.3):
        super().__init__(in_feats, hidden, drop)
        self.conv1 = SAGEConv(in_feats, hidden)
        self.conv2 = SAGEConv(hidden, hidden)

class GINEncoder(nn.Module):
    def __init__(self, in_feats, hidden=64, drop=0.3):
        super().__init__()
        nn1 = nn.Sequential(nn.Linear(in_feats, hidden), nn.ReLU(),
                            nn.Linear(hidden, hidden))
        nn2 = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(),
                            nn.Linear(hidden, hidden))
        self.g1 = GINConv(nn1)
        self.g2 = GINConv(nn2)
        self.bn1 = BatchNorm(hidden)
        self.bn2 = BatchNorm(hidden)
        self.drop = drop
        self.output_dim = hidden

    def forward(self, x, ei, batch):
        x = F.relu(self.bn1(self.g1(x, ei)))
        x = F.dropout(x, self.drop, training=self.training)
        x = F.relu(self.bn2(self.g2(x, ei)))
        x = F.dropout(x, self.drop, training=self.training)
        return global_mean_pool(x, batch)

class xLSTMEncoder(nn.Module):
    def __init__(self, vocab_size, embed=128, seq_len=1500):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed, padding_idx=0)
        cfg = xLSTMBlockStackConfig(
            mlstm_block=mLSTMBlockConfig(
                mlstm=mLSTMLayerConfig(conv1d_kernel_size=4,
                                       qkv_proj_blocksize=4, num_heads=4)),
            slstm_block=sLSTMBlockConfig(
                slstm=sLSTMLayerConfig(backend="vanilla", num_heads=4,
                                       conv1d_kernel_size=4,
                                       bias_init="powerlaw_blockdependent"),
                feedforward=FeedForwardConfig(proj_factor=1.3, act_fn="gelu")
            ),
            context_length=seq_len,
            num_blocks=1,
            embedding_dim=embed,
            slstm_at=[0]
        )
        self.xlstm = xLSTMBlockStack(cfg)
        self.output_dim = embed

    def forward(self, seq):
        return self.xlstm(self.embed(seq)).mean(dim=1)

class LSTMEncoder(nn.Module):          # Giữ lại để dùng cho các combo khác
    def __init__(self, vocab_size, embed=128, hidden=128):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed, padding_idx=0)
        self.lstm  = nn.LSTM(embed, hidden, batch_first=True)
        self.output_dim = hidden

    def forward(self, seq):
        _, (hn, _) = self.lstm(self.embed(seq))
        return hn.squeeze(0)

# ───────────────────────────────────────────
# Classifiers & wrappers
# ───────────────────────────────────────────
class MLPClassifier(nn.Module):
    def __init__(self, in_dim, hiddens=[128, 64], drop=0.3):
        super().__init__()
        dims, layers = [in_dim] + hiddens, []
        for i in range(len(hiddens)):
            layers += [nn.Linear(dims[i], dims[i+1]), nn.ReLU(), nn.Dropout(drop)]
        layers.append(nn.Linear(dims[-1], 1))
        self.mlp = nn.Sequential(*layers)

    def forward(self, x): return self.mlp(x).squeeze(1)

class GraphOnly(nn.Module):
    def __init__(self, enc): super().__init__(); self.enc=enc; self.fc=MLPClassifier(enc.output_dim)
    def forward(self, g,s): return self.fc(self.enc(g.x,g.edge_index,g.batch))

class SeqOnly(nn.Module):
    def __init__(self, enc): super().__init__(); self.enc=enc; self.fc=MLPClassifier(enc.output_dim)
    def forward(self, g,s): return self.fc(self.enc(s))

class MultiModal(nn.Module):
    def __init__(self, genc, senc, hid=128):
        super().__init__()
        self.genc, self.senc = genc, senc
        self.fc = MLPClassifier(genc.output_dim + senc.output_dim,
                                [hid, hid//2])
    def forward(self, g,s):
        return self.fc(torch.cat([self.genc(g.x,g.edge_index,g.batch),
                                  self.senc(s)],1))

# ───────────────────────────────────────────
# Train & eval
# ───────────────────────────────────────────
def train_epoch(model, loader, crit, opt, dev):
    model.train(); tot, ok, n = 0,0,0
    for g,s,l in loader:
        g,s,l = g.to(dev),s.to(dev),l.to(dev)
        opt.zero_grad()
        logit = model(g,s); loss = crit(logit,l); loss.backward(); opt.step()
        tot += loss.item()*l.size(0)
        ok  += ((torch.sigmoid(logit)>.5)==l).sum().item(); n += l.size(0)
    return tot/n, ok/n

def metrics(model, loader, crit, dev):
    model.eval(); tot, p, y = 0,[],[]
    with torch.no_grad():
        for g,s,l in loader:
            g,s,l = g.to(dev),s.to(dev),l.to(dev)
            logit = model(g,s); tot += crit(logit,l).item()*l.size(0)
            p+= (torch.sigmoid(logit)>.5).float().cpu().tolist(); y+=l.cpu().tolist()
    tn,fp,fn,tp = confusion_matrix(y,p).ravel(); n=tp+tn+fp+fn
    return {'loss':tot/n,'acc':(tp+tn)/n,'tpr':tp/(tp+fn+1e-9),
            'fpr':fp/(fp+tn+1e-9),'f1':f1_score(y,p)}

def run(name, model, tl, vl, te, lr, ep, patience, dev, ckpt=None):
    crit = nn.BCEWithLogitsLoss(); opt = torch.optim.Adam(model.parameters(), lr=lr)
    best, bad=0,0; best_state=None
    for i in range(1,ep+1):
        tloss,tacc = train_epoch(model,tl,crit,opt,dev)
        v = metrics(model,vl,crit,dev)
        print(f"[{name}] Ep{i:02d} | TrainL {tloss:.4f} A {tacc:.4f} "
              f"| ValL {v['loss']:.4f} A {v['acc']:.4f} F1 {v['f1']:.4f}")
        if v['f1']>best:
            best=v['f1']; best_state=model.state_dict(); bad=0
            if ckpt: torch.save(best_state, ckpt)
        else:
            bad+=1
            if bad>=patience: print(f"[{name}] Early stop"); break
    model.load_state_dict(best_state)
    t = metrics(model,te,crit,dev)
    print(f"[{name}] TEST → L {t['loss']:.4f} A {t['acc']:.4f} "
          f"TPR {t['tpr']:.4f} FPR {t['fpr']:.4f} F1 {t['f1']:.4f}")
    if ckpt: print(f"[{name}] Saved best to {ckpt}\n")

# ───────────────────────────────────────────
# Main
# ───────────────────────────────────────────
def main():
    json_root   = "/kaggle/input"
    pt_root     = "/kaggle/input/1000-final/1000" # Chỉnh API tại đây
    max_len     = 2000
    bs, lr, ep, patience = 8, 1e-3, 20, 5
    dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    ds = MultiModalDataset(json_root, pt_root, max_len)
    labels = [l for _,_,l in ds.samples]

    outer = StratifiedShuffleSplit(1, test_size=.15, random_state=42)
    tv_idx, test_idx = next(outer.split(range(len(ds)), labels))
    inner = StratifiedShuffleSplit(1, test_size=.17647, random_state=42)
    y_tv = [labels[i] for i in tv_idx]
    tr_rel, va_rel = next(inner.split(tv_idx, y_tv))
    tr_idx = [tv_idx[i] for i in tr_rel]; va_idx=[tv_idx[i] for i in va_rel]

    tl = DataLoader(Subset(ds,tr_idx),batch_size=bs,shuffle=True,
                     collate_fn=collate_fn)
    vl = DataLoader(Subset(ds,va_idx),batch_size=bs,shuffle=False,
                     collate_fn=collate_fn)
    te = DataLoader(Subset(ds,test_idx),batch_size=bs,shuffle=False,
                     collate_fn=collate_fn)

    # GCN-only
    gfeat = ds[0][0].x.size(1)
    run("GCN",
        GraphOnly(GCNEncoder(gfeat).to(dev)).to(dev),
        tl,vl,te, lr,ep,patience,dev)

    # xLSTM-only
    run("xLSTM",
        SeqOnly(xLSTMEncoder(len(ds.vocab),seq_len=max_len).to(dev)).to(dev),
        tl,vl,te, lr,ep,patience,dev)

    # combos
    g_encoders = {'gcn': GCNEncoder, 'gat': GATEncoder, 'sage': SageEncoder, 'gin': GINEncoder}
    s_encoders = {'xlstm': xLSTMEncoder, 'lstm': LSTMEncoder}

    for gn,gc in g_encoders.items():
        for sn,sc in s_encoders.items():
            print()
            genc = gc(gfeat).to(dev)
            senc = sc(len(ds.vocab), seq_len=max_len) if sn=='xlstm' \
                   else sc(len(ds.vocab))
            senc = senc.to(dev)
            ckpt = 'best_gcn_xlstm.pt' if gn=='gcn' and sn=='xlstm' else None
            run(f"{gn.upper()}+{sn.upper()}",
                MultiModal(genc,senc).to(dev),
                tl,vl,te, lr,ep,patience,dev, ckpt)

if __name__ == '__main__':
    main()


[INFO] Building vocab from scratch…
[GCN] Ep01 | TrainL 0.3817 A 0.8755 | ValL 0.0952 A 0.9548 F1 0.9576
[GCN] Ep02 | TrainL 0.0949 A 0.9710 | ValL 0.0243 A 0.9935 F1 0.9937
[GCN] Ep03 | TrainL 0.0987 A 0.9751 | ValL 0.0779 A 0.9871 F1 0.9875
[GCN] Ep04 | TrainL 0.0566 A 0.9820 | ValL 0.0407 A 0.9871 F1 0.9875
[GCN] Ep05 | TrainL 0.1053 A 0.9710 | ValL 0.0562 A 0.9871 F1 0.9875
[GCN] Ep06 | TrainL 0.0864 A 0.9710 | ValL 0.0298 A 0.9935 F1 0.9937
[GCN] Ep07 | TrainL 0.0602 A 0.9820 | ValL 0.0425 A 0.9935 F1 0.9937
[GCN] Early stop
[GCN] TEST → L 0.0479 A 0.9871 TPR 0.9747 FPR 0.0000 F1 0.9872
[xLSTM] Ep01 | TrainL 0.3881 A 0.8575 | ValL 0.1666 A 0.8645 F1 0.8467
[xLSTM] Ep02 | TrainL 0.0952 A 0.9640 | ValL 0.0835 A 0.9806 F1 0.9806
[xLSTM] Ep03 | TrainL 0.0415 A 0.9834 | ValL 0.0251 A 0.9871 F1 0.9873
[xLSTM] Ep04 | TrainL 0.0070 A 0.9986 | ValL 0.0303 A 0.9935 F1 0.9937
[xLSTM] Ep05 | TrainL 0.0264 A 0.9945 | ValL 0.2596 A 0.9032 F1 0.9133
[xLSTM] Ep06 | TrainL 0.0161 A 0.9945 | ValL 0

In [None]:
# train_xran.py  –  PyTorch ≥2.0
import os, json, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import confusion_matrix, f1_score
from tqdm import tqdm

# ---------- HYPER ----------
MAX_API, MAX_DLL, MAX_MUTEX = 500, 10, 10
SEQ_LEN = MAX_API + MAX_DLL + MAX_MUTEX      # 520
EMB_DIM = 128
BATCH, EPOCHS, PATIENCE, LR = 64, 20, 5, 1e-3
SEED = 42; torch.manual_seed(SEED)

# ---------- DATASET ----------
class XRanDataset(Dataset):
    def __init__(self, benign_dir, ransom_dir):
        self.samples = []
        self.vocab, nxt = {'<PAD>':0,'<UNK>':1}, 2
        for lbl, root in [(0, benign_dir), (1, ransom_dir)]:
            if not os.path.isdir(root):
                raise FileNotFoundError(root)
            for fn in tqdm(os.listdir(root), desc=f"Parse {root}"):
                if not fn.endswith('.json'): continue
                with open(os.path.join(root,fn), 'r', encoding='utf-8', errors='ignore') as f:
                    feat = json.load(f)
                toks = self._tokens(feat)
                for t in toks:
                    if t not in self.vocab:
                        self.vocab[t] = nxt; nxt += 1
                self.samples.append((toks, lbl))

    def _tokens(self, feat):
        apis = [f"api:{c.get('api','')}" for c in feat.get('api_call_sequence', [])[:MAX_API]]
        beh  = feat.get('behavior_summary', {})
        dlls = [f"dll:{d}"   for d in beh.get('dll_loaded', [])][:MAX_DLL]
        mtx  = [f"mutex:{m}" for m in beh.get('mutex',      [])][:MAX_MUTEX]
        return apis + dlls + mtx

    def __len__(self): return len(self.samples)
    def __getitem__(self, i):
        toks, lbl = self.samples[i]
        ids = [self.vocab.get(t,1) for t in toks] + [0]*(SEQ_LEN - len(toks))
        return torch.tensor(ids), torch.tensor(lbl, dtype=torch.float32)

def collate(batch):
    seq, lbl = zip(*batch)
    return torch.stack(seq), torch.stack(lbl)

# ---------- MODEL ----------
class XRanCNN(nn.Module):
    def __init__(self, vocab):
        super().__init__()
        self.emb = nn.Embedding(vocab, EMB_DIM, padding_idx=0)
        self.c1  = nn.Conv1d(EMB_DIM, 128, 5, padding=2)
        self.c2  = nn.Conv1d(128, 64, 3, padding=1)
        self.pool= nn.MaxPool1d(2)
        self.fc1 = nn.Linear((SEQ_LEN//4)*64, 64)
        self.drop= nn.Dropout(0.5)
        self.out = nn.Linear(64, 1)
    def forward(self, seq):
        x = self.emb(seq).transpose(1,2)
        x = self.pool(F.relu(self.c1(x)))
        x = self.pool(F.relu(self.c2(x)))
        x = self.drop(F.relu(self.fc1(x.flatten(1))))
        return self.out(x).squeeze(1)

# ---------- HELPERS ----------
def train_epoch(m, ld, crit, opt, dev):
    m.train(); tot, ok, n = 0,0,0
    for s,l in ld:
        s,l=s.to(dev),l.to(dev)
        opt.zero_grad(); y=m(s); loss=crit(y,l); loss.backward(); opt.step()
        tot+=loss.item()*l.size(0); ok+=((torch.sigmoid(y)>.5)==l).sum().item(); n+=l.size(0)
    return tot/n, ok/n

@torch.no_grad()
def eval_epoch(m, ld, crit, dev):
    m.eval(); tot, p, y = 0,[],[]
    for s,l in ld:
        s,l=s.to(dev),l.to(dev)
        yhat=m(s); tot+=crit(yhat,l).item()*l.size(0)
        p+= (torch.sigmoid(yhat)>.5).float().cpu().tolist()
        y+= l.cpu().tolist()
    tn,fp,fn,tp = confusion_matrix(y,p).ravel(); n=tp+tn+fp+fn
    return {'loss':tot/n, 'acc':(tp+tn)/n,
            'tpr':tp/(tp+fn+1e-9), 'fpr':fp/(fp+tn+1e-9),
            'f1':f1_score(y,p)}

# ---------- MAIN ----------
def main():
    benign_dir = '/kaggle/input/json-atb-benign-507'
    ransom_dir = '/kaggle/input/ransom-5xx-new/ransomware'
    dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    ds = XRanDataset(benign_dir, ransom_dir)
    labels = [lbl for _,lbl in ds.samples]

    outer = StratifiedShuffleSplit(1, test_size=.15, random_state=SEED)
    tv_idx, te_idx = next(outer.split(range(len(ds)), labels))
    inner = StratifiedShuffleSplit(1, test_size=.17647, random_state=SEED)
    y_tv = [labels[i] for i in tv_idx]
    tr_rel, va_rel = next(inner.split(tv_idx, y_tv))
    tr_idx=[tv_idx[i] for i in tr_rel]; va_idx=[tv_idx[i] for i in va_rel]

    tl=DataLoader(Subset(ds,tr_idx),BATCH,True, collate_fn=collate)
    vl=DataLoader(Subset(ds,va_idx),BATCH,False,collate_fn=collate)
    te=DataLoader(Subset(ds,te_idx),BATCH,False,collate_fn=collate)

    model = XRanCNN(len(ds.vocab)).to(dev)
    crit  = nn.BCEWithLogitsLoss()
    opt   = torch.optim.Adam(model.parameters(), lr=LR)

    best_acc, bad, best_state = 0,0,None
    for ep in range(1, EPOCHS+1):
        trL,trA = train_epoch(model, tl, crit, opt, dev)
        v = eval_epoch(model, vl, crit, dev)
        print(f"Ep{ep:02d} | TrL {trL:.4f} A {trA:.4f} | "
              f"VaL {v['loss']:.4f} A {v['acc']:.4f}")
        if v['acc'] > best_acc:
            best_acc, bad, best_state = v['acc'], 0, model.state_dict()
        else:
            bad += 1
            if bad >= PATIENCE:
                print("Early stopping."); break

    model.load_state_dict(best_state)
    t = eval_epoch(model, te, crit, dev)
    print(f"TEST → Loss {t['loss']:.4f} Acc {t['acc']:.4f} "
          f"TPR {t['tpr']:.4f} FPR {t['fpr']:.4f} F1 {t['f1']:.4f}")
    torch.save(best_state, 'best_xran_cnn.pt')
    print("Saved best checkpoint to best_xran_cnn.pt")

if __name__ == '__main__':
    main()
