In [1]:
# import torch
# from transformers import pipeline

# pipeline = pipeline(
#     task="image-classification",
#     model="google/vit-base-patch16-224",
#     device=-1,
# )
# pipeline(
#     "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
# )

Step-by-Step: Fine-Tuning ViT on Quick, Draw!
1. Install and Load Data

In [2]:
# from quickdraw import QuickDrawDataGroup

# # Example: load 1000 recognized circle drawings
# qdg = QuickDrawDataGroup("circle", recognized=True, max_drawings=1000)
# images = [d.image.convert("RGB") for d in qdg.drawings]

2. Label & Organize

Assign numeric labels for each category:

In [3]:
from quickdraw import QuickDrawDataGroup

categories = ["circle", "square", "line"]  # customize as needed
label_map = {name: i for i, name in enumerate(categories)}

# Collect images and labels
all_images, all_labels = [], []
for name in categories:
    group = QuickDrawDataGroup(name, recognized=True, max_drawings=300)
    for d in group.drawings:
        all_images.append(d.image.convert("RGB"))
        all_labels.append(label_map[name])

loading circle drawings
load complete
loading square drawings
load complete
downloading line from https://storage.googleapis.com/quickdraw_dataset/full/binary/line.bin
download complete
loading line drawings
load complete


3. Create Dataset (via Hugging Face datasets)

In [4]:
import os
from PIL import Image

output_dir = "shapes_dataset/images"
os.makedirs(output_dir, exist_ok=True)

paths, labels = [], []

for i, (img, label) in enumerate(zip(all_images, all_labels)):
    path = os.path.join(output_dir, f"img_{i}.png")
    img.save(path)
    paths.append(path)
    labels.append(label)

from datasets import Dataset
import pandas as pd

df = pd.DataFrame({"image": paths, "label": labels})  # Use paths instead of all_images
ds = Dataset.from_pandas(df)
ds = ds.train_test_split(test_size=0.2)

  from .autonotebook import tqdm as notebook_tqdm


4. Preprocess with AutoImageProcessor

In [5]:
from transformers import AutoImageProcessor
from PIL import Image

processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")

def preprocess_function(examples):
    # Load images from paths
    images = [Image.open(img_path).convert("RGB") for img_path in examples["image"]]
    # Process images and return pixel values
    inputs = processor(images, return_tensors="pt")
    examples["pixel_values"] = inputs["pixel_values"]
    return examples

# Apply preprocessing to the entire dataset
ds = ds.map(preprocess_function, batched=True, batch_size=32)

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.
Map: 100%|██████████| 720/720 [00:06<00:00, 119.86 examples/s]
Map: 100%|██████████| 180/180 [00:01<00:00, 126.83 examples/s]


5. Load Model for Classification

In [6]:
from transformers import AutoModelForImageClassification

model = AutoModelForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=len(categories),
    id2label={i: name for name, i in label_map.items()},
    label2id=label_map,
    ignore_mismatched_sizes=True  # Ignore size mismatch for classification head
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([3]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([3, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


6. Train with Trainer API

In [7]:
import torch
from transformers import TrainingArguments, Trainer, DefaultDataCollator

# Use the default data collator which handles tensors properly
data_collator = DefaultDataCollator()

training_args = TrainingArguments(
    output_dir="./vit-quickdraw",
    eval_strategy="epoch",  # Changed from evaluation_strategy
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    learning_rate=5e-5,
    logging_steps=10,
    save_strategy="epoch"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds["train"],
    eval_dataset=ds["test"],
    tokenizer=processor,
    data_collator=data_collator
)

trainer.train()

  trainer = Trainer(


Epoch,Training Loss,Validation Loss
1,0.0467,0.011844
2,0.0012,0.017189
3,0.0008,0.01451


TrainOutput(global_step=135, training_loss=0.04382045715012484, metrics={'train_runtime': 1475.4562, 'train_samples_per_second': 1.464, 'train_steps_per_second': 0.091, 'total_flos': 1.6738419776569344e+17, 'train_loss': 0.04382045715012484, 'epoch': 3.0})

7. Evaluate & Use

In [8]:
from transformers import pipeline
clf = pipeline("image-classification", model=trainer.model, feature_extractor=processor)
clf(your_pil_image)

Device set to use cpu


NameError: name 'your_pil_image' is not defined

In [9]:
trainer.save_model("./vit-quickdraw-final")

print("Model saved to './vit-quickdraw-final'")

Model saved to './vit-quickdraw-final'


# Load Model and Use

In [2]:
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
import torch
import numpy as np

# Load model and processor
model_path = './vit-quickdraw-final'
model = ViTForImageClassification.from_pretrained(model_path)
processor = ViTImageProcessor.from_pretrained(model_path)
model.eval()

def classify_doodle(image):
    """
    Classify a doodle image
    
    Args:
        image: PIL Image or numpy array or path to image file
    
    Returns:
        tuple: (predicted_class, confidence_score, all_probabilities)
    """
    # Handle different input types
    if isinstance(image, str):
        image = Image.open(image)
    elif isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    # Ensure image is RGB
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    # Process the image
    inputs = processor(images=image, return_tensors="pt")
    
    # Make prediction
    with torch.no_grad():
        outputs = model(**inputs)
        probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
        predicted_class_id = probabilities.argmax().item()
        confidence = probabilities.max().item()
    
    # Get class labels and probabilities
    predicted_label = model.config.id2label[str(predicted_class_id)]
    all_probs = {model.config.id2label[str(i)]: prob.item() 
                 for i, prob in enumerate(probabilities[0])}
    
    return predicted_label, confidence, all_probs

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
model.parameters()
sum(p.numel() for p in model.parameters() if p.requires_grad)

85800963

In [2]:
test_image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
result = classify_doodle(test_image_path)
print(result)

FileNotFoundError: [Errno 2] No such file or directory: 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg'

In [None]:
from transformers import pipeline, AutoImageProcessor
from PIL import Image

processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
your_pil_image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
clf = pipeline("image-classification", model=model, feature_extractor=processor)
clf(your_pil_image)


Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.
Device set to use cpu


[{'label': 'circle', 'score': 0.40838727355003357},
 {'label': 'line', 'score': 0.3109080195426941},
 {'label': 'square', 'score': 0.2807047367095947}]

In [8]:
# clf("https://external-content.duckduckgo.com/iu/?u=https%3A%2F%2Fcreazilla-store.fra1.digitaloceanspaces.com%2Fcliparts%2F7822516%2Fhand-drawn-circles-clipart-md.png&f=1&nofb=1&ipt=3dde44d393119c19cd53efe09d82655be0eca574b56726fe4148de728d49ee24")
# clf("https://external-content.duckduckgo.com/iu/?u=https%3A%2F%2Fcontent.clipchamp.com%2Fcontent-repo%2Fcontent%2Fpreviews%2Fcc_ea807c5a.png&f=1&nofb=1&ipt=cc8ab3d317a4a685990a4ee9b588014c37391bd11e7c77cb44c718ed16ee4679")
# clf("https://external-content.duckduckgo.com/iu/?u=https%3A%2F%2Fcdn1.vectorstock.com%2Fi%2F1000x1000%2F65%2F35%2Fhand-drawn-circle-line-sketch-set-circular-vector-26936535.jpg&f=1&nofb=1&ipt=6b73329471feaa59e4f85137f0fca30aba8575629952e782fe86703cfe558511")
clf("https://external-content.duckduckgo.com/iu/?u=https%3A%2F%2Fas2.ftcdn.net%2Fv2%2Fjpg%2F05%2F52%2F34%2F03%2F1000_F_552340334_TXgPQTLmEPSlyJ6mZ1S6ixCZpmE4dvpV.jpg&f=1&nofb=1&ipt=655e28b444d76496d73cd20b9d80befefd25b4e11d822fed0197e183d40c3d78")

[{'label': 'line', 'score': 0.9834213852882385},
 {'label': 'circle', 'score': 0.012197574600577354},
 {'label': 'square', 'score': 0.004381043836474419}]