In [None]:
!pip install -U timm==1.0.17 transformers==4.51.1

In [1]:
import gc
import os
import glob
import json
import torch
from torch.utils.data import Dataset
from transformers import AutoModelForCausalLM, AutoProcessor, AutoModel
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor, get_scheduler
from torch.optim import AdamW
from PIL import Image

In [2]:
class OCRDataset(Dataset):
    def __init__(self, images_root, prompt="What is the printed values?", resize=None):
        self.images = glob.glob(images_root+"/*.jpg")
        self.resize = resize
        self.prompt= prompt

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        label_path = self.images[idx].replace(".jpg", ".txt")
        
        if (not os.path.exists(image_path)) or (not os.path.exists(label_path)):
            return self.__getitem__(idx-1)

        image = Image.open(image_path)
        if self.resize is not None:
            image = image.resize(self.resize)
        
        image = image.convert("RGBA").convert("RGB")

        with open(label_path) as f:
            label = f.read().strip().rstrip("DH").strip()
            if not label:
                return self.__getitem__(idx-1)

        return self.prompt, label, image

def collate_fn(batch):
    questions, answers, images = zip(*batch)
    inputs = processor(text=list(questions), images=list(images), return_tensors="pt", padding=True).to(device)
    return inputs, answers

BATCH_SIZE = 4
train_dataset = OCRDataset("./dataset/energy-meter/train", resize=(256, 128))
val_dataset = OCRDataset("./dataset/energy-meter/val", resize=(256, 128))
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=0)

In [19]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True)

In [20]:
def infer(image, prompt="What is the printed values?", max_new_tokens=128):
    torch.cuda.empty_cache()

    if image.mode != "RGB":
        image = image.convert("RGB")

    with torch.inference_mode():
        inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
        generated_ids = model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=max_new_tokens,
            num_beams=3
        )
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        parsed_answer = processor.post_process_generation(generated_text, task="response", image_size=(image.width, image.height))["response"]
        
        inputs["input_ids"] = inputs["input_ids"].detach()
        inputs["pixel_values"] = inputs["pixel_values"].detach()
        generated_ids = generated_ids.detach()
        del inputs, generated_ids, generated_text
        torch.cuda.empty_cache()

        return parsed_answer

In [None]:
# Check model performance before training
for idx in range(3):
    display(train_dataset[idx][2])
    print(infer(train_dataset[idx][2]))
    print()

In [22]:
def save_model_and_processor(model, processor, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    model.save_pretrained(output_dir)
    processor.save_pretrained(output_dir)
    
    config_path = os.path.join(output_dir, "config.json")
    with open(config_path, "r") as f:
        config = json.load(f)
        config["model_type"] = "davit"
        config["text_config"]["model_type"] = "davit"
        config["vision_config"]["model_type"] = "davit"
    
    with open(config_path, "w") as f:
        json.dump(config, f)

def train_model(train_loader, val_loader, model, processor, device, epochs=10, lr=1e-6):
    optimizer = AdamW(model.parameters(), lr=lr)
    num_training_steps = epochs * len(train_loader)
    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
    )
    
    last_val_loss = float('inf')

    for epoch in range(epochs):
        model.train()
        train_loss = 0

        # Train loop with progress bar
        progress_bar = tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{epochs}")
        for batch in progress_bar:
            inputs, answers = batch

            input_ids = inputs["input_ids"].to(device)
            pixel_values = inputs["pixel_values"].to(device)
            labels = processor.tokenizer(text=answers, return_tensors="pt", padding=True, return_token_type_ids=False).input_ids.to(device)

            outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)
            loss = outputs.loss

            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            train_loss += loss.item()
            
            inputs["input_ids"].detach()
            inputs["pixel_values"].detach()
            del inputs, outputs
            torch.cuda.empty_cache()

        avg_train_loss = train_loss / len(train_loader)

        model.eval()
        val_loss = 0

        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}/{epochs}"):
                inputs, answers = batch

                input_ids = inputs["input_ids"].to(device)
                pixel_values = inputs["pixel_values"].to(device)
                labels = processor.tokenizer(text=answers, return_tensors="pt", padding=True, return_token_type_ids=False).input_ids.to(device)

                outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)
                val_loss += outputs.loss.item()
                
            inputs["input_ids"].detach()
            inputs["pixel_values"].detach()
            del inputs, outputs
            torch.cuda.empty_cache()

        avg_val_loss = val_loss / len(val_loader)
        tqdm.write(f"Average Validation Loss: {avg_val_loss}")

        # Save latest model
        save_model_and_processor(model, processor, "./weights/model_latest")

        if avg_val_loss < last_val_loss:
            last_val_loss = avg_val_loss
            # Save best model
            save_model_and_processor(model, processor, "./weights/model_best")

In [None]:
# We will freeze image encoder for faster training.
for param in model.vision_tower.parameters():
    param.is_trainable = False

train_model(train_loader, val_loader, model, processor, device, epochs=10)

In [None]:
# Check model performance after training
for idx in range(3):
    display(train_dataset[idx][2])
    print(infer(train_dataset[idx][2]))
    print()

In [None]:
correct, incorrect = 0, 0
for idx in tqdm(range(len(train_dataset))):
    label = train_dataset[idx][1]
    output = infer(train_dataset[idx][2])
    if output == label:
        correct += 1
    else:
        incorrect += 1

print(f"Train Accuracy: {correct / (correct+incorrect)}")

correct, incorrect = 0, 0
for idx in tqdm(range(len(val_dataset))):
    label = val_dataset[idx][1]
    output = infer(val_dataset[idx][2])
    if output == label:
        correct += 1
    else:
        incorrect += 1

print(f"Val Accuracy: {correct / (correct+incorrect)}")