# Test Experiments

Frechet Inception Distance (FID) and Inception Score

In [9]:
from experiment_dir import set_cwd_project_root

set_cwd_project_root()

In [10]:
from runner import SketchTrainer, sample, device
from prepare_data import add_svg_properties, clean_svg, stroke_to_bezier_single
import torch
from tqdm import tqdm
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
import torchvision
from raster_dataset import svg_rasterize
import numpy as np

torch.cuda.empty_cache()

In [11]:
def to_3ch_tensor(img_pil):
    arr = np.array(img_pil, dtype=np.float32) / 255.0  # H×W in [0,1]
    t = torch.from_numpy(arr).unsqueeze(0)  # 1×H×W
    return t.repeat(3, 1, 1)  # 3×H×W


def test_next_token_accuracy(sketch_trainer: SketchTrainer):
    model = sketch_trainer.model
    test_loader = sketch_trainer.test_loader
    use_padding_mask = sketch_trainer.use_padding_mask
    model.eval()
    test_token_accuracy = 0.0

    with torch.no_grad():
        for input_ids, target_ids, class_labels in tqdm(test_loader, desc="Testing"):
            input_ids = input_ids.to(device)
            target_ids = target_ids.to(device)
            class_labels = class_labels.to(device)

            if use_padding_mask:
                mask = input_ids == sketch_trainer.tokenizer.pad_token_id
                logits = model(input_ids, class_labels, src_key_padding_mask=mask)
            else:
                logits = model(input_ids, class_labels)

            preds = logits.argmax(dim=-1)
            mask = target_ids != sketch_trainer.tokenizer.pad_token_id
            correct = (preds[mask] == target_ids[mask]).float().sum()
            total = mask.sum()

            acc = (correct / total) if total > 0 else torch.tensor(0.0, device=device)
            test_token_accuracy += acc.item()

    avg_acc = test_token_accuracy / len(test_loader)
    print(f"Test Next Token Accuracy: {avg_acc:.4f}")


def generate_autoregressive(
    model,
    class_labels,
    tokenizer,
    max_len,
    device,
    temperature=0.8,
):
    """
    Autoregressively generate token sequences using SketchTransformer.

    Returns:
        seq: (batch, T) long tensor, starting with START, possibly ending with END.
    """
    start_id = tokenizer.vocab["START"]
    end_id = tokenizer.vocab["END"]

    model.eval()
    B = class_labels.size(0)
    seq = torch.full(
        (B, 1),
        start_id,
        dtype=torch.long,
        device=device,
    )

    finished = torch.zeros(B, dtype=torch.bool, device=device)

    for _ in range(max_len - 1):
        logits = model(seq, class_labels)
        next_logits = logits[:, -1, :]

        # Only use the logits at the last time step
        next_logits = logits[:, -1, :] / temperature  # (B, vocab_size)
        probs = torch.softmax(next_logits, dim=-1)
        next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)

        # Append next token
        seq = torch.cat([seq, next_tokens.unsqueeze(1)], dim=1)

        # Track which sequences have produced END
        finished |= next_tokens == end_id
        if finished.all():
            break

    return seq


def trim_ids_for_decode(ids, end_id, pad_id):
    """
    Trim a token list at the first END (inclusive) or PAD (exclusive), whichever comes first.
    """
    cut = len(ids)
    if end_id in ids:
        cut = min(cut, ids.index(end_id) + 1)  # keep END
    if pad_id in ids:
        cut = min(cut, ids.index(pad_id))  # drop PAD
    return ids[:cut]


