# Packages, Imports, and Setup


In [None]:
from google.colab import drive

drive.mount("/content/drive")

In [None]:
%cd /content/drive/My Drive/Machine Learning/Sketch2Graphviz

In [None]:
%pip install -r requirements.txt

In [None]:
import os
import random
import numpy as np
import torch
from dotenv import load_dotenv
from huggingface_hub import login, notebook_login

from scripts.data import get_graphviz_hf_dataloaders, get_json_graphviz_json_dataloaders
from scripts.model import (
    Sketch2GraphvizVLM,
    print_num_params,
    save_sketch2graphviz_vlm_local,
    load_sketch2graphviz_vlm_local,
    save_sketch2graphviz_vlm_hf,
    load_sketch2graphviz_vlm_hf,
)
from scripts.finetune_lora import finetune_vlm_lora
from scripts.eval import evaluate_vlm
from scripts.inference import predict_graphviz_dot

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

load_dotenv()
hf_token = os.getenv("HF_TOKEN")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
# !huggingface-cli login --token hf_token

In [None]:
login(token=hf_token)
# notebook_login()

In [None]:
!nvidia-smi

In [None]:
batch_size = 1

instruction = """
## 1. Role Definition

You are an expert ompiler that converts images of Graphviz diagrams into their exact Graphviz DOT code.
Given an image of a graph, using only the image, output only the DOT code, starting with either 'digraph' or 'graph', with no explanations, no markdown, and no extra text.
Graphviz DOT is a plain-text language for describing graphs as nodes and edges with optional attributes such as labels, shapes, colors, and styles, for both directed ('digraph') and undirected ('graph') diagrams.

## 2. Core Syntax Rules (The "Grammar")

Follow these strict syntax constraints:

1.  **Graph Type:**
      * Use `digraph` (Directed Graph) for hierarchies, flows, or dependencies. Use `->` for edges.
      * Use `graph` (Undirected Graph) for physical networks or mutual connections. Use `--` for edges.
2.  **Identifiers:**
      * Alphanumeric strings (e.g., `A`, `node1`) do not need quotes.
      * Strings with spaces, special characters, or reserved keywords MUST be enclosed in double quotes (e.g., `"User Login"`, `"Data-Base"`).
3.  **Statement Termination:** End all node, edge, and attribute statements with a semicolon `;`.
4.  **Scope:** All code must be enclosed within braces `{ ... }`.

## 3. Attribute Dictionary (The "Vocabulary")

Apply attributes using brackets `[key=value]`. If multiple attributes are needed, comma-separate them or use spaces: `[shape=box, color=red]`.

### Node Attributes

  * **`shape`**:
      * Process/Step: `box`
      * Start/End: `ellipse` or `oval`
      * Decision: `diamond`
      * Database: `cylinder`
      * Code/Structure: `record` (use `|` to separate fields in label)
  * **`style`**: `filled`, `rounded`, `dotted`, `invis`
  * **`fillcolor`**: Hex codes (`#FF0000`) or common names (`lightblue`). Only visible if `style=filled`.
  * **`label`**: The visible text. If omitted, the identifier is used.

### Edge Attributes

  * **`label`**: Text displayed along the line.
  * **`style`**: `solid` (default), `dashed` (future/theoretical), `dotted`.
  * **`dir`**: `forward` (default), `back`, `both`, `none`.
  * **`color`**: Edge color.

## 4. Structural Logic

  * **Clusters:** To group nodes visually (draw a box around them), the subgraph name **MUST** start with the prefix `cluster_` (e.g., `subgraph cluster_backend { ... }`).
  * **Rank:** To force nodes to appear on the same horizontal level, use `{ rank=same; NodeA; NodeB; }`.
  * **Direction:** Default flow is Top-to-Bottom. For Left-to-Right flow, add `rankdir=LR;` at the top of the graph scope.

## 5. HTML Labels (Advanced Formatting)

If needed for tables or complex text formatting:

  * Do **not** wrap the label in quotes `"..."`.
  * Wrap the label in angle brackets `<...>` and HTML-like syntax.
  * Example: `label=< <B>Bold Text</B><BR/>Subtitle >`

## 7. Output Protocol

  * Output **only** the code block.
  * Do not include any explanations.
  * Ensure all braces `{}` are balanced.

"""

