In [None]:
from IPython.display import clear_output

!pip install salesforce-lavis
!pip install torch 
!pip install torchvision
!pip install transformers
!pip install peft==0.10.0
!pip install datasets
!pip install pillow
!pip install underthesea
!pip install huggingface_hub
!pip install hf_xet
!pip install google-genai
!pip install easyocr
clear_output()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision import models
from peft import get_peft_model, LoraConfig, TaskType

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from lavis.models import load_model_and_preprocess

from PIL import Image, ImageFile
import os
import re
import ast
import numpy as np
from tqdm import tqdm
from huggingface_hub import hf_hub_download, HfApi

from google import genai
from google.genai.types import GenerateContentConfig
import easyocr
from dotenv import load_dotenv

load_dotenv()

ImageFile.LOAD_TRUNCATED_IMAGES = True

2025-05-13 04:18:19.240760: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747109899.467328      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747109899.534744      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
  return torch.cuda.amp.custom_fwd(orig_func)  # type: ignore
  return torch.cuda.amp.custom_bwd(orig_func)  # type: ignore


In [None]:
# Define project-specific variables
PROJECT_NAME = "ViInfographicsVQA"  # Name of the project
USERNAME = "Namronaldo2004"  # Hugging Face username
HUGGINGFACE_HUB_REPO = (
    USERNAME + "/" + PROJECT_NAME
)  # Full repository name on Hugging Face Hub
TEXT_BASELINE_NAME = "Flow3-modified/Text"
TEXT_CHECKPOINT_FILENAME = f"{TEXT_BASELINE_NAME}/latest_checkpoint.pth"
NONTEXT_BASELINE_NAME = "Flow3-modified/Non-text"
NONTEXT_CHECKPOINT_FILENAME = f"{NONTEXT_BASELINE_NAME}/latest_checkpoint.pth"

In [None]:
OCR_READER = easyocr.Reader(["vi"])
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
CLIENT = genai.Client(api_key=GEMINI_API_KEY)
GEMINI_MODEL = "gemini-2.5-flash-preview-04-17"
GENERATION_CONFIG = GenerateContentConfig(response_mime_type="text/plain")

In [None]:
api = HfApi()
NONTEXT_CHECKPOINT_PATH = "./checkpoints/nontext_latest_checkpoint.pth"
TEXT_CHECKPOINT_PATH = "./checkpoints/text_latest_checkpoint.pth"
os.makedirs("checkpoints", exist_ok=True)

if not os.path.exists(NONTEXT_CHECKPOINT_PATH):
    if api.file_exists(
        repo_id=HUGGINGFACE_HUB_REPO,
        filename=NONTEXT_CHECKPOINT_FILENAME,
        repo_type="model",
    ):
        NONTEXT_CHECKPOINT_PATH = hf_hub_download(
            repo_id=HUGGINGFACE_HUB_REPO,
            filename=NONTEXT_CHECKPOINT_FILENAME,
            local_dir="./checkpoints",  # Store the checkpoint locally in the "checkpoints" directory
        )
print(NONTEXT_CHECKPOINT_PATH)
# =======================================================================================================
if not os.path.exists(TEXT_CHECKPOINT_PATH):
    if api.file_exists(
        repo_id=HUGGINGFACE_HUB_REPO,
        filename=TEXT_CHECKPOINT_FILENAME,
        repo_type="model",
    ):
        TEXT_CHECKPOINT_PATH = hf_hub_download(
            repo_id=HUGGINGFACE_HUB_REPO,
            filename=TEXT_CHECKPOINT_FILENAME,
            local_dir="./checkpoints",  # Store the checkpoint locally in the "checkpoints" directory
        )
print(TEXT_CHECKPOINT_PATH)

checkpoints/Flow3-modified/Non-text/latest_checkpoint.pth


