In [1]:
import os
import json
import torch
from datasets import Dataset, DatasetDict, load_dataset
from torchvision.io import ImageReadMode, read_image
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode
from torchvision import transforms
import numpy as np

from transformers import (
    Trainer,
    TrainingArguments,
    VisionTextDualEncoderModel,
    VisionTextDualEncoderProcessor,
    AutoTokenizer,
    AutoImageProcessor
)
from PIL import Image

In [2]:
# Define paths
base_dir = "/home/jupyter/novice"
jsonl_path = os.path.join(base_dir, "vlm.jsonl")
images_dir = os.path.join(base_dir, "images")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [3]:
# Initialize model and processor
model = VisionTextDualEncoderModel.from_vision_text_pretrained("google/siglip-base-patch16-224", "bert-base-uncased").to(device)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")
processor = VisionTextDualEncoderProcessor(image_processor, tokenizer)
config = model.config



config.json:   0%|          | 0.00/432 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/813M [00:00<?, ?B/s]

The projection layer and logit scale weights `['visual_projection.weight', 'text_projection.weight', 'logit_scale']` are newly initialized. You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


preprocessor_config.json:   0%|          | 0.00/368 [00:00<?, ?B/s]

In [4]:
# Define the directory to save cropped images
cropped_images_dir = "/home/jupyter/til-24-base/derrick/clip/images"
os.makedirs(cropped_images_dir, exist_ok=True)

# Function to crop images based on bounding boxes
def crop_and_save_images(jsonl_path, images_dir):
    cropped_data = {"image_path": [], "caption": []}
    with open(jsonl_path, "r") as f:
        for line in f:
            if line.strip() == "":
                continue
            instance = json.loads(line.strip())
            image_name = instance["image"]
            image_path = os.path.join(images_dir, image_name)
            annotations = instance["annotations"]
            for i, annotation in enumerate(annotations):
                bbox = annotation["bbox"]
                caption = annotation["caption"]
                x, y, w, h = bbox
                # image = Image.open(image_path).convert("RGB")
                # cropped_image = image.crop((x, y, x+w, y+h))
                cropped_image_path = os.path.join(cropped_images_dir, f"{image_name[:-4]}_{i}.jpg")
                # cropped_image.save(cropped_image_path)
                cropped_data["image_path"].append(cropped_image_path)
                cropped_data["caption"].append(caption)
    return cropped_data


In [5]:
# Load the dataset
dataset = crop_and_save_images(jsonl_path, images_dir)

In [6]:
# Convert to Dataset
dataset = Dataset.from_dict(dataset)

In [7]:
# Split the dataset
dataset = dataset.train_test_split(test_size=0.2)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]

In [8]:
# Define image transformations
class Transform(torch.nn.Module):
    def __init__(self, image_size, mean, std):
        super().__init__()
        self.transforms = torch.nn.Sequential(
            Resize([image_size], interpolation=InterpolationMode.BICUBIC),
            CenterCrop(image_size),
            ConvertImageDtype(torch.float),
            Normalize(mean, std),
        )

    def forward(self, x) -> torch.Tensor:
        with torch.no_grad():
            x = self.transforms(x)
        return x

In [9]:
# Initialize torchvision transforms and jit it for faster processing.
image_transformations = Transform(
    config.vision_config.image_size, image_processor.image_mean, image_processor.image_std
)
image_transformations = torch.jit.script(image_transformations)

In [10]:
# Preprocess function
def preprocess_dataset(data, split):
    
    # We need to tokenize inputs and targets.
    column_names = data.column_names

    # Get the column names for input/targets
    image_column = "image_path"
    caption_column = "caption"
    dataset_columns = (image_column, caption_column)

    # Preprocessing the datasets.
    # We need to tokenize input captions and transform the images.
    def tokenize_captions(examples):
        captions = list(examples[caption_column])
        text_inputs = tokenizer(captions, padding="max_length", truncation=True)
        examples["input_ids"] = text_inputs.input_ids
        examples["attention_mask"] = text_inputs.attention_mask
        return examples

    def transform_images(examples):
        images = [read_image(image_file, mode=ImageReadMode.RGB) for image_file in examples[image_column]]
        examples["pixel_values"] = [image_transformations(image) for image in images]
        return examples

    data = data.map(
        function=tokenize_captions,
        batched=True,
        remove_columns=[col for col in column_names if col != image_column],
        desc=f"Running tokenizer on {split} dataset",
    )

    # Transform images on the fly as doing it on the whole dataset takes too much time.
    data.set_transform(transform_images)
    return data

In [11]:
train_data = preprocess_dataset(train_dataset, "train")
eval_data = preprocess_dataset(eval_dataset, "validation")

Running tokenizer on train dataset:   0%|          | 0/11963 [00:00<?, ? examples/s]



Running tokenizer on validation dataset:   0%|          | 0/2991 [00:00<?, ? examples/s]

In [12]:
# # Define data collator
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    input_ids = torch.tensor([example["input_ids"] for example in examples], dtype=torch.long)
    attention_mask = torch.tensor([example["attention_mask"] for example in examples], dtype=torch.long)
    return {
        "pixel_values": pixel_values,
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "return_loss": True,
    }

In [13]:
output_dir="siglip-base-finetune"
learning_rate=1e-5
weight_decay=0.1
batch_size=12
num_epochs=15

In [14]:
training_args = TrainingArguments(
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    per_device_train_batch_size=batch_size,
    remove_unused_columns=False,
    output_dir=output_dir,
    report_to='none',  # disable wandb
    num_train_epochs=num_epochs,  # Increase epochs to allow early stopping to function
    eval_strategy="steps",  # Change evaluation strategy to steps
    save_strategy="steps",  # Set save strategy to match eval strategy
    save_steps=2000  # Save checkpoints every 2000 steps
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=eval_data,
    data_collator=collate_fn
)
# Train the model
trainer.train()

ValueError: You have to specify input_ids

In [None]:
metrics = trainer.evaluate()
print(metrics)

In [None]:
trainer.save_model("model/clip-large-finetune")
tokenizer.save_pretrained("model/clip-large-finetune")
image_processor.save_pretrained("model/clip-large-finetune")