In [None]:
import os, glob, random, csv, json
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Optional

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import snntorch as snn
from snntorch import surrogate

from sklearn.metrics import (
    confusion_matrix, classification_report,
    f1_score, precision_score, recall_score
)

# ============================================================
# TRAIN/QAT CONFIG (Q8.4)
# ============================================================

Q_BITS  = 8
Q_FRAC  = 4
Q_SCALE = 2 ** Q_FRAC
Q_MIN   = -(2 ** (Q_BITS - 1))          # -128
Q_MAX   = (2 ** (Q_BITS - 1)) - 1       # +127

# ============================================================
# EXPORT CONFIG (Human-readable FLOAT Q4.4)
# Q4.4 in 8-bit signed => real range [-8.0, 7.9375], step 1/16
# ============================================================

EXPORT_Q_BITS  = 8
EXPORT_Q_FRAC  = 4
EXPORT_Q_SCALE = 2 ** EXPORT_Q_FRAC
EXPORT_Q_MIN   = -(2 ** (EXPORT_Q_BITS - 1))      # -128
EXPORT_Q_MAX   = (2 ** (EXPORT_Q_BITS - 1)) - 1   # +127

# ============================================================
# TRAIN CONFIG
# ============================================================

DATASET_DIR = "spikes_sampler_centered_var"

MAX_EPOCHS     = 100
WARMUP_EPOCHS  = 15
BATCH_SIZE     = 32
EVAL_BATCH     = 128
LR             = 1e-3
SEED           = 42

VAL_RATIO  = 0.15
TEST_RATIO = 0.15

# Early stopping (monitor val_loss) - QAT ONLY
PATIENCE  = 12
MIN_DELTA = 1e-4

# Clamp ranges for LIF params (hardware safe)
BETA_CLAMP = (0.0, 0.999)
TH_CLAMP   = (0.0, 8.0)

# QAT behavior
QUANTIZE_PARAMS_EACH_EPOCH = True  # snap ke grid setelah tiap epoch (QAT)

# Export config
EXPORT_PRETTY_Q4_4 = True
EXPORT_EVERY_BEST_QAT = True

# ============================================================
# AAMI LABELS
# ============================================================

AAMI_LABEL_TO_IDX = {"N": 0, "S": 1, "V": 2, "F": 3, "Q": 4}
IDX_TO_AAMI_LABEL = {v: k for k, v in AAMI_LABEL_TO_IDX.items()}

# ============================================================
# QUANTIZATION (STE) for TRAINING (Q8.4)
# ============================================================

def quantize_ste(x: torch.Tensor) -> torch.Tensor:
    """
    Quantization-aware snapping ke grid Q8.4 dengan STE.
    Forward : round + clip ke grid Q8.4
    Backward: gradien dianggap identitas (STE)
    """
    x_q = torch.clamp(torch.round(x * Q_SCALE), Q_MIN, Q_MAX) / Q_SCALE
    return x + (x_q - x).detach()

# ============================================================
# DATASET
# ============================================================