In [None]:
class EfficientNetFeatureExtractor(nn.Module):
    def __init__(
        self,
        model_name: str = "efficientnet_b0",
        target_size: int = 224,
        central_fraction: float = 0.875,
    ):
        super(EfficientNetFeatureExtractor, self).__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self._load_model(model_name).to(self.device)
        self.transform = self._build_transform(target_size, central_fraction)

        self.pooling1 = nn.AdaptiveAvgPool2d((1, 32))
        self.pooling2 = nn.AdaptiveAvgPool2d((1, 1024))

    def _load_model(self, model_name: str) -> nn.Module:
        model_dict = {
            "efficientnet_b0": models.efficientnet_b0,
            "efficientnet_b1": models.efficientnet_b1,
            "efficientnet_b2": models.efficientnet_b2,
            "efficientnet_b3": models.efficientnet_b3,
            "efficientnet_b4": models.efficientnet_b4,
            "efficientnet_b5": models.efficientnet_b5,
            "efficientnet_b6": models.efficientnet_b6,
            "efficientnet_b7": models.efficientnet_b7,
        }

        if model_name not in model_dict:
            raise ValueError(
                f"Unsupported model_name '{model_name}'. Choose from: {list(model_dict.keys())}"
            )

        model = model_dict[model_name](weights="DEFAULT")
        return model.features  # Only use the feature extractor part

    def _build_transform(
        self, target_size: int, central_fraction: float
    ) -> nn.Sequential:
        resize_size = int(target_size / central_fraction)
        return transforms.Compose(
            [
                transforms.Resize(resize_size),
                transforms.CenterCrop(target_size),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

    def freeze(self):
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, images: Image.Image) -> torch.Tensor:
        images_tensor = torch.stack(
            [self.transform(image.convert("RGB")) for image in images]
        ).to(self.device)

        with torch.no_grad():
            features = self.model(images_tensor)

        features = self.pooling1(features)
        features = features.permute(0, 3, 2, 1)
        features = self.pooling2(features)

        batch_size = features.shape[0]
        flattened = features.view(batch_size, features.shape[1], -1)

        return flattened


class Blip2ViTExtractor(nn.Module):
    def __init__(self):
        super(Blip2ViTExtractor, self).__init__()
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model, self.preprocess, _ = load_model_and_preprocess(
            name="blip2_feature_extractor",
            model_type="pretrain",
            is_eval=True,
            device=self.device,
        )
        self.preprocess = self.preprocess["eval"]

        # 👇 Thêm lớp Linear để chuyển từ 768 → 1024
        self.linear_proj = nn.Linear(768, 1024)

    def freeze(self):
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, images):
        images = torch.stack(
            [self.preprocess(image.convert("RGB")).to(self.device) for image in images]
        )

        image_features = self.model.extract_features(
            samples={"image": images}, mode="image"
        ).image_embeds  # shape: (B, N, 768)

        image_features = self.linear_proj(image_features)  # shape: (B, N, 1024)
        return image_features

