In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import CLIPProcessor, GPT2Tokenizer
from datasets import load_dataset
from tqdm import tqdm
from model import ImageToHTMLModel

In [None]:
def preprocess_data(example, processor, tokenizer):
    image = processor(images=example['image'], return_tensors="pt").pixel_values
    html = tokenizer(example['html'], truncation=True, max_length=512, padding="max_length", return_tensors="pt")
    return {"pixel_values": image.squeeze(), "labels": html.input_ids.squeeze()}

def train(model, train_dataloader, optimizer, device):
    model.train()
    total_loss = 0
    for batch in tqdm(train_dataloader):
        pixel_values = batch['pixel_values'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(pixel_values, labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_dataloader)

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load dataset
    dataset = load_dataset("HuggingFaceM4/WebSight", split="train[:1000]")  # Using a subset for demonstration

    # Initialize models and processors
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    gpt_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    gpt_tokenizer.pad_token = gpt_tokenizer.eos_token

    # Preprocess dataset
    processed_dataset = dataset.map(
        lambda example: preprocess_data(example, clip_processor, gpt_tokenizer),
        remove_columns=dataset.column_names
    )

    # Create data loader
    train_dataloader = DataLoader(processed_dataset, batch_size=8, shuffle=True)

    # Initialize model
    model = ImageToHTMLModel().to(device)

    # Training loop
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
    num_epochs = 5

    for epoch in range(num_epochs):
        avg_loss = train(model, train_dataloader, optimizer, device)
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

    # Save the model
    torch.save(model.state_dict(), "image_to_html_model.pth")

if __name__ == "__main__":
    main()

