# Packages, Imports, and Setup


In [None]:
from google.colab import drive

drive.mount("/content/drive")

In [None]:
%cd /content/drive/My Drive/Machine Learning/Sketch2Graphviz

In [None]:
%pip install -r requirements.txt

In [None]:
import os
import random
import numpy as np
import torch
from dotenv import load_dotenv
from huggingface_hub import login
import matplotlib.pyplot as plt

from scripts.data import get_json_graphviz_json_dataloaders
from scripts.model import (
    Sketch2GraphvizVLM,
    save_sketch2graphviz_vlm,
    load_sketch2graphviz_vlm,
    print_num_params,
)
from scripts.finetune_lora import finetune_vlm_lora
from scripts.eval import evaluate_vlm, generate_vlm_outputs, evaluate_vlm_outputs
from scripts.inference import predict_graphviz_dot_from_image
from scripts.embeddings import get_graphviz_image_embeddings
from scripts.psql_vector_db import (
    store_embeddings_in_db,
    get_top_k_similar_vectors_from_db,
)
from scripts.prompts import graphviz_code_from_image_instruction


SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high")

load_dotenv()
hf_token = os.getenv("HF_TOKEN")

batch_size = 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
login(token=hf_token)

In [None]:
!nvidia-smi

# Loading Data


In [None]:
train_dataloader, test_dataloader = get_json_graphviz_json_dataloaders(
    json_path="synthetic_data_gen.json",
    batch_size=batch_size,
    root_dir="graphviz_rendered_json",
    image_size=(768, 768),
    return_tensor=False,
    handdrawn_probability=0.30,
)

In [None]:
train_sample = next(iter(train_dataloader))

plt.imshow(np.array(train_sample["images"][0]))
plt.axis(False)
plt.show()

print(f"Graphviz Code: {train_sample["graphviz_code"][0]}")

# Loading Sketch2Graphviz Model


In [None]:
# Fine-tune the model in 16-bit and do inferencing with 4-bit quantization

model = Sketch2GraphvizVLM(
    llama_model_id="meta-llama/Llama-3.2-11B-Vision-Instruct",
    quantization="16-bit",
    device=device,
)

if model.quantization != "16-bit":
    model.llama_model.gradient_checkpointing_enable()
    model.llama_model.config.use_cache = False
    model.llama_model.enable_input_require_grads()

print_num_params(model)

# LoRA Fine-Tuning and Evaluation


In [None]:
lora_rank = 64
lora_dropout = 0.1

# Effective batch size = batch_size * grad_accumulation_steps = 16
grad_accumulation_steps = 16

lr = 1e-4  # 2e-4
weight_decay = 1e-2  # 1e-3
warmup_ratio = 0.1
early_stopping_patience = 2
max_grad_norm = 1.0

num_epochs = 12

model, train_losses, val_losses = finetune_vlm_lora(
    model=model,
    train_dataloader=train_dataloader,
    val_dataloader=test_dataloader,
    instruction=graphviz_code_from_image_instruction,
    rank=lora_rank,
    lora_dropout=lora_dropout,
    grad_accumulation_steps=grad_accumulation_steps,
    lr=lr,
    weight_decay=weight_decay,
    warmup_ratio=warmup_ratio,
    num_epochs=num_epochs,
    use_val_early_stopping=True,
    early_stopping_patience=early_stopping_patience,
    max_grad_norm=max_grad_norm,
    model_save_dir="checkpoints",
    device=device,
)

print(f"Train Losses: {train_losses}")
print(f"Val Losses: {val_losses}")

In [None]:
# model = load_sketch2graphviz_vlm(
#     model_load_dir="checkpoints",
#     epoch_load=1,
#     quantization="16-bit",
#     is_training=False,
#     device=device,
# )

test_loss = evaluate_vlm(
    model=model,
    iterator=test_dataloader,
    instruction=graphviz_code_from_image_instruction,
    description="Testing",
    device=device,
)

print(f"Test Loss: {test_loss:.6f}")

In [None]:
testing_vlm_outputs_no_rag = generate_vlm_outputs(
    model=model,
    iterator=test_dataloader,
    instruction=graphviz_code_from_image_instruction,
    use_rag=False,
    top_K_rag=5,
    max_new_tokens=1024,
    do_sample=False,
    temperature=1.0,
    skip_special_tokens=True,
    description="Testing",
    outputs_save_path="testing_outputs_no_rag.jsonl",
    device=device,
)

In [None]:
evaluation_results_no_rag = evaluate_vlm_outputs(
    description="Evaluating otuputs",
    outputs_load_path="testing_outputs_no_rag.jsonl",
)

evaluation_results_no_rag