In [None]:
class BARTPho(nn.Module):
    def __init__(
        self,
        model_name="vinai/bartpho-syllable",
        device="cpu",
        max_length=50,
        use_lora=False,
        lora_r=8,
        lora_alpha=16,
        lora_dropout=0.05,
        target_modules=["q_proj", "v_proj"],
    ):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.device = device
        self.max_length = max_length

        # Load base model
        base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

        # Apply LoRA if needed
        if use_lora:
            lora_config = LoraConfig(
                r=lora_r,
                lora_alpha=lora_alpha,
                lora_dropout=lora_dropout,
                bias="none",
                task_type=TaskType.SEQ_2_SEQ_LM,
                target_modules=target_modules,
            )
            base_model = get_peft_model(base_model, lora_config)

        # Save full model
        self.encoder = base_model.base_model.model.model.encoder.to(device)
        self.decoder = base_model.base_model.model.model.decoder  # BART/M-BART decoder
        self.lm_head = base_model.base_model.model.lm_head

    def encode(self, input_texts):
        batch_encoder_hidden_states = []
        batch_attention_masks = []
        batch_input_ids = []

        for text in input_texts:
            # Split sentence into chunks (each chunk has at most 512 words)
            words = text.split()
            chunks = [
                " ".join(words[i * 384 : (i + 1) * 384])
                for i in range((len(words) + 383) // 384)
            ]

            chunk_input_ids_list = []
            chunk_attention_mask_list = []
            chunk_hidden_states_list = []

            for chunk in chunks:
                encoded = self.tokenizer(
                    chunk, return_tensors="pt", padding=True, truncation=True
                ).to(self.device)

                input_ids = encoded["input_ids"]  # shape: (1, seq_len)
                attention_mask = encoded["attention_mask"]  # shape: (1, seq_len)

                outputs = self.encoder(
                    input_ids=input_ids, attention_mask=attention_mask
                )

                chunk_input_ids_list.append(input_ids)
                chunk_attention_mask_list.append(attention_mask)
                chunk_hidden_states_list.append(outputs.last_hidden_state)

            # Concatenate all chunks → shape: (1, total_seq_len, hidden_size)
            full_input_ids = torch.cat(chunk_input_ids_list, dim=1).squeeze(
                0
            )  # shape: (total_seq_len)
            full_attention_mask = torch.cat(chunk_attention_mask_list, dim=1).squeeze(
                0
            )  # shape: (total_seq_len)
            full_hidden_states = torch.cat(chunk_hidden_states_list, dim=1).squeeze(
                0
            )  # shape: (total_seq_len, hidden_size)

            batch_input_ids.append(full_input_ids)
            batch_attention_masks.append(full_attention_mask)
            batch_encoder_hidden_states.append(full_hidden_states)

        # Pad sequences in the batch to the same length
        max_seq_len = max(x.shape[0] for x in batch_input_ids)

        def pad_tensor(tensor, max_len, pad_value=0):
            pad_len = max_len - tensor.shape[0]
            if pad_len == 0:
                return tensor
            pad = torch.full(
                (pad_len,), pad_value, dtype=tensor.dtype, device=tensor.device
            )
            return torch.cat([tensor, pad], dim=0)

        def pad_hidden_states(tensor, max_len):
            pad_len = max_len - tensor.shape[0]
            if pad_len == 0:
                return tensor
            pad = torch.zeros((pad_len, tensor.shape[1]), device=tensor.device)
            return torch.cat([tensor, pad], dim=0)

        # Stack padded tensors
        input_ids = torch.stack(
            [
                pad_tensor(x, max_seq_len, self.tokenizer.pad_token_id)
                for x in batch_input_ids
            ],
            dim=0,
        )
        attention_mask = torch.stack(
            [pad_tensor(x, max_seq_len, 0) for x in batch_attention_masks], dim=0
        )
        encoder_hidden_states = torch.stack(
            [pad_hidden_states(x, max_seq_len) for x in batch_encoder_hidden_states],
            dim=0,
        )

        return {
            "encoder_hidden_states": encoder_hidden_states,  # shape: (batch_size, seq_len, hidden_size)
            "attention_mask": attention_mask,  # shape: (batch_size, seq_len)
            "input_ids": input_ids,  # shape: (batch_size, seq_len)
        }

    def decode(
        self,
        answer_input_ids,
        answer_attention_mask,
        encoder_hidden_states,
        encoder_attention_mask,
    ):
        decoder_outputs = self.decoder(
            input_ids=answer_input_ids,
            attention_mask=answer_attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
        )

        logits = self.lm_head(decoder_outputs.last_hidden_state)
        return logits

    def generate(self, encoder_hidden_states, encoder_attention_mask):
        batch_size = encoder_hidden_states.size(0)
        device = encoder_hidden_states.device

        # Bắt đầu với decoder_input_ids là eos_token_id cho mỗi dòng trong batch
        decoder_input_ids = torch.full(
            (batch_size, 1),
            fill_value=self.tokenizer.eos_token_id,
            dtype=torch.long,
            device=device,
        )

        with torch.no_grad():
            for _ in range(self.max_length):
                logits = self.decode(
                    answer_input_ids=decoder_input_ids,
                    answer_attention_mask=torch.ones_like(
                        decoder_input_ids, device=device
                    ),
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask,
                )

                next_token = logits[:, -1, :].argmax(-1, keepdim=True)  # (B, 1)
                decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=-1)

                # Dừng sớm nếu tất cả dòng đều sinh ra <eos>
                if (next_token == self.tokenizer.eos_token_id).all():
                    break

        return self.tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True)

    def freeze_encoder(self, layers_to_freeze=None):
        if layers_to_freeze is None:
            for param in self.encoder.parameters():
                param.requires_grad = False
        else:
            for idx, layer in enumerate(self.encoder.layers):
                if idx in layers_to_freeze:
                    for param in layer.parameters():
                        param.requires_grad = False

    def unfreeze_encoder(self):
        for param in self.encoder.parameters():
            param.requires_grad = True

