# Test Experiments

Frechet Inception Distance (FID) and Inception Score

In [1]:
from experiment_dir import set_cwd_project_root

set_cwd_project_root()

In [None]:
from runner import SketchTrainer, sample, device
from main import load_config
import torch
from tqdm import tqdm
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
import torchvision
from raster_dataset import svg_rasterize
import numpy as np


torch.cuda.empty_cache()


def to_3ch_tensor(img_pil):
    arr = np.array(img_pil, dtype=np.uint8)          # H×W (grayscale)
    t = torch.from_numpy(arr).unsqueeze(0)           # 1×H×W
    return t.repeat(3, 1, 1)                         # 3×H×W


def test_next_token_accuracy(sketch_trainer: SketchTrainer):
    model = sketch_trainer.model
    test_loader = sketch_trainer.test_loader
    use_padding_mask = sketch_trainer.use_padding_mask

    model.eval()
    test_token_accuracy = 0.0

    # ---------- 1) Next-token accuracy ----------
    with torch.no_grad():
        for input_ids, target_ids, class_labels in tqdm(test_loader, desc="Testing"):
            input_ids = input_ids.to(device)
            target_ids = target_ids.to(device)
            class_labels = class_labels.to(device)

            if use_padding_mask:
                mask = input_ids == sketch_trainer.tokenizer.pad_token_id
                logits = model(input_ids, class_labels, src_key_padding_mask=mask)
            else:
                logits = model(input_ids, class_labels)

            preds = logits.argmax(dim=-1)
            mask = target_ids != sketch_trainer.tokenizer.pad_token_id
            correct = (preds[mask] == target_ids[mask]).float().sum()
            total = mask.sum()

            acc = (correct / total) if total > 0 else torch.tensor(0.0, device=device)
            test_token_accuracy += acc.item()

    avg_acc = test_token_accuracy / len(test_loader)
    print(f"Test Next Token Accuracy: {avg_acc:.4f}")



def test_fid_inception_score(sketch_trainer: SketchTrainer):
    test_loader = sketch_trainer.test_loader
    start_id = sketch_trainer.tokenizer.vocab["START"]
    end_id = sketch_trainer.tokenizer.vocab["END"]
    model = sketch_trainer.model

    generations = []

    for _, _, class_labels in tqdm(
        test_loader, desc="Generating Test Samples"
    ):
        sample_generations = []
        for j in range(class_labels.size(0)):
            sample_generations.append(sample(
                model,
                start_tokens=[start_id],
                eos_id=end_id,
                class_label=class_labels[j].unsqueeze(0),
                temperature=0.8,
                top_k=20,
                top_p=0.7,
                greedy=False,
            ))
        generations.append(sample_generations)

    # torchmetrics FID: expects 3×HxW uint8 or float with [0,1] if normalize=True
    fid = FrechetInceptionDistance(normalize=False).to(device)
    inception = InceptionScore(splits=10, normalize=False).to(device)

    # Each sketch in the test set is rasterized and compared to a sketch which was autoregressively generated

    # TODO: maybe only sample the first N sketches from each set / batch
    for i, item in enumerate(test_loader):
        real_batch = []
        generations_batch = []

        for j in range(len(item)):
            real_ids = [start_id] + item[0][j].tolist()
            real_svg = sketch_trainer.tokenizer.decode(real_ids)
            real_img = svg_rasterize(real_svg)
            r = to_3ch_tensor(real_img)

            generation_ids = generations[i][j].tolist()
            generation_svg = sketch_trainer.tokenizer.decode(generation_ids)
            generation_img = svg_rasterize(generation_svg)
            g = to_3ch_tensor(generation_img)

            real_batch.append(r.unsqueeze(0))
            generations_batch.append(g.unsqueeze(0))
        
        real_images = torch.cat(real_batch, dim=0).to(device)  # B×3×H×W
        generated_images = torch.cat(generations_batch, dim=0).to(device)

        # create and show a grid
        real_grid = torchvision.utils.make_grid(real_images, nrow=32)
        generated_grid = torchvision.utils.make_grid(generated_images, nrow=32)

        torchvision.utils.save_image(real_grid, "real_grid.png")
        torchvision.utils.save_image(generated_grid, "generated_grid.png")

        fid.update(real_images, real=True)
        fid.update(generated_images, real=False)
        inception.update(generated_images)

    fid_score = fid.compute().item()
    is_mean, is_std = inception.compute()
    is_mean = is_mean.item()
    is_std = is_std.item()

    sketch_trainer.writer.add_scalar("FID/Test", fid_score, 0)
    sketch_trainer.writer.add_scalar("IS/TestMean", is_mean, 0)
    sketch_trainer.writer.add_scalar("IS/TestStd", is_std, 0)

    print(f"Test FID: {fid_score:.4f}")
    print(f"Test Inception Score: mean={is_mean:.4f}, std={is_std:.4f}")



In [3]:
# %pip install --upgrade pip
# %pip install --upgrade Pillow

In [4]:
from dataset import QuickDrawDataset
from sketch_tokenizers import DeltaPenPositionTokenizer
from models import SketchTransformerConditional
from runner import SketchTrainer, sample
from prepare_data import stroke_to_bezier_single, clean_svg

label_names = ["bird", "crab", "guitar", "donut", "whale", "penguin", "skull", "fish"]
dataset = QuickDrawDataset(label_names=label_names, download=True)
tokenizer = DeltaPenPositionTokenizer(bins=16)

model = SketchTransformerConditional(
    vocab_size=len(tokenizer.vocab),
    d_model=512,
    nhead=8,
    num_layers=8,
    max_len=200,
    num_classes=len(label_names),
)

training_config = {
    "batch_size": 128,
    "num_epochs": 15,
    "learning_rate": 1e-4,
    "log_dir": "logs/sketch_transformer_experiment_3",
    "splits": [0.85, 0.1, 0.05],
    # "use_padding_mask": True,
    "checkpoint_path": "logs/sketch_transformer_experiment_3/20251112_125146/model_3.pt",
}

trainer = SketchTrainer(model, dataset, tokenizer, training_config)

Downloading QuickDrawDataset files: 100%|██████████| 8/8 [00:00<00:00, 1163.67it/s]
Loading QuickDrawDataset: 8it [00:00, 338.09it/s]
Tokenizing dataset: 100%|██████████| 8/8 [00:00<00:00, 1808.28it/s]


Resumed training from checkpoint: logs/sketch_transformer_experiment_3/20251112_125146/model_3.pt


Initial Eval: 100%|██████████| 824/824 [00:02<00:00, 332.75it/s]


In [5]:
# trainer = load_config("configs/example_0.toml")

In [6]:
test_next_token_accuracy(trainer)

Testing: 100%|██████████| 412/412 [00:32<00:00, 12.56it/s]

Test Next Token Accuracy: 0.3378





In [9]:
test_fid_inception_score(trainer)

Generating Test Samples:   1%|          | 4/412 [01:36<2:44:39, 24.22s/it]


KeyboardInterrupt: 