# Test Experiments

Frechet Inception Distance (FID) and Inception Score

In [15]:
from experiment_dir import set_cwd_project_root

set_cwd_project_root()

In [16]:
from runner import SketchTrainer, sample, device
from utils import top_k_filtering, top_p_filtering
from prepare_data import add_svg_properties, clean_svg, stroke_to_bezier_single
from main import load_config
import torch
from tqdm import tqdm
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
import torchvision
from raster_dataset import svg_rasterize
import numpy as np

torch.cuda.empty_cache()


def to_3ch_tensor(img_pil):
    arr = np.array(img_pil, dtype=np.float32) / 255.0  # H×W in [0,1]
    t = torch.from_numpy(arr).unsqueeze(0)  # 1×H×W
    return t.repeat(3, 1, 1)  # 3×H×W


def test_next_token_accuracy(sketch_trainer: SketchTrainer):
    model = sketch_trainer.model
    test_loader = sketch_trainer.test_loader
    use_padding_mask = sketch_trainer.use_padding_mask
    model.eval()
    test_token_accuracy = 0.0

    with torch.no_grad():
        for input_ids, target_ids, class_labels in tqdm(test_loader, desc="Testing"):
            input_ids = input_ids.to(device)
            target_ids = target_ids.to(device)
            class_labels = class_labels.to(device)

            if use_padding_mask:
                mask = input_ids == sketch_trainer.tokenizer.pad_token_id
                logits = model(input_ids, class_labels, src_key_padding_mask=mask)
            else:
                logits = model(input_ids, class_labels)

            preds = logits.argmax(dim=-1)
            mask = target_ids != sketch_trainer.tokenizer.pad_token_id
            correct = (preds[mask] == target_ids[mask]).float().sum()
            total = mask.sum()

            acc = (correct / total) if total > 0 else torch.tensor(0.0, device=device)
            test_token_accuracy += acc.item()

    avg_acc = test_token_accuracy / len(test_loader)
    print(f"Test Next Token Accuracy: {avg_acc:.4f}")


def generate_autoregressive(
    model,
    class_labels,
    tokenizer,
    max_len,
    device,
    temperature=0.8,
):
    """
    Autoregressively generate token sequences using SketchTransformer.

    Returns:
        seq: (batch, T) long tensor, starting with START, possibly ending with END.
    """
    start_id = tokenizer.vocab["START"]
    end_id = tokenizer.vocab["END"]

    model.eval()
    B = class_labels.size(0)
    seq = torch.full(
        (B, 1),
        start_id,
        dtype=torch.long,
        device=device,
    )

    finished = torch.zeros(B, dtype=torch.bool, device=device)

    for _ in range(max_len - 1):
        logits = model(seq, class_labels)
        next_logits = logits[:, -1, :]

        # Only use the logits at the last time step
        next_logits = logits[:, -1, :] / temperature  # (B, vocab_size)
        probs = torch.softmax(next_logits, dim=-1)
        next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)

        # Append next token
        seq = torch.cat([seq, next_tokens.unsqueeze(1)], dim=1)

        # Track which sequences have produced END
        finished |= next_tokens == end_id
        if finished.all():
            break

    return seq


def trim_ids_for_decode(ids, end_id, pad_id):
    """
    Trim a token list at the first END (inclusive) or PAD (exclusive), whichever comes first.
    """
    cut = len(ids)
    if end_id in ids:
        cut = min(cut, ids.index(end_id) + 1)  # keep END
    if pad_id in ids:
        cut = min(cut, ids.index(pad_id))  # drop PAD
    return ids[:cut]