# Training and Evaluation


In [None]:
# train_dataloader, val_dataloader, test_dataloader = get_graphviz_hf_dataloaders(
#     batch_size=batch_size,
#     root_dir="graphviz_rendered",
#     image_size=(336, 336),  # (672, 672), (1008, 1008)
# )

train_dataloader, test_dataloader = get_json_graphviz_json_dataloaders(
    json_path="simple_synthetic_data_gen.json",
    batch_size=batch_size,
    root_dir="graphviz_rendered_json",
    image_size=(336, 336),  # (672, 672), (1008, 1008), None
)

In [None]:
model = Sketch2GraphvizVLM(
    vit_model_id="openai/clip-vit-large-patch14-336",
    llama_model_id="meta-llama/Llama-3.1-8B-Instruct",
    quantization="4-bit",
    tile_images=False,  # True
    device=device,
).to(device)

model.llama_model.gradient_checkpointing_enable()
model.llama_model.config.use_cache = False
model.llama_model.enable_input_require_grads()

print_num_params(model)

In [None]:
lora_rank = 32
lora_dropout = 0.1

lr_vit = 1e-5  # 5e-6
lr_proj = 1e-4  # 5e-5
lr_lora = 2e-4  # 1e-4

weight_decay_vit = 1e-2  # 5e-2
weight_decay_proj = 1e-2  # 5e-2
weight_decay_lora = 1e-2  # 0.0

max_grad_norm = 1.0

num_epochs = 10

model, train_losses, val_losses = finetune_vlm_lora(
    model=model,
    train_dataloader=train_dataloader,
    val_dataloader=test_dataloader,
    instruction=instruction,
    rank=lora_rank,
    lora_dropout=lora_dropout,
    lr_vit=lr_vit,
    lr_proj=lr_proj,
    lr_lora=lr_lora,
    weight_decay_vit=weight_decay_vit,
    weight_decay_proj=weight_decay_proj,
    weight_decay_lora=weight_decay_lora,
    num_epochs=num_epochs,
    use_val_early_stopping=True,
    max_grad_norm=max_grad_norm,
    model_save_dir="checkpoints",
    device=device,
)

print(f"Train Losses: {train_losses}")
print(f"Val Losses: {val_losses}")

In [None]:
test_loss = evaluate_vlm(
    model=model,
    iterator=test_dataloader,
    instruction=instruction,
    description="Testing",
    model_load_dir="checkpoints",
    epoch_load=None,
    device=device,
)

print(f"Test Loss: {test_loss:.6f}")

In [None]:
# Save and Load from HuggingFace

# save_sketch2graphviz_vlm_hf(model=model, huggingface_username="rishaba")

# model = load_sketch2graphviz_vlm_hf(
#     model_load_dir="checkpoints",
#     epoch_load=10,
#     huggingface_username="rishaba",
#     device=device,
# )

In [None]:
# Save and Load from Local

# save_sketch2graphviz_vlm_local(model=model, model_save_dir="checkpoints", epoch_save=6)

# model = load_sketch2graphviz_vlm_local(
#     model_load_dir="checkpoints",
#     epoch_load=6,
#     device=device,
# )

In [None]:
predicted_graphviz_output = predict_graphviz_dot(
    model=model,
    image="testing_graphs/graph_1.png",
    instruction=instruction,
    max_new_tokens=1024,
    do_sample=True,
    temperature=1.0,
    skip_special_tokens=True,
    device=device,
)

print(predicted_graphviz_output)