def test_fid_inception_score(
    sketch_trainer: SketchTrainer,
    num_samples=None,
    bezier_postprocess=True,
    canvas_size=128,
):
    model = sketch_trainer.model
    test_loader = sketch_trainer.test_loader
    tokenizer = sketch_trainer.tokenizer
    pad_id = tokenizer.pad_token_id
    start_id = tokenizer.vocab["START"]
    end_id = tokenizer.vocab["END"]

    def trim_ids_for_decode(ids):
        cut = len(ids)
        if end_id in ids:
            cut = min(cut, ids.index(end_id) + 1)
        if pad_id in ids:
            cut = min(cut, ids.index(pad_id))
        return ids[:cut]

    # torchmetrics FID & IS
    fid = FrechetInceptionDistance(normalize=False).to(device)
    inception = InceptionScore(splits=10, normalize=False).to(device)

    all_real_images = []  # will hold all images across every batch
    all_fake_images = []

    if num_samples is None:
        num_samples = len(test_loader)

    it = list(test_loader)[:num_samples]

    model.eval()
    with torch.no_grad():
        for _, target_ids, class_labels in tqdm(it, desc="FID/IS"):
            class_labels = class_labels.to(device)
            max_len = model.max_len
            fake_seqs = generate_autoregressive(
                model=model,
                class_labels=class_labels,
                tokenizer=tokenizer,
                max_len=max_len,
                device=device,
            )
            fake_cpu = fake_seqs.cpu()

            real_batch = []
            fake_batch = []

            B = fake_cpu.size(0)
            for b in range(B):
                real_ids = [start_id] + [
                    t for t in target_ids[b].tolist() if t != pad_id
                ]
                real_ids = trim_ids_for_decode(real_ids)
                real_svg = tokenizer.decode(real_ids)
                if bezier_postprocess:
                    real_svg = stroke_to_bezier_single(real_svg)
                    real_svg = clean_svg(real_svg)
                real_svg = add_svg_properties(
                    real_svg, width=canvas_size, height=canvas_size
                )
                real_img = svg_rasterize(real_svg)  # PIL grayscale image
                r = to_3ch_tensor(real_img)

                fake_ids = [t for t in fake_cpu[b].tolist() if t != pad_id]
                fake_ids = trim_ids_for_decode(fake_ids)
                fake_svg = tokenizer.decode(fake_ids)
                if bezier_postprocess:
                    fake_svg = stroke_to_bezier_single(fake_svg)
                    fake_svg = clean_svg(fake_svg)
                fake_svg = add_svg_properties(
                    fake_svg, width=canvas_size, height=canvas_size
                )

                fake_img = svg_rasterize(fake_svg)
                f = to_3ch_tensor(fake_img)

                real_batch.append(r.unsqueeze(0))  # 1×3×H×W
                fake_batch.append(f.unsqueeze(0))

            all_real_images.extend(real_batch)
            all_fake_images.extend(fake_batch)

            real_images = torch.cat(real_batch, dim=0).to(
                device=device, dtype=torch.uint8
            )
            fake_images = torch.cat(fake_batch, dim=0).to(
                device=device, dtype=torch.uint8
            )

            fid.update(real_images, real=True)
            fid.update(fake_images, real=False)
            inception.update(fake_images)

    all_real_images = torch.cat(all_real_images, dim=0).cpu()
    all_fake_images = torch.cat(all_fake_images, dim=0).cpu()

    # save one image each
    torchvision.utils.save_image(all_real_images, "real_grid.png")
    torchvision.utils.save_image(all_fake_images, "fake_grid.png")

    fid_score = fid.compute().item()
    is_mean, is_std = inception.compute()
    is_mean = is_mean.item()
    is_std = is_std.item()

    sketch_trainer.writer.add_scalar("FID/Test", fid_score, 0)
    sketch_trainer.writer.add_scalar("IS/TestMean", is_mean, 0)
    sketch_trainer.writer.add_scalar("IS/TestStd", is_std, 0)

    print(f"Test FID: {fid_score:.4f}")
    print(f"Test Inception Score: mean={is_mean:.4f}, std={is_std:.4f}")

In [12]:
from transformers import CLIPProcessor, CLIPModel