In [None]:
testing_vlm_outputs_rag = generate_vlm_outputs(
    model=model,
    iterator=test_dataloader,
    instruction=graphviz_code_from_image_instruction,
    use_rag=True,
    top_K_rag=5,
    max_new_tokens=1024,
    do_sample=False,
    temperature=1.0,
    skip_special_tokens=True,
    description="Testing",
    outputs_save_path="testing_outputs_rag.jsonl",
    device=device,
)

In [None]:
evaluation_results_rag = evaluate_vlm_outputs(
    description="Evaluating otuputs",
    outputs_load_path="testing_outputs_rag.jsonl",
)

evaluation_results_rag

In [None]:
# best_epoch = 2

# save_sketch2graphviz_vlm(
#     model=model, model_save_dir="checkpoints", epoch_save=best_epoch
# )

# model = load_sketch2graphviz_vlm(
#     model_load_dir="checkpoints",
#     epoch_load=best_epoch,
#     quantization="16-bit",
#     is_training=False,
#     device=device,
# )

# PostgreSQL + PGVector Database RAG


In [None]:
!apt-get update -qq
!apt-get install -y -qq ca-certificates curl gnupg2 lsb-release > /dev/null
!curl -fsSL https://www.postgresql.org/media/keys/ACCC4CF8.asc | gpg --dearmor -o /etc/apt/trusted.gpg.d/postgresql.gpg
!echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list

In [None]:
# Install PostgreSQL 18
!apt-get update -qq
!apt-get install -y -qq postgresql-18 postgresql-contrib-18 > /dev/null

!service postgresql start
!psql --version

In [None]:
# Install PGVector
!sudo apt update && sudo apt install postgresql-18-pgvector

In [None]:
!sudo -u postgres psql -c "CREATE USER root WITH SUPERUSER"
!sudo -u postgres psql -c "CREATE DATABASE sketch2graphvizdb"

In [None]:
%pip install colab-xterm

In [None]:
load_ext colabxterm

In [None]:
%xterm

In [None]:
# Setup commands to run in xterm/cli:

# sudo -u postgres psql
# \c sketch2graphvizdb
# CREATE EXTENSION IF NOT EXISTS vector;
# \dx

In [None]:
graphviz_codes_and_embedding_vectors = get_graphviz_image_embeddings(
    model=model,
    dataloader=train_dataloader,
    device=device,
)

In [None]:
store_embeddings_in_db(
    embedding_data=graphviz_codes_and_embedding_vectors,
    dbname="sketch2graphvizdb",
    user=None,
    table_name="graphviz_embeddings",
    embedding_dim=graphviz_codes_and_embedding_vectors[0][1].shape[0],  # 4096
)

In [None]:
code, query_vector = graphviz_codes_and_embedding_vectors[0]

vector_similarity_results = get_top_k_similar_vectors_from_db(
    embedding_vector=query_vector,
    top_K=5,
    dbname="sketch2graphvizdb",
    user=None,
    table_name="graphviz_embeddings",
)

for result in vector_similarity_results:
    id, graphviz_code, embedding_distance = result

    print(f"ID: {id}")
    print(f"Graphviz Code: {graphviz_code}")
    print(f"Embedding Distance: {embedding_distance}")
    print("\n")

In [None]:
# Save sketch2graphvizdb data
!pg_dump sketch2graphvizdb > sketch2graphvizdb.sql

In [None]:
# Load sketch2graphvizdb data
!psql -d sketch2graphvizdb -f sketch2graphvizdb.sql

# Inference


In [None]:
# Load and inference with 4-bit quantization

model = load_sketch2graphviz_vlm(
    model_load_dir="checkpoints",
    epoch_load=2,
    quantization="4-bit",  # 16-bit
    is_training=False,
    device=device,
)

In [None]:
# Without RAG

predicted_graphviz_output = predict_graphviz_dot_from_image(
    model=model,
    image="testing_graphs/graph_6.png",
    instruction=graphviz_code_from_image_instruction,
    should_print_instruction=False,
    use_rag=False,
    top_K_rag=5,
    max_new_tokens=2048,
    do_sample=False,
    temperature=0.3,
    skip_special_tokens=False,
    device=device,
)

print(predicted_graphviz_output)

In [None]:
# With RAG

predicted_graphviz_output = predict_graphviz_dot_from_image(
    model=model,
    image="testing_graphs/graph_6.png",
    instruction=graphviz_code_from_image_instruction,
    should_print_instruction=False,
    use_rag=True,
    top_K_rag=5,
    max_new_tokens=2048,
    do_sample=False,
    temperature=0.3,
    skip_special_tokens=False,
    device=device,
)

print(predicted_graphviz_output)