In [2]:
!pip install torch torchvision transformers datasets



In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from transformers import BartTokenizer, BartForConditionalGeneration
from torch.utils.data import Dataset, DataLoader, IterableDataset
from datasets import load_dataset
import numpy as np
from tqdm import tqdm

In [None]:
class WebsightIterableDataset(IterableDataset):
    def __init__(self, processor, tokenizer, max_length=512, validation=False, val_ratio=0.1):
        self.dataset = load_dataset(
            "HuggingFaceM4/WebSight",
            "v0.1",
            split="train",
            streaming=True
        )
        self.processor = processor
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.validation = validation
        self.val_ratio = val_ratio

    def __iter__(self):
        iterator = iter(self.dataset)
        while True:
            try:
                item = next(iterator)
                if 'screenshot' not in item or 'code' not in item:
                    continue

                is_val = hash(item['code']) % 10 < (self.val_ratio * 10)
                if is_val != self.validation:
                    continue

                # Process image (resize and normalize)
                image = self.processor(item['screenshot'])

                # Tokenize HTML
                html_tokens = self.tokenizer(
                    item['code'],
                    max_length=self.max_length,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt"
                )

                yield {
                    'image': image,
                    'html_input_ids': html_tokens.input_ids[0],
                    'html_attention_mask': html_tokens.attention_mask[0],
                    'raw_html': item['code']
                }
            except StopIteration:
                break

In [None]:
class ImageToHTMLGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        #  MobileNetV3 for feature extraction
        mobilenet = models.mobilenet_v3_small(pretrained=True)
        self.image_encoder = nn.Sequential(*list(mobilenet.children())[:-1])  # Remove classifier
        
        self.projection = nn.Linear(576, 768)  # Match BART hidden dim

        #  DistilBART for text generation
        self.bart = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

    def forward(self, image, html_input_ids=None, html_attention_mask=None):
        with torch.no_grad():
            image_features = self.image_encoder(image).squeeze(-1).squeeze(-1)  # Extract feature map
        
        projected_features = self.projection(image_features).unsqueeze(1)  # Align with BART's hidden dim

        if html_input_ids is not None:
            outputs = self.bart(
                input_ids=html_input_ids,
                attention_mask=html_attention_mask,
                encoder_outputs=(projected_features,),
                labels=html_input_ids
            )
            return outputs
        else:
            generated = self.bart.generate(
                encoder_outputs=(projected_features,),
                max_length=256,
                num_beams=2,
                early_stopping=True
            )
            return generated

def create_data_loaders(processor, tokenizer, batch_size=8):
    train_dataset = WebsightIterableDataset(processor, tokenizer, validation=False)
    val_dataset = WebsightIterableDataset(processor, tokenizer, validation=True)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=2, prefetch_factor=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=2, prefetch_factor=2)

    return train_loader, val_loader

In [None]:
def train_model(model, train_loader, val_loader, num_epochs=10, device='cuda'):
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

    best_val_loss = float('inf')
    steps_per_epoch = 250

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        train_steps = 0

        train_pbar = tqdm(enumerate(train_loader), total=steps_per_epoch, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
        
        for step, batch in train_pbar:
            if step >= steps_per_epoch:
                break
                
            try:
                image = batch['image'].to(device, non_blocking=True)
                html_input_ids = batch['html_input_ids'].to(device, non_blocking=True)
                html_attention_mask = batch['html_attention_mask'].to(device, non_blocking=True)

                outputs = model(
                    image=image,
                    html_input_ids=html_input_ids,
                    html_attention_mask=html_attention_mask
                )

                loss = outputs.loss
                train_loss += loss.item()

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()

                train_steps += 1
                train_pbar.set_postfix({'loss': train_loss / (step + 1)})
            except Exception as e:
                print(f"Error in training batch: {e}")
                continue

        avg_train_loss = train_loss / train_steps

        model.eval()
        val_loss = 0
        val_steps = 0

        val_pbar = tqdm(enumerate(val_loader), total=steps_per_epoch // 5, desc=f'Epoch {epoch+1}/{num_epochs} [Val]')
        
        with torch.no_grad():
            for step, batch in val_pbar:
                if step >= steps_per_epoch // 5:
                    break
                    
                try:
                    image = batch['image'].to(device)
                    html_input_ids = batch['html_input_ids'].to(device)
                    html_attention_mask = batch['html_attention_mask'].to(device)

                    outputs = model(
                        image=image,
                        html_input_ids=html_input_ids,
                        html_attention_mask=html_attention_mask
                    )

                    val_loss += outputs.loss.item()
                    val_steps += 1
                    val_pbar.set_postfix({'loss': val_loss / (step + 1)})
                    
                except Exception as e:
                    print(f"Error in validation batch: {e}")
                    continue

        avg_val_loss = val_loss / val_steps
        scheduler.step(avg_val_loss)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
            }, "best_model.pth")
        
        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Training Loss: {avg_train_loss:.4f}')
        print(f'Validation Loss: {avg_val_loss:.4f}')

In [None]:
processor = models.mobilenet_v3_small(pretrained=True).features  # Image feature extractor
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
    
train_loader, val_loader = create_data_loaders(processor, tokenizer)
 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImageToHTMLGenerator().to(device)
print('\n', device)

train_model(model, train_loader, val_loader, num_epochs=10, device=device)

Downloading: "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_small-047dcff4.pth
100%|██████████| 9.83M/9.83M [00:00<00:00, 97.7MB/s]


vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/5.16k [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/738 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/71 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/738 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/71 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]


 cuda


Epoch 1/10 [Train]:   0%|          | 0/250 [00:00<?, ?it/s]