In [2]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (1

In [3]:
import torch
from torch import nn
from transformers import CLIPModel, CLIPProcessor, GPT2LMHeadModel, GPT2Tokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

In [5]:
class ImageToHTMLModel(nn.Module):
    def __init__(self, clip_model_name="openai/clip-vit-base-patch32", gpt_model_name="gpt2"):
        super().__init__()
        self.clip = CLIPModel.from_pretrained(clip_model_name).vision_model
        self.gpt = GPT2LMHeadModel.from_pretrained(gpt_model_name)
        self.linear = nn.Linear(self.clip.config.hidden_size, self.gpt.config.n_embd)

    def forward(self, pixel_values, labels=None):
        image_features = self.clip(pixel_values).last_hidden_state[:, 0, :]
        image_features = self.linear(image_features)

        if labels is not None:
            outputs = self.gpt(inputs_embeds=image_features.unsqueeze(1), labels=labels)
        else:
            outputs = self.gpt.generate(inputs_embeds=image_features.unsqueeze(1), max_length=512)

        return outputs

def preprocess_data(example, processor, tokenizer):
    image = processor(images=example['image'], return_tensors="pt").pixel_values
    html = tokenizer(example['html'], truncation=True, max_length=512, padding="max_length", return_tensors="pt")
    return {"pixel_values": image.squeeze(), "labels": html.input_ids.squeeze()}

def train(model, train_dataloader, optimizer, device):
    model.train()
    total_loss = 0
    for batch in tqdm(train_dataloader):
        pixel_values = batch['pixel_values'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(pixel_values, labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_dataloader)

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load dataset
    dataset = load_dataset("HuggingFaceM4/WebSight", split="train[:10]")  # Using a subset for demonstration

    # Initialize models and processors
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    gpt_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    gpt_tokenizer.pad_token = gpt_tokenizer.eos_token

    # Preprocess dataset
    processed_dataset = dataset.map(
        lambda example: preprocess_data(example, clip_processor, gpt_tokenizer),
        remove_columns=dataset.column_names
    )

    # Create data loader
    train_dataloader = DataLoader(processed_dataset, batch_size=8, shuffle=True)

    # Initialize model
    model = ImageToHTMLModel().to(device)

    # Training loop
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
    num_epochs = 5

    for epoch in range(num_epochs):
        avg_loss = train(model, train_dataloader, optimizer, device)
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

    # Save the model
    torch.save(model.state_dict(), "image_to_html_model.pth")

if __name__ == "__main__":
    main()



Resolving data files:   0%|          | 0/738 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/738 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/738 [00:00<?, ?files/s]

(…)-00005-of-00738-89cf78e53b934db0.parquet:   0%|          | 0.00/408M [00:00<?, ?B/s]

(…)-00006-of-00738-ba36f1dbd3143674.parquet:   0%|          | 0.00/414M [00:00<?, ?B/s]

(…)-00007-of-00738-00b0a9a4836cf7a5.parquet:   0%|          | 0.00/411M [00:00<?, ?B/s]

KeyboardInterrupt: 