In [None]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model=768, d_ff=2048, dropout=0.1):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        out = self.dropout1(F.gelu(self.fc1(x)))
        out = self.dropout2(self.fc2(out))
        return self.norm(x + out)


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=768, num_heads=8, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        self.attn = nn.MultiheadAttention(
            embed_dim=d_model, num_heads=num_heads, dropout=dropout, batch_first=True
        )
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, attention_mask=None):
        if attention_mask is not None:
            attention_mask = (
                ~attention_mask.bool()
            )  # Convert to padding mask: True = MASK
        attn_output, _ = self.attn(
            queries, keys, values, key_padding_mask=attention_mask
        )
        out = self.dropout(attn_output)
        return self.norm(queries + out)


class EncoderLayer(nn.Module):
    def __init__(self, d_model=768, num_heads=8, d_ff=2048, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.mhatt = MultiHeadAttention(d_model, num_heads, dropout)
        self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout)

    def forward(self, queries, keys, values, attention_mask=None):
        att = self.mhatt(queries, keys, values, attention_mask)
        ff = self.pwff(att)
        return ff


class BiDirectionalCrossAttention(nn.Module):
    def __init__(
        self,
        d_model=1024,
        num_heads=8,
        d_ff=2048,
        dropout=0.1,
        num_layers=3,
        max_len=6000,
    ):
        super(BiDirectionalCrossAttention, self).__init__()

        self.vision_pos_embed = nn.Embedding(max_len, d_model)
        self.text_pos_embed = nn.Embedding(max_len, d_model)

        self.vision_norm = nn.LayerNorm(d_model)
        self.text_norm = nn.LayerNorm(d_model)

        self.d_model = d_model  # D = 1024
        # self.text_proj = nn.Linear(1024, self.d_model)  # ❌ Loại bỏ vì không cần nữa

        self.vision_language_attn_layers = nn.ModuleList(
            [EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )
        self.language_vision_attn_layers = nn.ModuleList(
            [EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )
        self.vision_self_attn_layers = nn.ModuleList(
            [EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )
        self.language_self_attn_layers = nn.ModuleList(
            [EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )

    def forward(self, vision_feats, vision_mask, text_feats, text_mask):
        batch_size, v_len, _ = vision_feats.size()
        _, t_len, _ = text_feats.size()

        v_pos_ids = (
            torch.arange(v_len, device=vision_feats.device)
            .unsqueeze(0)
            .repeat(batch_size, 1)
        )
        t_pos_ids = (
            torch.arange(t_len, device=text_feats.device)
            .unsqueeze(0)
            .repeat(batch_size, 1)
        )

        vision_feats = self.vision_norm(vision_feats + self.vision_pos_embed(v_pos_ids))
        text_feats = self.text_norm(text_feats + self.text_pos_embed(t_pos_ids))

        for vl_attn, lv_attn, v_self, l_self in zip(
            self.vision_language_attn_layers,
            self.language_vision_attn_layers,
            self.vision_self_attn_layers,
            self.language_self_attn_layers,
        ):
            vision_feats = vl_attn(vision_feats, text_feats, text_feats, text_mask)
            text_feats = lv_attn(text_feats, vision_feats, vision_feats, vision_mask)

            vision_feats = v_self(vision_feats, vision_feats, vision_feats, vision_mask)
            text_feats = l_self(text_feats, text_feats, text_feats, text_mask)

        fused_feats = torch.cat([vision_feats, text_feats], dim=1)  # shape: (B, V+T, D)

        return fused_feats

In [None]:
class NonTextModel(nn.Module):
    def __init__(self):
        super(NonTextModel, self).__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.local_visual_extractor = EfficientNetFeatureExtractor(
            model_name="efficientnet_b7"
        ).to(self.device)
        self.global_visual_extractor = Blip2ViTExtractor().to(self.device)

        self.bart_pho = BARTPho(device=self.device, use_lora=True)  # truyền rõ device
        self.encoder = BiDirectionalCrossAttention(max_len=1028).to(self.device)

    def forward(self, images, questions, answers):
        local_features = self.local_visual_extractor(images)
        global_features = self.global_visual_extractor(images)
        vision_feats = torch.cat([local_features, global_features], dim=1)
        vision_feats = vision_feats.to(self.device)
        vision_mask = torch.ones(vision_feats.size()[:-1], dtype=torch.bool).to(
            self.device
        )

        text_encoding = self.bart_pho.encode(questions)
        text_feats = text_encoding["encoder_hidden_states"]
        question_attention_mask = text_encoding["attention_mask"]

        encoder_output = self.encoder(
            vision_feats, vision_mask, text_feats, question_attention_mask
        )
        encoder_attention_mask = torch.cat(
            [vision_mask, question_attention_mask], dim=1
        )

        answer_encoded = self.bart_pho.tokenizer(
            answers, return_tensors="pt", padding=True, truncation=True
        ).to(self.device)
        answer_input_ids = answer_encoded["input_ids"]
        answer_attention_mask = answer_encoded["attention_mask"]

        logits = self.bart_pho.decode(
            answer_input_ids=answer_input_ids,
            answer_attention_mask=answer_attention_mask,
            encoder_hidden_states=encoder_output,
            encoder_attention_mask=encoder_attention_mask,
        )

        return logits

    def generate(self, images, questions):
        # Step 1: Extract visual features
        with torch.no_grad():
            local_features = self.local_visual_extractor(images)  # (B, N, D)
            global_features = self.global_visual_extractor(images)  # (B, 1, D)
            vision_feats = torch.cat([local_features, global_features], dim=1)
            vision_mask = torch.ones(vision_feats.size()[:-1], dtype=torch.bool).to(
                self.device
            )

            text_encoding = self.bart_pho.encode(questions)
            text_feats = text_encoding["encoder_hidden_states"]
            question_attention_mask = text_encoding["attention_mask"]

            # Step 3: Co-Attention Fusion
            encoder_output = self.encoder(
                vision_feats, vision_mask, text_feats, question_attention_mask
            ).to(self.device)
            encoder_attention_mask = torch.cat(
                [vision_mask, question_attention_mask], dim=1
            ).to(self.device)

            return self.bart_pho.generate(encoder_output, encoder_attention_mask)

In [None]:
class TextModel(nn.Module):
    def __init__(self):
        super(TextModel, self).__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.bart_pho = BARTPho(device=self.device, use_lora=True)  # truyền rõ device
        self.bart_pho.freeze_encoder()
        self.encoder = BiDirectionalCrossAttention().to(self.device)

    def forward(self, ocr_infos, questions, answers):
        ocr_encoding = self.bart_pho.encode(ocr_infos)
        ocr_feats = ocr_encoding["encoder_hidden_states"]
        ocr_attention_mask = ocr_encoding["attention_mask"]

        question_encoding = self.bart_pho.encode(questions)
        question_feats = question_encoding["encoder_hidden_states"]
        question_attention_mask = question_encoding["attention_mask"]

        encoder_output = self.encoder(
            ocr_feats, ocr_attention_mask, question_feats, question_attention_mask
        )
        encoder_attention_mask = torch.cat(
            [ocr_attention_mask, question_attention_mask], dim=1
        )

        answer_encoded = self.bart_pho.tokenizer(
            answers, return_tensors="pt", padding=True, truncation=True
        ).to(self.device)
        answer_input_ids = answer_encoded["input_ids"]
        answer_attention_mask = answer_encoded["attention_mask"]

        logits = self.bart_pho.decode(
            answer_input_ids=answer_input_ids,
            answer_attention_mask=answer_attention_mask,
            encoder_hidden_states=encoder_output,
            encoder_attention_mask=encoder_attention_mask,
        )

        return logits

    def generate(self, ocr_infos, questions):
        # Step 1: Extract visual features
        with torch.no_grad():
            ocr_encoding = self.bart_pho.encode(ocr_infos)
            ocr_feats = ocr_encoding["encoder_hidden_states"]
            ocr_attention_mask = ocr_encoding["attention_mask"]

            question_encoding = self.bart_pho.encode(questions)
            question_feats = question_encoding["encoder_hidden_states"]
            question_attention_mask = question_encoding["attention_mask"]

            encoder_output = self.encoder(
                ocr_feats, ocr_attention_mask, question_feats, question_attention_mask
            )
            encoder_attention_mask = torch.cat(
                [ocr_attention_mask, question_attention_mask], dim=1
            )

            return self.bart_pho.generate(encoder_output, encoder_attention_mask)

In [None]:
def build_type_prompt(question: str):
    instructions = """
    Bạn sẽ nhận được một câu hỏi. Hãy phân loại xem câu hỏi đó là "text" hoặc là "non-text".
    
    **Định nghĩa**:
    - "text": bao gồm các câu hỏi dựa trên các yếu tố trích xuất từ văn bản, liên quan đến các số liệu, từ ngữ, hoặc bất kỳ hình thức văn bản nào.
    - "non-text": bao gồm các câu hỏi không dựa trên các yếu tố văn bản, chẳng hạn như các câu hỏi dựa trên hình ảnh, vật thể, màu sắc, hình dạng, v.v.
    
    Trả lời đầu ra là một từ duy nhất: "text" hoặc "non-text".
    
    Dưới đây là câu hỏi cần phân loại:
    """.strip()

    ending = '\n\nHãy trả về 1 từ duy nhất: "text" hoặc "non-text".'

    full_prompt = f"{instructions}\n{question}{ending}"
    return full_prompt

In [None]:
def build_ocr_prompt(ocr_chunks: list[str]) -> str:
    prompt = f"""Bạn nhận được một danh sách văn bản ngắn, trích xuất từ ảnh infographic (OCR).
            
    Yêu cầu:
    - Gom nhóm các dòng có liên quan theo ngữ cảnh.
    - Viết lại thành các câu hoàn chỉnh, ngắn gọn, rõ nghĩa dùng để thực hiện text embedding.
    - Trả về **duy nhất một Python list hợp lệ** chứa các câu, ví dụ:
    ```python
    [
        "Câu hoàn chỉnh 1.",
        "Câu hoàn chỉnh 2.",
        ...
    ]
    
    Dữ liệu:
    {ocr_chunks}
    
    """

    return prompt

In [None]:
def fix_list_string(raw_str: str) -> str:
    # Bước 1: Dọn sạch chuỗi đầu vào
    raw_str = raw_str.strip()

    # Bước 2: Chèn dấu phẩy giữa các dấu " liền nhau nếu có
    raw_str = re.sub(r'"\s*"', '", "', raw_str)

    # Bước 3: Parse từng dòng
    lines = raw_str.strip("[]").split("\n")
    cleaned_lines = []

    for i, line in enumerate(lines):
        line = line.strip().rstrip(",")

        if not line:
            continue

        # Loại bỏ dấu " bao ngoài nếu có
        if line.startswith('"') and line.endswith('"'):
            line = line[1:-1]

        # Escape các dấu " bên trong nội dung
        line = line.replace("\\", "\\\\").replace('"', '\\"')

        # Thêm lại dấu " bao ngoài
        line = f'"{line}"'

        # Thêm dấu phẩy nếu chưa phải dòng cuối
        if i < len(lines) - 1:
            line += ","

        cleaned_lines.append(line)

    fixed_str = "[\n" + "\n".join(cleaned_lines) + "\n]"
    return fixed_str


def clean_ocr_contents(response: str) -> str:
    response = re.sub(r"^```python\s*", "", response.strip(), flags=re.IGNORECASE)
    response = re.sub(r"\s*```$", "", response.strip())
    # Thay dấu ngoặc kép cong thành thẳng
    response = response.replace("“", '"').replace("”", '"')
    # Thực thi parsing
    response = fix_list_string(response)
    meaning_sentences = ast.literal_eval(response)
    joined_text = " ".join(meaning_sentences)

    return joined_text

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_checkpoint(model, filepath):
    #
    if os.path.isfile(filepath):
        checkpoint = torch.load(filepath, map_location=DEVICE)
        model.load_state_dict(checkpoint["model_state_dict"])


#
TEXT_MODEL = TextModel().to(DEVICE)
load_checkpoint(TEXT_MODEL, TEXT_CHECKPOINT_PATH)
TEXT_MODEL.eval()

#
NONTEXT_MODEL = NonTextModel().to(DEVICE)
load_checkpoint(NONTEXT_MODEL, NONTEXT_CHECKPOINT_PATH)
NONTEXT_MODEL.eval()

Downloading: "https://download.pytorch.org/models/efficientnet_b7_lukemelas-c5b4e57e.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b7_lukemelas-c5b4e57e.pth
100%|██████████| 255M/255M [00:02<00:00, 93.6MB/s] 


vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

100%|██████████| 1.89G/1.89G [00:47<00:00, 42.7MB/s]
  state_dict = torch.load(cached_file, map_location="cpu")


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

100%|██████████| 712M/712M [00:17<00:00, 43.2MB/s] 
  checkpoint = torch.load(cached_file, map_location="cpu")


config.json:   0%|          | 0.00/897 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

dict.txt:   0%|          | 0.00/360k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.58G [00:00<?, ?B/s]

  return torch.load(checkpoint_file, map_location="cpu")
  checkpoint = torch.load(filepath, map_location=DEVICE)


In [None]:
def generate_answer(image: Image, question: str) -> str:
    #
    prompt = build_type_prompt(question)
    print(f"Prompt: {prompt}")
    question_type = CLIENT.models.generate_content(
        model=GEMINI_MODEL, contents=prompt, config=GENERATION_CONFIG
    ).text.strip()
    print(f"Question types: {question_type}")

    #
    if question_type.lower() == "text":
        #
        image_np = np.array(image)
        ocr_chunks = OCR_READER.readtext(image_np)
        ocr_chunks = list(chunk[1] for chunk in ocr_chunks)

        #
        ocr_prompt = build_ocr_prompt(ocr_chunks)
        contents = [ocr_prompt, image]

        try:
            #
            response = CLIENT.models.generate_content(
                model=GEMINI_MODEL, contents=contents, config=GENERATION_CONFIG
            )
            response = response.text
            ocr_info = clean_ocr_contents(response)
            print(f"OCR ")
        except:
            return (
                "Your image may contain sensitive or violent content, please change it!"
            )

        with torch.no_grad():
            answer = TEXT_MODEL.generate([ocr_info], [question])
            print(f"Text model responded")
        return answer
    elif question_type.lower() == "non-text":
        with torch.no_grad():
            answer = NONTEXT_MODEL.generate([image], [question])
            print(f"Non-text model responded")
        return answer
    else:
        return "Please try again!"


# def generate_answer(image: Image, question: str) -> str:
#     with torch.no_grad():
#         answer = NONTEXT_MODEL.generate([image], [question])
#         print(f"Non-text model responded")
#     return answer

___

In [None]:
!pip install -q fastapi uvicorn nest-asyncio pyngrok pillow

clear_output()

In [None]:
!ngrok config add-authtoken <YOUR_NGROK_TOKEN_KEY_HERE>

Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml                                


In [None]:
import io
import gc
import base64
from PIL import Image
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from pyngrok import ngrok
import nest_asyncio
import uvicorn

# Allow nested async event loop (required for Jupyter/Kaggle)
nest_asyncio.apply()


class InferenceRequest(BaseModel):
    questions: list[str]
    image_base64: str


class InferenceResponse(BaseModel):
    answers: list[str]

In [None]:
app = FastAPI()

# Enable CORS for development (allow all)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


def decode_image(base64_string: str) -> Image:
    try:
        image_data = base64.b64decode(base64_string)
        image = Image.open(io.BytesIO(image_data)).convert("RGB")
        print(f"Image decoded")
        return image
    except Exception as e:
        raise ValueError("Invalid base64 image input")


@app.post("/infer", response_model=InferenceResponse)
def infer(request: InferenceRequest):
    try:
        # print(request.question)
        image = decode_image(request.image_base64)

        result = generate_answer(image, request.questions[0])
        print(result)
        del image
        return InferenceResponse(answers=result)

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

    finally:
        # Explicit memory cleanup
        gc.collect()  # Run Python garbage collection
        if torch.cuda.is_available():
            torch.cuda.empty_cache()  # Free up CUDA memory


try:
    ngrok.kill()
    public_url = ngrok.connect(8000)
    print(f"{public_url}")
    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
except (KeyboardInterrupt, Exception):
    ...

INFO:     Started server process [31]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)


NgrokTunnel: "https://204d-34-141-197-202.ngrok-free.app" -> "http://localhost:8000"


INFO:     Shutting down
