# Packages, Imports, and Setup


In [None]:
from google.colab import drive

drive.mount("/content/drive")

In [None]:
%cd /content/drive/My Drive/Machine Learning/Sketch2Graphviz

In [None]:
%pip install -r requirements.txt

In [None]:
import os
import random
import numpy as np
import torch
from dotenv import load_dotenv
from huggingface_hub import login

from scripts.data import get_graphviz_hf_dataloaders
from scripts.model import Sketch2GraphvizVLM, print_num_params, load_sketch2graph_vlm
from scripts.finetune_lora import finetune_vlm_lora
from scripts.eval import evaluate_vlm
from scripts.inference import predict_graphviz_dot

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
load_dotenv()
hf_token = os.getenv("HF_TOKEN")

In [None]:
!huggingface-cli login --token hf_token

In [None]:
login(token=hf_token)

In [None]:
!nvidia-smi

In [None]:
batch_size = 16  # 4

instruction = (
    "You are a compiler that converts images of Graphviz diagrams into their exact Graphviz DOT code. "
    "Given the image, output only the DOT code, starting with either 'digraph' or 'graph', with no explanations, no markdown, and no extra text.\n"
)

# Training and Evaluation


In [None]:
train_dataloader, val_dataloader, test_dataloader = get_graphviz_hf_dataloaders(
    batch_size=batch_size,
    root_dir="graphviz_rendered",
    image_size=(336, 336),
)

In [None]:
model = Sketch2GraphvizVLM(
    vit_model_id="openai/clip-vit-large-patch14-336",
    llama_model_id="meta-llama/Llama-3.1-8B",
    quantization="4-bit",
    device=device,
).to(device)

model.llama_model.gradient_checkpointing_enable()
model.llama_model.config.use_cache = False
model.llama_model.enable_input_require_grads()

print_num_params(model)

In [None]:
lora_rank = 32

lr_vit = 1e-5
lr_lora = 2e-4
lr_proj = 1e-4
max_grad_norm = 1.0

weight_decay = 1e-2
num_epochs = 5

model, train_losses, val_losses = finetune_vlm_lora(
    model=model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    instruction=instruction,
    rank=lora_rank,
    lr_vit=lr_vit,
    lr_lora=lr_lora,
    lr_proj=lr_proj,
    weight_decay=weight_decay,
    num_epochs=num_epochs,
    max_grad_norm=max_grad_norm,
    model_save_dir="checkpoints",
    device=device,
)

print(f"Train Losses: {train_losses}")
print(f"Val Losses: {val_losses}")

In [None]:
test_loss = evaluate_vlm(
    model=model,
    iterator=test_dataloader,
    instruction=instruction,
    description="Testing",
    model_load_dir="checkpoints",
    epoch=10,
    device=device,
)

print(f"Test Loss: {test_loss:.6f}")

In [None]:
# model = load_sketch2graph_vlm(
#     model=model,
#     model_load_dir="checkpoints",
#     epoch_load=10,
#     device=device,
# )

In [None]:
predicted_graphviz_output = predict_graphviz_dot(
    model=model,
    image="graphs/graph_1.png",
    instruction=instruction,
    max_new_tokens=1024,
    do_sample=True,
    temperature=1.0,
    device=device,
)

print(predicted_graphviz_output)