# Test Experiments

Frechet Inception Distance (FID) and Inception Score

In [2]:
from experiment_dir import set_cwd_project_root

set_cwd_project_root()

In [8]:
from runner import SketchTrainer, device
import torch
from tqdm import tqdm
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from raster_dataset import svg_rasterize
import numpy as np


def to_3ch_tensor(img_pil):
    arr = np.array(img_pil, dtype=np.uint8)          # H×W (grayscale)
    t = torch.from_numpy(arr).unsqueeze(0)           # 1×H×W
    return t.repeat(3, 1, 1)                         # 3×H×W


def test_it(sketch_trainer: SketchTrainer):
    model = sketch_trainer.model
    test_loader = sketch_trainer.test_loader
    use_padding_mask = sketch_trainer.use_padding_mask

    model.eval()
    test_token_accuracy = 0.0

    # ---------- 1) Next-token accuracy ----------
    with torch.no_grad():
        for input_ids, target_ids, class_labels in tqdm(test_loader, desc="Testing"):
            input_ids = input_ids.to(device)
            target_ids = target_ids.to(device)
            class_labels = class_labels.to(device)

            if use_padding_mask:
                mask = input_ids == sketch_trainer.tokenizer.pad_token_id
                logits = model(input_ids, class_labels, src_key_padding_mask=mask)
            else:
                logits = model(input_ids, class_labels)

            preds = logits.argmax(dim=-1)
            mask = target_ids != sketch_trainer.tokenizer.pad_token_id
            correct = (preds[mask] == target_ids[mask]).float().sum()
            total = mask.sum()

            acc = (correct / total) if total > 0 else torch.tensor(0.0, device=device)
            test_token_accuracy += acc.item()

    avg_acc = test_token_accuracy / len(test_loader)
    print(f"Test Next Token Accuracy: {avg_acc:.4f}")

    # ---------- 2) FID + IS ----------
    start_id = sketch_trainer.tokenizer.vocab["START"]
    end_id = sketch_trainer.tokenizer.vocab["END"]

    def _trim_at_end(ids):
        if end_id in ids:
            idx = ids.index(end_id)
            return ids[: idx + 1]
        return ids

    # torchmetrics FID: expects 3×HxW uint8 or float with [0,1] if normalize=True
    fid = FrechetInceptionDistance(normalize=False).to(device)
    inception = InceptionScore(splits=10, normalize=False).to(device)

    model.eval()
    with torch.no_grad():
        for input_ids, target_ids, class_labels in tqdm(test_loader, desc="FID/IS"):
            input_ids = input_ids.to(device)
            target_ids = target_ids.to(device)
            class_labels = class_labels.to(device)

            if use_padding_mask:
                mask = input_ids == sketch_trainer.tokenizer.pad_token_id
                logits = model(input_ids, class_labels, src_key_padding_mask=mask)
            else:
                logits = model(input_ids, class_labels)

            preds = logits.argmax(dim=-1).cpu()
            targets_cpu = target_ids.cpu()

            real_batch = []
            fake_batch = []

            B = preds.size(0)
            for b in range(B):
                # REAL sequence
                real_ids = [start_id] + targets_cpu[b].tolist()
                real_ids = _trim_at_end(real_ids)
                real_svg = sketch_trainer.tokenizer.decode(real_ids)
                real_img = svg_rasterize(real_svg)     # PIL, grayscale 299×299
                r = to_3ch_tensor(real_img)            # 3×299×299

                # FAKE sequence
                fake_ids = [start_id] + preds[b].tolist()
                fake_ids = _trim_at_end(fake_ids)
                fake_svg = sketch_trainer.tokenizer.decode(fake_ids)
                fake_img = svg_rasterize(fake_svg)
                f = to_3ch_tensor(fake_img)

                real_batch.append(r.unsqueeze(0))      # 1×3×H×W
                fake_batch.append(f.unsqueeze(0))

            real_images = torch.cat(real_batch, dim=0).to(device)  # B×3×H×W
            fake_images = torch.cat(fake_batch, dim=0).to(device)

            fid.update(real_images, real=True)
            fid.update(fake_images, real=False)
            inception.update(fake_images)

    fid_score = fid.compute().item()
    is_mean, is_std = inception.compute()
    is_mean = is_mean.item()
    is_std = is_std.item()

    sketch_trainer.writer.add_scalar("FID/Test", fid_score, 0)
    sketch_trainer.writer.add_scalar("IS/TestMean", is_mean, 0)
    sketch_trainer.writer.add_scalar("IS/TestStd", is_std, 0)

    print(f"Test FID: {fid_score:.4f}")
    print(f"Test Inception Score: mean={is_mean:.4f}, std={is_std:.4f}")



