# Installing the depedencies

In [None]:
!pip install unsloth
!pip install transformers accelerate bitsandbytes
!pip install datasets pillow sacrebleu

# Loading the Datasets

In [None]:
from datasets import load_dataset, Dataset

db = load_dataset("HuggingFaceM4/WebSight", "v0.2", split="train", streaming=True)

DB = []
for i, data in enumerate(db):
  DB.append(data)
  if i >= 100:
    break
val_db = DB[80:]
sub_db = DB[:80]
# db_small = Dataset.from_list(sub_db)
print("A smaller Database created ")

In [None]:
print(len(sub_db),len(val_db))

In [None]:
from datasets import Features, Value, Image

#  Schema
features = Features({
    "image": Image(),
    "text": Value("string"),
    "llm_generated_idea": Value("string")
})

db_small = Dataset.from_list(sub_db, features=features)

print("A smaller Database created ")

# Fine Tuning Flow

In [None]:
prompt = "Generate the HTML code for this webpage."

In [None]:
db_small[0]

In [None]:
converted_dataset = []

In [None]:
def convert_to_conversation(sample):
    conversation = [
        { "role": "user",
          "content" : [
            {"type" : "text",  "text"  : prompt},
            {"type" : "image", "image" : sample["image"]} ]
        },
        { "role" : "assistant",
          "content" : [
            {"type" : "text",  "text"  : sample["text"]} ]
        },
    ]
    return { "messages" : conversation }

In [None]:
converted_dataset = [convert_to_conversation(sample) for sample in db_small]

In [None]:
print(converted_dataset[0])

In [None]:
from unsloth import FastVisionModel

model, tokenizer = FastVisionModel.from_pretrained(
    model_name="unsloth/llava-1.5-7b-hf",
    load_in_4bit=True,
    use_gradient_checkpointing = "unsloth",
)

In [None]:
model = FastVisionModel.get_peft_model(
    model,
    finetune_vision_layers     = True,
    finetune_language_layers   = True,
    finetune_attention_modules = True,
    finetune_mlp_modules       = True,

    r = 8,
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)

In [None]:
from unsloth.trainer import UnslothVisionDataCollator
from trl import SFTTrainer, SFTConfig
from unsloth import is_bf16_supported


FastVisionModel.for_training(model)

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    data_collator = UnslothVisionDataCollator(model, tokenizer),
    train_dataset = converted_dataset,
    args = SFTConfig(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 10,
        max_steps = 30,
        num_train_epochs = 1,
        learning_rate = 2e-4,
        fp16 = not is_bf16_supported(),
        bf16 = is_bf16_supported(),
        logging_steps = 200,
        save_strategy='steps',
        save_steps=200,
        save_total_limit=2,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none",

        remove_unused_columns = False,
        dataset_text_field = "",
        dataset_kwargs = {"skip_prepare_dataset": True},
        dataset_num_proc = 4,
        max_length = 1024,
    ),
)

In [None]:
trainer_stats = trainer.train()

# Inference

In [None]:
# val_db[0]["text"]

In [None]:
FastVisionModel.for_inference(model) # Enable for inference

image = val_db[0]["image"]
instruction = "Generate the HTML code for this webpage."

messages = [
    {"role": "user", "content": [
        {"type": "image"},
        {"type": "text", "text": instruction}
    ]}
]
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
inputs = tokenizer(
    image,
    input_text,
    add_special_tokens = False,
    return_tensors = "pt",
).to("cuda")

from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer, skip_prompt = True)
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
                   use_cache = True, temperature = 1.5, min_p = 0.1)

In [None]:
val_db[0]

In [None]:
from tqdm import tqdm
import torch
import sacrebleu

preds = []
refs = []

for sample in tqdm(val_db):
    image = sample["image"]
    # Extract Image
    # user_msg = sample["messages"][0]["content"]
    # image = None
    # for c in user_msg:
    #     if c["type"] == "image":
    #         image = c["image"]    # PIL Image
    # assert image is not None

    # Extract Reference HTML
    ref = sample["text"].strip()
    refs.append(ref)

    # Prepare inference prompt
    instruction = "Generate the HTML code for this webpage."

    messages = [
    {"role": "user", "content": [
        {"type": "image"},
        {"type": "text", "text": instruction}
    ]}
    ]

    input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)

    inputs = tokenizer(
        image,
        input_text,
        return_tensors="pt"
    ).to("cuda")

    # Generate
    with torch.inference_mode():
        output = model.generate(**inputs, max_new_tokens=1024)

    pred = tokenizer.decode(output[0], skip_special_tokens=True).strip()
    preds.append(pred)


In [None]:
bleu = sacrebleu.corpus_bleu(preds, [refs])
print("BLEU Score:", bleu.score)

In [None]:
def token_accuracy(pred, ref):
    pred_tok = tokenizer.tokenize(pred)
    ref_tok = tokenizer.tokenize(ref)

    min_len = min(len(pred_tok), len(ref_tok))
    correct = sum(pred_tok[i] == ref_tok[i] for i in range(min_len))

    total = len(ref_tok)
    return correct / total

token_accs = [token_accuracy(p, r) for p, r in zip(preds, refs)]
print("Token-Level Accuracy:", sum(token_accs) / len(token_accs))