class SamplerSpikeECGDataset(Dataset):
    def __init__(self, root_dir: str, classes: Optional[List[str]] = None):
        self.samples = []
        classes = classes or list(AAMI_LABEL_TO_IDX.keys())
        for c in classes:
            for fp in sorted(glob.glob(os.path.join(root_dir, c, "*.npz"))):
                self.samples.append((fp, c))
        if not self.samples:
            raise RuntimeError(f"Dataset kosong di: {root_dir}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        fp, cls = self.samples[idx]
        d = np.load(fp, allow_pickle=True)
        spikes = (d["spikes"] > 0).astype(np.float32)
        d.close()

        if spikes.ndim != 2:
            raise ValueError(f"spikes harus 2D, dapat {spikes.shape} pada {fp}")

        # ensure [T,C]
        if spikes.shape[0] < spikes.shape[1]:
            spikes = spikes.T

        return torch.from_numpy(spikes), AAMI_LABEL_TO_IDX[cls]

# ============================================================
# SPLIT
# ============================================================

def split_dataset(dataset, val_ratio=0.15, test_ratio=0.15, seed=42, save_dir=None):
    rng = random.Random(seed)
    cls_idx = {c: [] for c in AAMI_LABEL_TO_IDX}
    for i, (_, c) in enumerate(dataset.samples):
        cls_idx[c].append(i)

    tr, va, te = [], [], []
    for c, idx in cls_idx.items():
        rng.shuffle(idx)
        n = len(idx)
        nv, nt = int(n * val_ratio), int(n * test_ratio)
        va += idx[:nv]
        te += idx[nv:nv + nt]
        tr += idx[nv + nt:]
        print(f"{c}: total={n}, train={len(idx[nv+nt:])}, val={nv}, test={nt}")

    if save_dir:
        save_dir = Path(save_dir)
        save_dir.mkdir(parents=True, exist_ok=True)

        split_data = {
            "train": sorted(tr),
            "val": sorted(va),
            "test": sorted(te),
            "seed": seed,
            "val_ratio": val_ratio,
            "test_ratio": test_ratio
        }

        with open(save_dir / "split_indices.json", "w") as f:
            json.dump(split_data, f, indent=2)

        print(f"\n[SPLIT] Indices saved to: {save_dir}")

    return (
        torch.utils.data.Subset(dataset, tr),
        torch.utils.data.Subset(dataset, va),
        torch.utils.data.Subset(dataset, te),
    )

# ============================================================
# MODEL (BIAS OFF + optional QAT in forward)
# ============================================================

class SNN_ECG_QAT(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()

        self.fc1 = nn.Linear(input_dim, 30, bias=False)
        self.lif1 = snn.Leaky(
            beta=0.9, threshold=1.0,
            learn_beta=True, learn_threshold=True,
            spike_grad=surrogate.atan(),
            reset_mechanism="subtract"
        )

        self.fc2 = nn.Linear(30, 30, bias=False)
        self.lif2 = snn.Leaky(
            beta=0.9, threshold=1.0,
            learn_beta=True, learn_threshold=True,
            spike_grad=surrogate.atan(),
            reset_mechanism="subtract"
        )

        self.fc3 = nn.Linear(30, 5, bias=False)
        self.lif3 = snn.Leaky(
            beta=0.9, threshold=1.0,
            learn_beta=True, learn_threshold=True,
            spike_grad=surrogate.atan(),
            reset_mechanism="subtract"
        )

    def forward(self, spikes, qat_on: bool = True):
        """
        spikes: [B,T,C]
        qat_on=True : bobot dipaksa berada di grid Q8.4 saat forward (QAT)
        """
        B, T, _ = spikes.shape
        device = spikes.device

        mem1 = torch.zeros(B, 30, device=device)
        mem2 = torch.zeros(B, 30, device=device)
        mem3 = torch.zeros(B, 5, device=device)

        spk_sum = torch.zeros(B, 5, device=device)

        if qat_on:
            w1 = quantize_ste(self.fc1.weight)
            w2 = quantize_ste(self.fc2.weight)
            w3 = quantize_ste(self.fc3.weight)
        else:
            w1 = self.fc1.weight
            w2 = self.fc2.weight
            w3 = self.fc3.weight

        for t in range(T):
            cur = spikes[:, t]  # [B,C]

            cur1 = torch.matmul(cur, w1.t())
            spk1, mem1 = self.lif1(cur1, mem1)

            cur2 = torch.matmul(spk1, w2.t())
            spk2, mem2 = self.lif2(cur2, mem2)

            cur3 = torch.matmul(spk2, w3.t())
            spk3, mem3 = self.lif3(cur3, mem3)

            spk_sum += spk3

        return spk_sum  # [B,5]

# ============================================================
# METRICS
# ============================================================

@torch.no_grad()
def evaluate(model, loader, device, qat_on: bool):
    model.eval()
    ce = nn.CrossEntropyLoss()

    y, p = [], []
    loss_sum, n = 0.0, 0

    for x, t in loader:
        x, t = x.to(device), t.to(device)
        o = model(x, qat_on=qat_on)
        loss_sum += ce(o, t).item() * t.size(0)
        y.append(t.cpu().numpy())
        p.append(o.argmax(1).cpu().numpy())
        n += t.size(0)

    y = np.concatenate(y) if y else np.array([], dtype=np.int64)
    p = np.concatenate(p) if p else np.array([], dtype=np.int64)

    return {
        "loss": float(loss_sum / max(n, 1)),
        "acc": float((y == p).mean()) if y.size else 0.0,
        "f1": float(f1_score(y, p, average="macro", zero_division=0)) if y.size else 0.0,
        "prec": float(precision_score(y, p, average="macro", zero_division=0)) if y.size else 0.0,
        "rec": float(recall_score(y, p, average="macro", zero_division=0)) if y.size else 0.0,
        "y": y, "p": p
    }

# ============================================================
# CURVES + CONFUSION MATRIX
# ============================================================

def save_curves(history: List[Dict], out_dir: Path):
    epochs = [h["epoch"] for h in history]

    def plot_key(key, title, ylabel, fname):
        plt.figure(figsize=(8, 4))
        plt.plot(epochs, [h[f"train_{key}"] for h in history], label=f"Train {key}")
        plt.plot(epochs, [h[f"val_{key}"] for h in history], label=f"Val {key}")
        plt.xlabel("Epoch")
        plt.ylabel(ylabel)
        plt.title(title)
        plt.grid(True, alpha=0.3)
        plt.legend()
        plt.tight_layout()
        plt.savefig(out_dir / fname, dpi=150, bbox_inches="tight")
        plt.close()

    plot_key("loss", "Loss per Epoch", "Loss", "loss_curve.png")
    plot_key("acc", "Accuracy per Epoch", "Accuracy", "acc_curve.png")
    plot_key("f1", "Macro-F1 per Epoch", "Macro-F1", "f1_curve.png")
    plot_key("prec", "Macro-Precision per Epoch", "Macro-Precision", "precision_curve.png")
    plot_key("rec", "Macro-Recall per Epoch", "Macro-Recall", "recall_curve.png")

def save_cm(y, p, path: Path, title: str):
    cm = confusion_matrix(y, p, labels=[0,1,2,3,4])
    plt.figure(figsize=(5, 4))
    plt.imshow(cm, cmap="Blues")
    plt.title(title)
    plt.colorbar()
    plt.xticks(range(5), list(IDX_TO_AAMI_LABEL.values()))
    plt.yticks(range(5), list(IDX_TO_AAMI_LABEL.values()))
    for i in range(5):
        for j in range(5):
            plt.text(j, i, cm[i, j], ha="center", va="center")
    plt.tight_layout()
    plt.savefig(path, dpi=150, bbox_inches="tight")
    plt.close()

# ============================================================
# HARDWARE-AWARE SNAP TO GRID (Q8.4) - per epoch (QAT phase)
# ============================================================

@torch.no_grad()
def quantize_params_to_grid(model: SNN_ECG_QAT):
    # Weights (snap ke grid)
    for fc in [model.fc1, model.fc2, model.fc3]:
        fc.weight.copy_(quantize_ste(fc.weight))

    # LIF params (clamp lalu snap ke grid)
    for lif in [model.lif1, model.lif2, model.lif3]:
        lif.beta.copy_(quantize_ste(torch.clamp(lif.beta, *BETA_CLAMP)))
        lif.threshold.copy_(quantize_ste(torch.clamp(lif.threshold, *TH_CLAMP)))

# ============================================================
# EXPORT: Human-readable FLOAT Q4.4 (pretty)
# ============================================================

def quantize_to_qfloat(x: torch.Tensor, q_scale: int, q_min: int, q_max: int) -> torch.Tensor:
    xi = torch.clamp(torch.round(x * q_scale), q_min, q_max)
    return xi / q_scale

@torch.no_grad()
def export_params_q4_4_pretty(model: SNN_ECG_QAT, out_dir: Path):
    out_dir.mkdir(parents=True, exist_ok=True)

    def save_matrix_csv_pretty(mat: np.ndarray, path: Path, row_prefix="out", col_prefix="in"):
        cols = [f"{col_prefix}{j}" for j in range(mat.shape[1])]
        with open(path, "w", newline="") as f:
            w = csv.writer(f)
            w.writerow([""] + cols)
            for i in range(mat.shape[0]):
                w.writerow([f"{row_prefix}{i}"] + [f"{mat[i, j]:.4f}" for j in range(mat.shape[1])])

    def clip_stats(x_float: torch.Tensor, x_qfloat: torch.Tensor) -> Dict[str, float]:
        xi = torch.round(x_float * EXPORT_Q_SCALE)
        clipped = (xi < EXPORT_Q_MIN) | (xi > EXPORT_Q_MAX)
        pct = 100.0 * clipped.float().mean().item() if clipped.numel() else 0.0
        return {
            "min": float(x_qfloat.min().item()),
            "max": float(x_qfloat.max().item()),
            "mean": float(x_qfloat.mean().item()),
            "pct_clipped": float(pct),
        }

    summary_lines = []

    meta = {
        "export_q_format": "Q4.4 (signed 8-bit total, frac=4)",
        "step": 1.0 / EXPORT_Q_SCALE,
        "range_real": [-8.0, 7.9375],
        "range_int": [EXPORT_Q_MIN, EXPORT_Q_MAX],
        "note": "FLOAT values snapped to Q4.4 grid (bukan integer mentah)."
    }
    (out_dir / "META_Q4_4.json").write_text(json.dumps(meta, indent=2))

    # Weights
    for name, fc in [("fc1", model.fc1), ("fc2", model.fc2), ("fc3", model.fc3)]:
        w = fc.weight.detach().cpu()
        wq = quantize_to_qfloat(w, EXPORT_Q_SCALE, EXPORT_Q_MIN, EXPORT_Q_MAX)
        st = clip_stats(w, wq)

        wq_np = wq.numpy()
        save_matrix_csv_pretty(wq_np, out_dir / f"{name}_weight_Q4_4_float.csv", row_prefix="out", col_prefix="in")
        np.savetxt(out_dir / f"{name}_weight_Q4_4_float_plain.csv", wq_np, delimiter=",", fmt="%.4f")

        summary_lines.append(
            f"[{name}] shape={tuple(w.shape)}  min={st['min']:.4f}  max={st['max']:.4f}  "
            f"mean={st['mean']:.4f}  clipped={st['pct_clipped']:.2f}%"
        )

    # LIF params
    lif_dict = {}
    for i, lif in enumerate([model.lif1, model.lif2, model.lif3], start=1):
        beta = torch.clamp(lif.beta.detach().cpu(), *BETA_CLAMP)
        thr  = torch.clamp(lif.threshold.detach().cpu(), *TH_CLAMP)

        beta_q = quantize_to_qfloat(beta, EXPORT_Q_SCALE, EXPORT_Q_MIN, EXPORT_Q_MAX)
        thr_q  = quantize_to_qfloat(thr,  EXPORT_Q_SCALE, EXPORT_Q_MIN, EXPORT_Q_MAX)

        lif_dict[f"lif{i}"] = {
            "beta_Q4_4": float(beta_q.item()),
            "threshold_Q4_4": float(thr_q.item()),
            "beta_clamp": list(BETA_CLAMP),
            "threshold_clamp": list(TH_CLAMP),
        }
        summary_lines.append(f"[lif{i}] beta={beta_q.item():.4f}  thr={thr_q.item():.4f}")

    (out_dir / "LIF_PARAMS_Q4_4.json").write_text(json.dumps(lif_dict, indent=2))
    (out_dir / "SUMMARY_Q4_4.txt").write_text("\n".join(summary_lines))

    print(f"[EXPORT] Pretty FLOAT Q4.4 saved to: {out_dir}")

# ============================================================
# DETAILED TEST CONFUSION MATRIX
# ============================================================

def print_detailed_test_cm(y_true, y_pred, total_samples):
    cm = confusion_matrix(y_true, y_pred, labels=[0,1,2,3,4])

    print("\n[TEST CONFUSION MATRIX - BEST QAT]")
    correct = (y_true == y_pred).sum()
    acc = correct / len(y_true)
    print(f"Test Accuracy: {acc*100:.2f}% ({correct}/{len(y_true)} correct)")

    f1_macro = f1_score(y_true, y_pred, average="macro", zero_division=0)
    print(f"Test Macro-F1: {f1_macro:.3f}")

    print(f"\nConfusion Matrix (Test Set, {total_samples} samples per class):")

    labels = list(IDX_TO_AAMI_LABEL.values())
    header = "              " + "".join([f"{l:>7}" for l in labels])
    print(header)

    for i, label in enumerate(labels):
        row = f"       {label} |" + "".join([f"{cm[i,j]:>7}" for j in range(5)])
        print(row)

    print("\nPer-Class Metrics:")
    report = classification_report(
        y_true, y_pred,
        target_names=labels,
        digits=3, zero_division=0
    )
    print(report)

# ============================================================
# MAIN (QAT-ONLY BEST)
# ============================================================

def main():
    print("Program pelatihan:")

    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    run_dir = Path("runs") / f"run_QATONLY_Q8_4_exportQ4_4_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    (run_dir / "curves").mkdir(parents=True, exist_ok=True)
    (run_dir / "confusion").mkdir(parents=True, exist_ok=True)
    (run_dir / "checkpoints").mkdir(parents=True, exist_ok=True)
    (run_dir / "export_params_Q4_4_pretty").mkdir(parents=True, exist_ok=True)

    dataset = SamplerSpikeECGDataset(DATASET_DIR)
    tr, va, te = split_dataset(dataset, VAL_RATIO, TEST_RATIO, SEED, save_dir=run_dir / "data_split")

    C = dataset[0][0].shape[1]
    model = SNN_ECG_QAT(C).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=LR)
    ce = nn.CrossEntropyLoss()

    train_loader      = DataLoader(tr, batch_size=BATCH_SIZE, shuffle=True)
    train_eval_loader = DataLoader(tr, batch_size=EVAL_BATCH, shuffle=False)
    val_loader        = DataLoader(va, batch_size=EVAL_BATCH, shuffle=False)
    test_loader       = DataLoader(te, batch_size=EVAL_BATCH, shuffle=False)

    history: List[Dict] = []
    best_val_loss_qat = float("inf")
    best_epoch_qat = -1
    no_improve_qat = 0

    print("\n" + "="*60)
    print("# WARMUP PHASE (Float Training, epoch 1-15, qat=0)")
    print("="*60 + "\n")

    for epoch in range(1, MAX_EPOCHS + 1):
        qat_on = (epoch > WARMUP_EPOCHS)

        if epoch == WARMUP_EPOCHS + 1:
            print("\n" + "="*60)
            print("# QAT PHASE (epoch 16+ , qat=1)")
            print("# Export hanya saat val_loss membaik (> MIN_DELTA)")
            print("="*60 + "\n")

        model.train()
        for x, t in train_loader:
            x, t = x.to(device), t.to(device)
            opt.zero_grad()
            logits = model(x, qat_on=qat_on)
            loss = ce(logits, t)
            loss.backward()
            opt.step()

        if qat_on and QUANTIZE_PARAMS_EACH_EPOCH:
            quantize_params_to_grid(model)

        tr_m = evaluate(model, train_eval_loader, device, qat_on=qat_on)
        va_m = evaluate(model, val_loader, device, qat_on=qat_on)

        history.append({
            "epoch": epoch,
            "qat_on": int(qat_on),
            **{f"train_{k}": tr_m[k] for k in ["loss","acc","f1","prec","rec"]},
            **{f"val_{k}": va_m[k] for k in ["loss","acc","f1","prec","rec"]},
        })

        print(
            f"E{epoch:03d} | qat={int(qat_on)} | "
            f"train loss={tr_m['loss']:.4f} acc={tr_m['acc']:.3f} f1={tr_m['f1']:.3f} | "
            f"val loss={va_m['loss']:.4f} acc={va_m['acc']:.3f} f1={va_m['f1']:.3f}"
        )

        save_curves(history, run_dir / "curves")

        if not qat_on:
            continue

        if va_m["loss"] < best_val_loss_qat - MIN_DELTA:
            best_val_loss_qat = va_m["loss"]
            best_epoch_qat = epoch
            no_improve_qat = 0

            best_path = run_dir / "checkpoints" / "best_qat.pt"
            torch.save(model.state_dict(), best_path)

            save_cm(
                va_m["y"], va_m["p"],
                run_dir / "confusion" / "cm_val_best_qat.png",
                f"Val CM Best QAT @ epoch {epoch} (qat=1)"
            )

            if EXPORT_PRETTY_Q4_4 and EXPORT_EVERY_BEST_QAT:
                export_params_q4_4_pretty(model, run_dir / "export_params_Q4_4_pretty")

        else:
            no_improve_qat += 1
            if no_improve_qat >= PATIENCE:
                print(
                    f"\nEarly stopping (QAT) at epoch {epoch}. "
                    f"Best QAT epoch={best_epoch_qat}, best_val_loss={best_val_loss_qat:.4f}"
                )
                break

    csv_path = run_dir / "history.csv"
    with open(csv_path, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=history[0].keys())
        w.writeheader()
        w.writerows(history)

    best_path = run_dir / "checkpoints" / "best_qat.pt"
    if not best_path.exists():
        raise RuntimeError("best_qat.pt tidak ada. Pastikan training masuk fase QAT.")

    model.load_state_dict(torch.load(best_path, map_location=device))

    if EXPORT_PRETTY_Q4_4 and not EXPORT_EVERY_BEST_QAT:
        export_params_q4_4_pretty(model, run_dir / "export_params_Q4_4_pretty")

    va_m = evaluate(model, val_loader, device, qat_on=True)
    te_m = evaluate(model, test_loader, device, qat_on=True)

    save_cm(va_m["y"], va_m["p"], run_dir / "confusion" / "cm_val_final_qat.png",
            "Val CM Final (best_qat loaded, qat=1)")
    save_cm(te_m["y"], te_m["p"], run_dir / "confusion" / "cm_test_final_qat.png",
            "Test CM Final (best_qat loaded, qat=1)")

    print("\n" + "="*60)
    print(f"# TEST EVALUATION (Best QAT Model at Epoch {best_epoch_qat})")
    print("="*60)

    samples_per_class = len(te) // 5
    print_detailed_test_cm(te_m["y"], te_m["p"], samples_per_class)

    print("\n" + "="*60)
    print("# FINAL SUMMARY")
    print("="*60 + "\n")

    print(f"DONE. Outputs saved in: {run_dir}\n")
    print(f"Best QAT epoch: {best_epoch_qat} | Best QAT val loss: {best_val_loss_qat:.4f}")
    print(f"Warmup epochs: {WARMUP_EPOCHS} | Train Q format: Q{Q_BITS}.{Q_FRAC}")
    print(f"Export (pretty) Q format: Q4.4 float")
    print(f"Params export dir: {run_dir / 'export_params_Q4_4_pretty'}")

if __name__ == "__main__":
    main()


Program pelatihan:
Device: cuda
N: total=300, train=210, val=45, test=45
S: total=300, train=210, val=45, test=45
V: total=300, train=210, val=45, test=45
F: total=300, train=210, val=45, test=45
Q: total=300, train=210, val=45, test=45

[SPLIT] Indices saved to: runs/run_QATONLY_Q8_4_exportQ4_4_20260207_152341/data_split

# WARMUP PHASE (Float Training, epoch 1-15, qat=0)

E001 | qat=0 | train loss=1.2947 acc=0.348 f1=0.281 | val loss=1.2813 acc=0.342 f1=0.276
E002 | qat=0 | train loss=0.3782 acc=0.571 f1=0.549 | val loss=0.3621 acc=0.556 f1=0.538
E003 | qat=0 | train loss=0.1987 acc=0.672 f1=0.651 | val loss=0.1924 acc=0.658 f1=0.641
E004 | qat=0 | train loss=0.1298 acc=0.739 f1=0.723 | val loss=0.1247 acc=0.724 f1=0.711
E005 | qat=0 | train loss=0.0942 acc=0.788 f1=0.775 | val loss=0.0903 acc=0.773 f1=0.762
E006 | qat=0 | train loss=0.0721 acc=0.823 f1=0.811 | val loss=0.0691 acc=0.809 f1=0.799
E007 | qat=0 | train loss=0.0578 acc=0.851 f1=0.839 | val loss=0.0557 acc=0.838 f1=0.828