def test_clip_score(
    sketch_trainer: SketchTrainer,
    dataset,
    num_samples=None,
    canvas_size=512,
    bezier_postprocess=True,
):
    model = sketch_trainer.model
    test_loader = sketch_trainer.test_loader
    tokenizer = sketch_trainer.tokenizer
    pad_id = tokenizer.pad_token_id
    start_id = tokenizer.vocab["START"]
    end_id = tokenizer.vocab["END"]

    def trim_ids_for_decode(ids):
        cut = len(ids)
        if end_id in ids:
            cut = min(cut, ids.index(end_id) + 1)
        if pad_id in ids:
            cut = min(cut, ids.index(pad_id))
        return ids[:cut]

    device = next(model.parameters()).device

    # Load CLIP
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

    real_scores = []
    fake_scores = []

    # num_samples = number of batches to use
    if num_samples is None:
        num_samples = len(test_loader)

    it = list(test_loader)[:num_samples]

    model.eval()
    clip_model.eval()

    with torch.no_grad():
        for _, target_ids, class_labels in tqdm(it, desc="CLIP Score"):
            class_labels = class_labels.to(device)
            max_len = model.max_len

            fake_seqs = generate_autoregressive(
                model=model,
                class_labels=class_labels,
                tokenizer=tokenizer,
                max_len=max_len,
                device=device,
            )
            fake_cpu = fake_seqs.cpu()

            # Convert numeric labels → strings
            text_labels = [dataset.label_map[c.item()] for c in class_labels]
            text = [f"a {lbl} sketch" for lbl in text_labels]

            # Encode text using CLIP
            text_inputs = processor(text=text, return_tensors="pt", padding=True).to(device)
            text_embeds = clip_model.get_text_features(**text_inputs)
            text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)

            B = fake_cpu.size(0)

            for b in range(B):
                # --- REAL ---
                real_ids = [start_id] + [t for t in target_ids[b].tolist() if t != pad_id]
                real_ids = trim_ids_for_decode(real_ids)
                real_svg = tokenizer.decode(real_ids)

                if bezier_postprocess:
                    real_svg = stroke_to_bezier_single(real_svg)
                    real_svg = clean_svg(real_svg)

                real_svg = add_svg_properties(real_svg, width=canvas_size, height=canvas_size)
                real_img = svg_rasterize(real_svg).convert("RGB")

                # --- FAKE ---
                fake_ids = [t for t in fake_cpu[b].tolist() if t != pad_id]
                fake_ids = trim_ids_for_decode(fake_ids)
                fake_svg = tokenizer.decode(fake_ids)

                if bezier_postprocess:
                    fake_svg = stroke_to_bezier_single(fake_svg)
                    fake_svg = clean_svg(fake_svg)

                fake_svg = add_svg_properties(fake_svg, width=canvas_size, height=canvas_size)
                fake_img = svg_rasterize(fake_svg).convert("RGB")

                # Encode both images at once
                image_inputs = processor(images=[real_img, fake_img], return_tensors="pt").to(device)
                image_embeds = clip_model.get_image_features(**image_inputs)
                image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)

                # Compute cosine similarity
                # image_embeds: (2, D)
                # text_embeds: (B, D)
                # We want the score of this sample only => index b
                scores = (image_embeds @ text_embeds[b].unsqueeze(1)).squeeze().tolist()

                real_clip_score = scores[0]
                fake_clip_score = scores[1]

                real_scores.append(real_clip_score)
                fake_scores.append(fake_clip_score)

    avg_real_score = sum(real_scores) / len(real_scores)
    avg_fake_score = sum(fake_scores) / len(fake_scores)

    print(f"Average Real CLIP Score: {avg_real_score:.4f}")
    print(f"Average Fake CLIP Score: {avg_fake_score:.4f}")


In [13]:
# %pip install --upgrade pip
# %pip install --upgrade Pillow

In [None]:
from dataset import QuickDrawDataset
from sketch_tokenizers import DeltaPenPositionTokenizer
from models import SketchTransformerConditional
from runner import SketchTrainer, sample

label_names = ["cake" ,"butterfly","flower","mug","sea turtle"]
dataset = QuickDrawDataset(label_names=label_names, download=True)
tokenizer = DeltaPenPositionTokenizer(bins=32)

model = SketchTransformerConditional(
    vocab_size=len(tokenizer.vocab),
    d_model=512,
    nhead=8,
    num_layers=8,
    max_len=200,
    num_classes=len(label_names),
)

training_config = {
    "batch_size": 128,
    "num_epochs": 15,
    "learning_rate": 1e-4,
    "log_dir": "logs/tes1",
    "splits": [0.85, 0.075, 0.075],
    # "use_padding_mask": True,
    "checkpoint_path": "_site/model_checkpoint_15.pt",
}

trainer = SketchTrainer(model, dataset, tokenizer, training_config)

Downloading QuickDrawDataset files: 100%|██████████| 5/5 [00:00<00:00, 11416.18it/s]
Loading QuickDrawDataset: 5it [00:00, 338.58it/s]
Tokenizing dataset: 100%|██████████| 5/5 [00:00<00:00, 629.34it/s]


Resumed training from checkpoint: C:\Code\Generative-SVG\_site\model_checkpoint_15.pt


In [15]:
# test_next_token_accuracy(trainer)

In [16]:
test_fid_inception_score(trainer, 20)

FID/IS: 100%|██████████| 20/20 [03:28<00:00, 10.44s/it]


Test FID: 1.3663
Test Inception Score: mean=1.0852, std=0.0064


In [17]:
test_fid_inception_score(trainer, 20, bezier_postprocess=False)

FID/IS: 100%|██████████| 20/20 [03:55<00:00, 11.79s/it]


Test FID: 1.4441
Test Inception Score: mean=1.0668, std=0.0034


In [18]:
test_fid_inception_score(trainer, 20, canvas_size=64)

FID/IS: 100%|██████████| 20/20 [03:31<00:00, 10.58s/it]


Test FID: 0.3535
Test Inception Score: mean=1.0294, std=0.0010


In [19]:
test_clip_score(trainer, dataset, 20)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
CLIP Score: 100%|██████████| 20/20 [05:13<00:00, 15.67s/it]

Average Real CLIP Score: 0.2429
Average Fake CLIP Score: 0.2406