def test_fid_inception_score(
    sketch_trainer: SketchTrainer,
    num_samples=None,
    bezier_postprocess=True,
    canvas_size=128,
):
    model = sketch_trainer.model
    test_loader = sketch_trainer.test_loader
    tokenizer = sketch_trainer.tokenizer
    pad_id = tokenizer.pad_token_id
    start_id = tokenizer.vocab["START"]
    end_id = tokenizer.vocab["END"]

    def trim_ids_for_decode(ids):
        cut = len(ids)
        if end_id in ids:
            cut = min(cut, ids.index(end_id) + 1)
        if pad_id in ids:
            cut = min(cut, ids.index(pad_id))
        return ids[:cut]

    # torchmetrics FID & IS
    fid = FrechetInceptionDistance(normalize=False).to(device)
    inception = InceptionScore(splits=10, normalize=False).to(device)

    all_real_images = []  # will hold all images across every batch
    all_fake_images = []

    if num_samples is None:
        num_samples = len(test_loader)

    it = list(test_loader)[:num_samples]

    model.eval()
    with torch.no_grad():
        for _, target_ids, class_labels in tqdm(it, desc="FID/IS"):
            class_labels = class_labels.to(device)
            max_len = model.max_len
            fake_seqs = generate_autoregressive(
                model=model,
                class_labels=class_labels,
                tokenizer=tokenizer,
                max_len=max_len,
                device=device,
            )
            fake_cpu = fake_seqs.cpu()

            real_batch = []
            fake_batch = []

            B = fake_cpu.size(0)
            for b in range(B):
                real_ids = [start_id] + [
                    t for t in target_ids[b].tolist() if t != pad_id
                ]
                real_ids = trim_ids_for_decode(real_ids)
                real_svg = tokenizer.decode(real_ids)
                if bezier_postprocess:
                    real_svg = stroke_to_bezier_single(real_svg)
                    real_svg = clean_svg(real_svg)
                real_svg = add_svg_properties(
                    real_svg, width=canvas_size, height=canvas_size
                )
                real_img = svg_rasterize(real_svg)  # PIL grayscale image
                r = to_3ch_tensor(real_img)

                fake_ids = [t for t in fake_cpu[b].tolist() if t != pad_id]
                fake_ids = trim_ids_for_decode(fake_ids)
                fake_svg = tokenizer.decode(fake_ids)
                if bezier_postprocess:
                    fake_svg = stroke_to_bezier_single(fake_svg)
                    fake_svg = clean_svg(fake_svg)
                fake_svg = add_svg_properties(
                    fake_svg, width=canvas_size, height=canvas_size
                )

                fake_img = svg_rasterize(fake_svg)
                f = to_3ch_tensor(fake_img)

                real_batch.append(r.unsqueeze(0))  # 1×3×H×W
                fake_batch.append(f.unsqueeze(0))

            all_real_images.extend(real_batch)
            all_fake_images.extend(fake_batch)

            real_images = torch.cat(real_batch, dim=0).to(
                device=device, dtype=torch.uint8
            )
            fake_images = torch.cat(fake_batch, dim=0).to(
                device=device, dtype=torch.uint8
            )

            fid.update(real_images, real=True)
            fid.update(fake_images, real=False)
            inception.update(fake_images)

    all_real_images = torch.cat(all_real_images, dim=0).cpu()
    all_fake_images = torch.cat(all_fake_images, dim=0).cpu()

    # save one image each
    torchvision.utils.save_image(all_real_images, "real_grid.png")
    torchvision.utils.save_image(all_fake_images, "fake_grid.png")

    fid_score = fid.compute().item()
    is_mean, is_std = inception.compute()
    is_mean = is_mean.item()
    is_std = is_std.item()

    sketch_trainer.writer.add_scalar("FID/Test", fid_score, 0)
    sketch_trainer.writer.add_scalar("IS/TestMean", is_mean, 0)
    sketch_trainer.writer.add_scalar("IS/TestStd", is_std, 0)

    print(f"Test FID: {fid_score:.4f}")
    print(f"Test Inception Score: mean={is_mean:.4f}, std={is_std:.4f}")

In [17]:
# %pip install --upgrade pip
# %pip install --upgrade Pillow

In [None]:
from dataset import QuickDrawDataset
from sketch_tokenizers import DeltaPenPositionTokenizer
from models import SketchTransformerConditional
from runner import SketchTrainer, sample

label_names = ["monkey", "fish", "sailboat", "skull", "whale"]
dataset = QuickDrawDataset(label_names=label_names, download=True)
tokenizer = DeltaPenPositionTokenizer(bins=16)

model = SketchTransformerConditional(
    vocab_size=len(tokenizer.vocab),
    d_model=384,
    nhead=8,
    num_layers=8,
    max_len=200,
    num_classes=len(label_names),
)

training_config = {
    "batch_size": 128,
    "num_epochs": 15,
    "learning_rate": 1e-4,
    "log_dir": "logs/sketch_transformer_experiment_5",
    "splits": [0.85, 0.1, 0.05],
    # "use_padding_mask": True,
    "checkpoint_path": "logs/sketch_transformer_experiment_5/20251115_192110/model_8.pt",
}

trainer = SketchTrainer(model, dataset, tokenizer, training_config)

Downloading QuickDrawDataset files: 100%|██████████| 5/5 [00:00<00:00, 10974.11it/s]
Loading QuickDrawDataset: 5it [00:00, 326.09it/s]
Tokenizing dataset: 100%|██████████| 5/5 [00:00<00:00, 618.30it/s]


Resumed training from checkpoint: logs/sketch_transformer_experiment_5/20251115_192110/model_8.pt


Initial Eval: 100%|██████████| 464/464 [00:02<00:00, 177.13it/s]


In [19]:
# trainer = load_config("configs/example_0.toml")

In [20]:
# test_next_token_accuracy(trainer)

In [None]:
test_fid_inception_score(trainer, 2)

FID/IS: 100%|██████████| 2/2 [00:08<00:00,  4.15s/it]


Test FID: 5.2216
Test Inception Score: mean=1.0596, std=0.0061


In [22]:
test_fid_inception_score(trainer, 20, bezier_postprocess=False)

FID/IS: 100%|██████████| 20/20 [01:29<00:00,  4.48s/it]


Test FID: 0.8337
Test Inception Score: mean=1.0513, std=0.0027


In [23]:
test_fid_inception_score(trainer, 20, canvas_size=64)

FID/IS: 100%|██████████| 20/20 [01:36<00:00,  4.81s/it]


Test FID: 0.2979
Test Inception Score: mean=1.0254, std=0.0007