In [None]:
trainer = load_config("configs/example_0.toml")
test_it(trainer)

Downloading QuickDrawDataset files: 100%|██████████| 1/1 [00:00<00:00, 1262.96it/s]
Loading QuickDrawDataset: 1it [00:00, 1047.27it/s]
Tokenizing dataset: 100%|██████████| 1/1 [00:00<00:00, 2928.98it/s]


No checkpoint found, starting fresh training.


Initial Eval: 100%|██████████| 81/81 [00:00<00:00, 246.72it/s]
Testing: 100%|██████████| 41/41 [00:01<00:00, 24.74it/s]


Test Next Token Accuracy: 0.0009


FID/IS: 100%|██████████| 41/41 [00:27<00:00,  1.49it/s]


Test FID: 220.4308
Test Inception Score: mean=1.9268, std=0.0229


In [6]:
from runner import SketchTrainer, device
import torch
from tqdm import tqdm
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from raster_dataset import svg_rasterize
import numpy as np

def to_3ch_tensor(img_pil):
    arr = np.array(img_pil, dtype=np.float32) / 255.0  # H×W in [0,1]
    t = torch.from_numpy(arr).unsqueeze(0)              # 1×H×W
    return t.repeat(3, 1, 1)                            # 3×H×W

def generate_autoregressive(
    model,
    class_labels,
    tokenizer,
    max_len,
    use_padding_mask: bool,
    device,
):
    """
    Autoregressively generate token sequences using SketchTransformer.

    Returns:
        seq: (batch, T) long tensor, starting with START, possibly ending with END.
    """
    pad_id   = tokenizer.pad_token_id
    start_id = tokenizer.vocab["START"]
    end_id   = tokenizer.vocab["END"]

    model.eval()
    B = class_labels.size(0)

    # Start with START only
    seq = torch.full(
        (B, 1),
        start_id,
        dtype=torch.long,
        device=device,
    )

    finished = torch.zeros(B, dtype=torch.bool, device=device)

    for _ in range(max_len - 1):
        if use_padding_mask:
            pad_mask = seq == pad_id  # (B, T)
            logits = model(seq, class_labels, src_key_padding_mask=pad_mask)
        else:
            logits = model(seq, class_labels)

        # Only use the logits at the last time step
        next_logits = logits[:, -1, :]   # (B, vocab_size)
        next_tokens = next_logits.argmax(dim=-1)  # (B,)

        # Append next token
        seq = torch.cat([seq, next_tokens.unsqueeze(1)], dim=1)

        # Track which sequences have produced END
        finished |= (next_tokens == end_id)
        if finished.all():
            break

    return seq

def trim_ids_for_decode(ids, end_id, pad_id):
    """
    Trim a token list at the first END (inclusive) or PAD (exclusive), whichever comes first.
    """
    cut = len(ids)
    if end_id in ids:
        cut = min(cut, ids.index(end_id) + 1)  # keep END
    if pad_id in ids:
        cut = min(cut, ids.index(pad_id))      # drop PAD
    return ids[:cut]

