In [1]:
import torch

from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor

from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
from transformers.image_transforms import resize, to_channel_dimension_format

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

In [2]:
DEVICE = torch.device("cuda")
HF_TOKEN = "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"

In [4]:
PROCESSOR = AutoProcessor.from_pretrained(
    "HuggingFaceM4/VLM_WebSight_finetuned",
    token=HF_TOKEN,
    cache_dir="/juice2/scr2/nlp/pix2code/huggingface"
)
MODEL = AutoModelForCausalLM.from_pretrained(
    "HuggingFaceM4/VLM_WebSight_finetuned",
    token=HF_TOKEN,
    cache_dir="/juice2/scr2/nlp/pix2code/huggingface",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
).to(DEVICE)

Downloading (…)rocessor_config.json:   0%|          | 0.00/351 [00:00<?, ?B/s]

Downloading tokenizer_config.json:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

Downloading tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

Downloading added_tokens.json:   0%|          | 0.00/61.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/552 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Downloading config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

Downloading (…)guration_vmistral.py:   0%|          | 0.00/14.8k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/HuggingFaceM4/VLM_WebSight_finetuned:
- configuration_vmistral.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


Downloading modeling_vmistral.py:   0%|          | 0.00/81.0k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/HuggingFaceM4/VLM_WebSight_finetuned:
- modeling_vmistral.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


Downloading (…)fetensors.index.json:   0%|          | 0.00/79.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

Downloading (…)of-00004.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

Downloading (…)of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

Downloading (…)of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

Downloading (…)of-00004.safetensors:   0%|          | 0.00/1.69G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Downloading generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

In [5]:
image_seq_len = MODEL.config.perceiver_config.resampler_n_latents
BOS_TOKEN = PROCESSOR.tokenizer.bos_token
BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids

In [6]:
def convert_to_rgb(image):
    # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
    # for transparent images. The call to `alpha_composite` handles this case
    if image.mode == "RGB":
        return image

    image_rgba = image.convert("RGBA")
    background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
    alpha_composite = Image.alpha_composite(background, image_rgba)
    alpha_composite = alpha_composite.convert("RGB")
    return alpha_composite

In [7]:
# The processor is the same as the Idefics processor except for the BILINEAR interpolation,
# so this is a hack in order to redefine ONLY the transform method
def custom_transform(x):
    x = convert_to_rgb(x)
    x = to_numpy_array(x)
    x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR)
    x = PROCESSOR.image_processor.rescale(x, scale=1 / 255)
    x = PROCESSOR.image_processor.normalize(
        x,
        mean=PROCESSOR.image_processor.image_mean,
        std=PROCESSOR.image_processor.image_std
    )
    x = to_channel_dimension_format(x, ChannelDimension.FIRST)
    x = torch.tensor(x)
    return x

In [8]:
inputs = PROCESSOR.tokenizer(
    f"{BOS_TOKEN}<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>",
    return_tensors="pt",
    add_special_tokens=False,
)

In [9]:
image_path = '../../testset_100/5672.png'
with Image.open(image_path) as image:
    inputs["pixel_values"] = PROCESSOR.image_processor([image], transform=custom_transform)

In [10]:
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
generated_ids = MODEL.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_length=4096)
generated_text = PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]

In [11]:
print(generated_text)

<html>
<style>
* {
    margin: 0;
    padding: 0;
    box-sizing: border-box;
}

body {
    font-family: 'Playfair Display', serif;
    color: #333;
}

header {
    position: fixed;
    top: 0;
    left: 0;
    width: 100%;
    display: flex;
    justify-content: space-between;
    align-items: center;
    padding: 1rem 2rem;
    background-color: #fff;
    box-shadow: 0 2px 5px rgba(0, 0, 0, 0.3);
}

.logo {
    font-size: 1.5rem;
    color: #575fcf;
}

nav ul {
    display: flex;
    list-style: none;
}

nav ul li {
    margin: 0 1rem;
}

nav ul li a {
    text-decoration: none;
    color: #575fcf;
    font-size: 1.2rem;
}

.cta {
    background-color: #575fcf;
    color: #fff;
    padding: 1rem 2rem;
    text-transform: uppercase;
    letter-spacing: 0.1rem;
}

.hero {
    height: 100vh;
    display: flex;
    flex-direction: column;
    justify-content: center;
    align-items: center;
    text-align: center;
    padding: 2rem;
}

.hero h1 {
    font-size: 4rem;
    margin-bottom: 

In [13]:
from tqdm import tqdm 
import os

In [16]:
test_data_dir = "../../testset_100"
predictions_dir = "../../predictions_100/websight"

for filename in tqdm(os.listdir(test_data_dir)):
    if filename.endswith("2.png") or filename.endswith("5.png"):
        image_path = os.path.join(test_data_dir, filename)
        with Image.open(image_path) as image:
            inputs["pixel_values"] = PROCESSOR.image_processor([image], transform=custom_transform)
        inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
        generated_ids = MODEL.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_length=4096)
        generated_text = PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]

        with open(os.path.join(predictions_dir, filename.replace(".png", ".html")), "w", encoding='utf-8') as f:
            f.write(generated_text)

100%|█████████████████████████████████████████████████████████████████████████████| 201/201 [32:46<00:00,  9.78s/it]
