In [None]:
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
from finetuning_dataset import QuickDrawFineTuningDataset
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

dataset = QuickDrawFineTuningDataset(labels=["apple", "cat"], tokenizer=tokenizer, max_length=1024)
model = AutoModelForCausalLM.from_pretrained("gpt2")

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=5,
    per_device_train_batch_size=8,
    save_steps=2000,
    save_total_limit=200,
    logging_dir="./logs",
    logging_steps=100,
    fp16=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
)

trainer.train()

Downloading QuickDraw files: 100%|██████████| 2/2 [00:00<00:00, 7717.21it/s]
Loading QuickDraw files: 100%|██████████| 2/2 [00:05<00:00,  2.79s/it]


Loaded tokenized data from sketch_tokenized_dataset_test_gpt2.pkl


Step,Training Loss
100,1.0268
200,0.2781
300,0.2644
400,0.252
500,0.2481
600,0.2513
700,0.2491
800,0.2436
900,0.2386
1000,0.2332


KeyboardInterrupt: 

In [35]:
prompt = "cat sketch"
inputs = tokenizer(prompt, return_tensors="pt")

# Move to GPU if available
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
inputs = {k: v.to(device) for k, v in inputs.items()}

# Generate text
outputs = model.generate(
    **inputs,
    max_length=1024,
    num_return_sequences=1,
    do_sample=True,          # random sampling
    top_k=50,                # limits sampling pool
    top_p=0.95,              # nucleus sampling
    temperature=0.8,         # controls randomness
    pad_token_id=tokenizer.eos_token_id,
)

# Decode and print
out = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(out)

out = out.split(prompt, 1)[-1]  # remove prompt

out_svg=f"""<svg viewBox="0 0 64 64"><g stroke-width="0.8">
<path d="{out}" stroke="black" fill="none"/>
</g></svg>"""

print(out_svg)

from IPython.display import display, HTML

display(HTML(f"""<div style="display:inline-block; width: 150px; background-color: white; margin-right:10px;"><b>char len: {len(out_svg)}</b><br>{out_svg}</div>"""))

cat sketch M25,13C14,0 4,8 3,19C0,32 5,52 21,57C25,57 27,62 29,62M27,12C22,12 19,9 13,13C9,16 10,19 7,22C7,26 14,27 17,28C20,30 31,25 32,27C33,25 33,24 34,23C37,23 37,21 38,20M35,13C35,10 33,5 33,3C31,0 33,8 33,12M26,15C27,14 29,14 29,15M43,15C39,15 35,16 36,17M31,16C34,16 36,17 41,16M31,17C32,19 37,17 39,17M30,17C34,17 38,16 40,15M31,17C32,19 32,20 32,25C31,24 32,24 31,24M29,26C27,29 26,30 25,29M25,26C26,27 28,28 31,28M25,27C25,28 26,30 27,30M29,25C29,25 28,28 28,28M28,26C32,28 37,29 37,29
<svg viewBox="0 0 64 64"><g stroke-width="0.8">
<path d=" M25,13C14,0 4,8 3,19C0,32 5,52 21,57C25,57 27,62 29,62M27,12C22,12 19,9 13,13C9,16 10,19 7,22C7,26 14,27 17,28C20,30 31,25 32,27C33,25 33,24 34,23C37,23 37,21 38,20M35,13C35,10 33,5 33,3C31,0 33,8 33,12M26,15C27,14 29,14 29,15M43,15C39,15 35,16 36,17M31,16C34,16 36,17 41,16M31,17C32,19 37,17 39,17M30,17C34,17 38,16 40,15M31,17C32,19 32,20 32,25C31,24 32,24 31,24M29,26C27,29 26,30 25,29M25,26C26,27 28,28 31,28M25,27C25,28 26,30 27,30M29,25C29,