def test_it(sketch_trainer: SketchTrainer):
    model = sketch_trainer.model
    test_loader = sketch_trainer.test_loader
    use_padding_mask = sketch_trainer.use_padding_mask

    model.eval()
    test_token_accuracy = 0.0

    # ---------- 1) Next-token accuracy (unchanged) ----------
    with torch.no_grad():
        for input_ids, target_ids, class_labels in tqdm(test_loader, desc="Testing"):
            input_ids = input_ids.to(device)
            target_ids = target_ids.to(device)
            class_labels = class_labels.to(device)

            if use_padding_mask:
                mask = input_ids == sketch_trainer.tokenizer.pad_token_id
                logits = model(input_ids, class_labels, src_key_padding_mask=mask)
            else:
                logits = model(input_ids, class_labels)

            preds = logits.argmax(dim=-1)
            mask = target_ids != sketch_trainer.tokenizer.pad_token_id
            correct = (preds[mask] == target_ids[mask]).float().sum()
            total = mask.sum()

            acc = (correct / total) if total > 0 else torch.tensor(0.0, device=device)
            test_token_accuracy += acc.item()

    avg_acc = test_token_accuracy / len(test_loader)
    print(f"Test Next Token Accuracy: {avg_acc:.4f}")

    # ---------- 2) FID + IS with proper generation ----------
    tokenizer = sketch_trainer.tokenizer
    pad_id   = tokenizer.pad_token_id
    start_id = tokenizer.vocab["START"]
    end_id   = tokenizer.vocab["END"]

    def trim_ids_for_decode(ids):
        cut = len(ids)
        if end_id in ids:
            cut = min(cut, ids.index(end_id) + 1)
        if pad_id in ids:
            cut = min(cut, ids.index(pad_id))
        return ids[:cut]

    # torchmetrics FID & IS
    fid = FrechetInceptionDistance(normalize=False).to(device)
    inception = InceptionScore(splits=10, normalize=False).to(device)

    model.eval()
    with torch.no_grad():
        for _, target_ids, class_labels in tqdm(test_loader, desc="FID/IS"):
            target_ids = target_ids.to(device)
            class_labels = class_labels.to(device)

            # 1) REAL sequences: from dataset targets
            targets_cpu = target_ids.cpu()

            # 2) FAKE sequences: from *autoregressive* generation
            max_len = getattr(sketch_trainer, "max_len", model.max_len)
            fake_seqs = generate_autoregressive(
                model=model,
                class_labels=class_labels,
                tokenizer=tokenizer,
                max_len=max_len,
                use_padding_mask=use_padding_mask,
                device=device,
            )
            fake_cpu = fake_seqs.cpu()

            real_batch = []
            fake_batch = []

            B = fake_cpu.size(0)
            for b in range(B):
                # ----- REAL -----
                # If your targets already include START, don't add another START here.
                # This version assumes targets DO NOT include START.
                real_ids = [start_id] + [
                    t for t in targets_cpu[b].tolist() if t != pad_id
                ]
                real_ids = trim_ids_for_decode(real_ids)
                real_svg = tokenizer.decode(real_ids)
                real_img = svg_rasterize(real_svg)   # PIL grayscale 299×299 (or similar)
                r = to_3ch_tensor(real_img)

                # ----- FAKE -----
                fake_ids = [t for t in fake_cpu[b].tolist() if t != pad_id]
                fake_ids = trim_ids_for_decode(fake_ids)
                fake_svg = tokenizer.decode(fake_ids)
                fake_img = svg_rasterize(fake_svg)
                f = to_3ch_tensor(fake_img)

                real_batch.append(r.unsqueeze(0))  # 1×3×H×W
                fake_batch.append(f.unsqueeze(0))

            real_images = torch.cat(real_batch, dim=0).to(device=device, dtype=torch.uint8)
            fake_images = torch.cat(fake_batch, dim=0).to(device=device, dtype=torch.uint8)

            fid.update(real_images, real=True)
            fid.update(fake_images, real=False)
            inception.update(fake_images)

    fid_score = fid.compute().item()
    is_mean, is_std = inception.compute()
    is_mean = is_mean.item()
    is_std = is_std.item()

    sketch_trainer.writer.add_scalar("FID/Test", fid_score, 0)
    sketch_trainer.writer.add_scalar("IS/TestMean", is_mean, 0)
    sketch_trainer.writer.add_scalar("IS/TestStd", is_std, 0)

    print(f"Test FID: {fid_score:.4f}")
    print(f"Test Inception Score: mean={is_mean:.4f}, std={is_std:.4f}")

In [None]:
from main import load_config, SketchTrainer

trainer = load_config("configs/example_0.toml")
test_it(trainer)

Downloading QuickDrawDataset files: 100%|██████████| 1/1 [00:00<00:00, 1099.71it/s]
Loading QuickDrawDataset: 1it [00:00, 963.32it/s]
Tokenizing dataset: 100%|██████████| 1/1 [00:00<00:00, 3111.50it/s]


No checkpoint found, starting fresh training.


Initial Eval: 100%|██████████| 81/81 [00:00<00:00, 438.76it/s]
Testing: 100%|██████████| 41/41 [00:01<00:00, 24.96it/s]


Test Next Token Accuracy: 0.0008


FID/IS: 100%|██████████| 41/41 [02:46<00:00,  4.06s/it]


Test FID: 10.1799
Test Inception Score: mean=1.0000, std=0.0000
