In [None]:
from IPython.display import clear_output

In [None]:
!pip install salesforce-lavis
!pip install torch 
!pip install torchvision
!pip install transformers
!pip install peft==0.10.0
!pip install datasets
!pip install pillow
!pip install matplotlib
!pip install tabulate
!pip install underthesea
!pip install huggingface_hub
!pip install hf_xet
!pip install python-dotenv

clear_output()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision import models
from peft import get_peft_model, LoraConfig, TaskType

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from lavis.models import load_model_and_preprocess
from datasets import load_dataset
from torch.utils.data import DataLoader, IterableDataset

from PIL import Image
from dotenv import load_dotenv
import os
from tqdm import tqdm
from huggingface_hub import hf_hub_download

load_dotenv()

In [None]:
# Define project-specific variables
PROJECT_NAME = "ViInfographicsVQA"  # Name of the project
USERNAME = "Namronaldo2004"          # Hugging Face username
HUGGINGFACE_HUB_REPO = USERNAME + "/" + PROJECT_NAME  # Full repository name on Hugging Face Hub
REPO_ACCESS_TOKEN = os.getenv("MODEL_API_KEY")  # Hugging Face access token
BASELINE_NAME = "Flow1-modified"
CHECKPOINT_FILENAME = f"{BASELINE_NAME}/latest_checkpoint.pth"
NUM_EPOCHS = 20

In [None]:
# Import Hugging Face Hub utilities
from huggingface_hub.hf_api import HfFolder  # For handling authentication tokens
from huggingface_hub import HfApi  # Tools for managing repositories on Hugging Face Hub

# Save the Hugging Face authentication atoken
HfFolder.save_token(REPO_ACCESS_TOKEN)

In [None]:
api = HfApi()
CHECKPOINT_PATH = "./checkpoints/latest_checkpoint.pth"
if api.file_exists(
    repo_id = HUGGINGFACE_HUB_REPO, 
    filename = CHECKPOINT_FILENAME, 
    repo_type = "model"
):
    CHECKPOINT_PATH = hf_hub_download(
        repo_id = HUGGINGFACE_HUB_REPO, 
        filename = CHECKPOINT_FILENAME, 
        local_dir = "./checkpoints",  # Store the checkpoint locally in the "checkpoints" directory
        local_dir_use_symlinks = False  # Avoid using symlinks for compatibility
    )

os.makedirs("checkpoints", exist_ok = True)
print(CHECKPOINT_PATH)

In [None]:
class ViInfographicsVQADataset(IterableDataset):
    """
    Custom dataset for streaming text-image pairs from a Hugging Face dataset.
    This dataset yields image-text pairs one by one, without preloading everything into memory.
    """

    def __init__(self, hf_dataset,
                 transform = None, max_instances = None):
        """
        Initialize the dataset.
        
        Args:
            hf_dataset: The Hugging Face dataset containing image-text pairs.
            transform: Optional image transformations (e.g., resizing, normalization).
            max_instances: Maximum number of instances to process.
        """
        self.hf_dataset = hf_dataset
        self.transform = transform
        self.max_instances = max_instances

    def __iter__(self):
        """
        Iterator to stream dataset samples.
        
        Yields:
            A tuple (image, text) where:
            - image: Transformed image tensor
            - text: Corresponding textual prompt
        """
        count = 0
        for sample in self.hf_dataset:
            if (self.max_instances is not None and count >= self.max_instances):
                break  # Stop if max instances reached
            
            image = sample["image"]  # Load image from dataset
            question = sample["question"]
            answer = sample["answer"]
            
            if self.transform:
                image = self.transform(image)  # Apply transformations if provided
            
            yield image, question, answer
            count += 1  # Increment the counter

In [None]:
class EfficientNetFeatureExtractor(nn.Module):
    def __init__(
        self,
        model_name: str = "efficientnet_b0",
        target_size: int = 224,
        central_fraction: float = 0.875,
    ):
        super(EfficientNetFeatureExtractor, self).__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self._load_model(model_name).to(self.device)
        self.transform = self._build_transform(target_size, central_fraction)

        self.pooling1 = nn.AdaptiveAvgPool2d((1, 32))
        self.pooling2 = nn.AdaptiveAvgPool2d((1, 1024))

    def _load_model(self, model_name: str) -> nn.Module:
        model_dict = {
            "efficientnet_b0": models.efficientnet_b0,
            "efficientnet_b1": models.efficientnet_b1,
            "efficientnet_b2": models.efficientnet_b2,
            "efficientnet_b3": models.efficientnet_b3,
            "efficientnet_b4": models.efficientnet_b4,
            "efficientnet_b5": models.efficientnet_b5,
            "efficientnet_b6": models.efficientnet_b6,
            "efficientnet_b7": models.efficientnet_b7,
        }

        if model_name not in model_dict:
            raise ValueError(f"Unsupported model_name '{model_name}'. Choose from: {list(model_dict.keys())}")

        model = model_dict[model_name](weights="DEFAULT")
        return model.features  # Only use the feature extractor part

    def _build_transform(self, target_size: int, central_fraction: float) -> nn.Sequential:
        resize_size = int(target_size / central_fraction)
        return transforms.Compose([
            transforms.Resize(resize_size),
            transforms.CenterCrop(target_size),
            transforms.ToTensor(),
            transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                 std = [0.229, 0.224, 0.225]),
        ])

    def freeze(self):
        for param in self.model.parameters():
            param.requires_grad = False
    
    def forward(self, images: Image.Image) -> torch.Tensor:
        images_tensor = torch.stack([
            self.transform(image.convert("RGB")) for image in images
        ]).to(self.device)

        with torch.no_grad():
            features = self.model(images_tensor)

        features = self.pooling1(features)
        features = features.permute(0, 3, 2, 1)
        features = self.pooling2(features)

        batch_size = features.shape[0]
        flattened = features.view(batch_size, features.shape[1], -1)

        return flattened

class Blip2ViTExtractor(nn.Module):
    def __init__(self):
        super(Blip2ViTExtractor, self).__init__()
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model, self.preprocess, _ = load_model_and_preprocess(
            name="blip2_feature_extractor",
            model_type="pretrain",
            is_eval=True,
            device=self.device
        )
        self.preprocess = self.preprocess["eval"]

        # 👇 Thêm lớp Linear để chuyển từ 768 → 1024
        self.linear_proj = nn.Linear(768, 1024)

    def freeze(self):
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, images):
        images = torch.stack([
            self.preprocess(image.convert("RGB")).to(self.device)
            for image in images
        ])

        image_features = self.model.extract_features(
            samples={"image": images},
            mode="image"
        ).image_embeds  # shape: (B, N, 768)

        image_features = self.linear_proj(image_features)  # shape: (B, N, 1024)
        return image_features

In [None]:
class BARTpho(nn.Module):
    def __init__(
        self,
        model_name="vinai/bartpho-syllable",
        device="cpu",
        max_length=50,
        use_lora=False,
        lora_r=8,
        lora_alpha=16,
        lora_dropout=0.05,
        target_modules=["q_proj", "v_proj"]
    ):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.device = device
        self.max_length = max_length

        # Load base model
        base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

        # Apply LoRA if needed
        if use_lora:
            lora_config = LoraConfig(
                r=lora_r,
                lora_alpha=lora_alpha,
                lora_dropout=lora_dropout,
                bias="none",
                task_type=TaskType.SEQ_2_SEQ_LM,
                target_modules=target_modules
            )
            base_model = get_peft_model(base_model, lora_config)

        # Save full model
        self.encoder = base_model.base_model.model.model.encoder.to(device)
        self.decoder = base_model.base_model.model.model.decoder  # BART/M-BART decoder
        self.lm_head = base_model.base_model.model.lm_head

    def encode(self, input_texts):
        encoded = self.tokenizer(
            input_texts,
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to(self.device)

        input_ids = encoded["input_ids"]
        attention_mask = encoded["attention_mask"]

        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        return {
            "encoder_hidden_states": encoder_outputs.last_hidden_state,
            "attention_mask": attention_mask,
            "input_ids": input_ids
        }

    def decode(
        self,
        answer_input_ids,
        answer_attention_mask,
        encoder_hidden_states,
        encoder_attention_mask,
    ):
        decoder_outputs = self.decoder(
            input_ids=answer_input_ids,
            attention_mask=answer_attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
        )
        
        logits = self.lm_head(decoder_outputs.last_hidden_state)
        return logits

    def generate(self, encoder_hidden_states, encoder_attention_mask):
        decoder_input_ids = torch.tensor([[self.tokenizer.eos_token_id]], device=encoder_hidden_states.device)

        with torch.no_grad():
            for _ in range(self.max_length):
                logits = self.decode(
                    answer_input_ids=decoder_input_ids,
                    answer_attention_mask=torch.ones_like(decoder_input_ids).to(encoder_hidden_states.device),
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask
                )

                next_token = logits[:, -1, :].argmax(-1, keepdim=True)
                decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=-1)

        return self.tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True)

    def freeze_encoder(self, layers_to_freeze=None):
        if layers_to_freeze is None:
            for param in self.encoder.parameters():
                param.requires_grad = False
        else:
            for idx, layer in enumerate(self.encoder.layers):
                if idx in layers_to_freeze:
                    for param in layer.parameters():
                        param.requires_grad = False

    def unfreeze_encoder(self):
        for param in self.encoder.parameters():
            param.requires_grad = True

In [None]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model = 768, d_ff = 2048, dropout = 0.1):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        out = self.dropout1(F.gelu(self.fc1(x)))
        out = self.dropout2(self.fc2(out))
        return self.norm(x + out)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model = 768, num_heads = 8, dropout = 0.1):
        super(MultiHeadAttention, self).__init__()
        self.attn = nn.MultiheadAttention(embed_dim = d_model, num_heads = num_heads, dropout = dropout, batch_first = True)
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, attention_mask=None):
        if attention_mask is not None:
            attention_mask = ~attention_mask.bool()  # Convert to padding mask: True = MASK
        attn_output, _ = self.attn(queries, keys, values, key_padding_mask = attention_mask)
        out = self.dropout(attn_output)
        return self.norm(queries + out)

class EncoderLayer(nn.Module):
    def __init__(self, d_model = 768, num_heads = 8, d_ff = 2048, dropout = 0.1):
        super(EncoderLayer, self).__init__()
        self.mhatt = MultiHeadAttention(d_model, num_heads, dropout)
        self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout)

    def forward(self, queries, keys, values, attention_mask = None):
        att = self.mhatt(queries, keys, values, attention_mask)
        ff = self.pwff(att)
        return ff

class BiDirectionalCrossAttention(nn.Module):
    def __init__(self, d_model = 1024, num_heads = 8, d_ff = 2048, dropout = 0.1, num_layers = 3, max_len = 1028):
        super(BiDirectionalCrossAttention, self).__init__()

        self.vision_pos_embed = nn.Embedding(max_len, d_model)
        self.text_pos_embed = nn.Embedding(max_len, d_model)

        self.vision_norm = nn.LayerNorm(d_model)
        self.text_norm = nn.LayerNorm(d_model)

        self.d_model = d_model  # D = 1024
        # self.text_proj = nn.Linear(1024, self.d_model)  # ❌ Loại bỏ vì không cần nữa

        self.vision_language_attn_layers = nn.ModuleList(
            [EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )
        self.language_vision_attn_layers = nn.ModuleList(
            [EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )
        self.vision_self_attn_layers = nn.ModuleList(
            [EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )
        self.language_self_attn_layers = nn.ModuleList(
            [EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )

    def forward(self, vision_feats, vision_mask, text_feats, text_mask):
        batch_size, v_len, _ = vision_feats.size()
        _, t_len, _ = text_feats.size()

        v_pos_ids = torch.arange(v_len, device=vision_feats.device).unsqueeze(0).repeat(batch_size, 1)
        t_pos_ids = torch.arange(t_len, device=text_feats.device).unsqueeze(0).repeat(batch_size, 1)

        vision_feats = self.vision_norm(vision_feats + self.vision_pos_embed(v_pos_ids))
        text_feats = self.text_norm(text_feats + self.text_pos_embed(t_pos_ids))

        for vl_attn, lv_attn, v_self, l_self in zip(
            self.vision_language_attn_layers,
            self.language_vision_attn_layers,
            self.vision_self_attn_layers,
            self.language_self_attn_layers
        ):
            vision_feats = vl_attn(vision_feats, text_feats, text_feats, text_mask)
            text_feats = lv_attn(text_feats, vision_feats, vision_feats, vision_mask)

            vision_feats = v_self(vision_feats, vision_feats, vision_feats, vision_mask)
            text_feats = l_self(text_feats, text_feats, text_feats, text_mask)

        fused_feats = torch.cat([vision_feats, text_feats], dim=1)  # shape: (B, V+T, D)

        return fused_feats

In [None]:
class NonTextModel(nn.Module):
    def __init__(self):
        super(NonTextModel, self).__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.local_visual_extractor = EfficientNetFeatureExtractor(model_name='efficientnet_b7').to(self.device)
        self.global_visual_extractor = Blip2ViTExtractor().to(self.device)
        
        self.bart_pho = BARTpho(device = self.device, use_lora = True)  # truyền rõ device
        self.encoder = BiDirectionalCrossAttention().to(self.device)

    def forward(self, images, questions, answers):
        local_features = self.local_visual_extractor(images)
        global_features = self.global_visual_extractor(images)
        vision_feats = torch.cat([local_features, global_features], dim=1)
        vision_feats = vision_feats.to(self.device)
        vision_mask = torch.ones(vision_feats.size()[:-1], dtype=torch.bool).to(self.device)
    
        text_encoding = self.bart_pho.encode(questions)
        text_feats = text_encoding["encoder_hidden_states"]
        question_attention_mask = text_encoding["attention_mask"]
    
        encoder_output = self.encoder(vision_feats, vision_mask, text_feats, question_attention_mask)
        encoder_attention_mask = torch.cat([vision_mask, question_attention_mask], dim=1)
    
        answer_encoded = self.bart_pho.tokenizer(
            answers, return_tensors="pt", padding=True, truncation=True
        ).to(self.device)
        answer_input_ids = answer_encoded["input_ids"]
        answer_attention_mask = answer_encoded["attention_mask"]
    
        logits = self.bart_pho.decode(
            answer_input_ids=answer_input_ids,
            answer_attention_mask=answer_attention_mask,
            encoder_hidden_states=encoder_output,
            encoder_attention_mask=encoder_attention_mask
        )
    
        return logits

    def generate(self, images, questions):
         # Step 1: Extract visual features
        with torch.no_grad():
            local_features = self.local_visual_extractor(images)      # (B, N, D)
            global_features = self.global_visual_extractor(images)    # (B, 1, D)
            vision_feats = torch.cat([local_features, global_features], dim = 1)
            vision_mask = torch.ones(vision_feats.size()[:-1], dtype = torch.bool).to(self.device)
    
            text_encoding = self.bart_pho.encode(questions)
            text_feats = text_encoding["encoder_hidden_states"]
            question_attention_mask = text_encoding["attention_mask"]
    
            # Step 3: Co-Attention Fusion
            encoder_output = self.encoder(vision_feats, vision_mask, text_feats, question_attention_mask).to(self.device)
            encoder_attention_mask = torch.cat([vision_mask, question_attention_mask], dim=1).to(self.device)
    
            return self.bart_pho.generate(encoder_output, encoder_attention_mask)

In [None]:
class VQATrainer:
    def __init__(self, model, train_loader, criterion, optimizer, lr_scheduler):
        self.model = model
        self.train_loader = train_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

        # Store training loss values for monitoring
        self.train_losses = []

    def train(self, start_epoch = 0, num_epochs = NUM_EPOCHS):
        self.model.train()

        for epoch in range(start_epoch, min(start_epoch + num_epochs, NUM_EPOCHS)):
            total_loss = 0.0
            pbar = tqdm(self.train_loader, desc=f"[Epoch {epoch+1}/{NUM_EPOCHS}] Training", leave=False)
            numDatas = 0

            for batch in pbar:
                images, questions, answers = batch
                # Forward pass
                logits = self.model(images, questions, answers)

                answer_encoded = self.model.bart_pho.tokenizer(
                    answers, return_tensors="pt", padding=True, truncation=True
                )
                answer_input_ids = answer_encoded["input_ids"].to(self.device)
                labels = answer_input_ids
                
                loss = self.criterion(logits.view(-1, logits.size(-1)), labels.view(-1))

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                total_loss += loss.item()
                pbar.set_postfix(loss = loss.item())
                numDatas += 1

            self.lr_scheduler.step()
            avg_loss = total_loss / numDatas
            print(f"[Epoch {epoch + 1}] Train Loss: {avg_loss:.4f}")

            #
            self.train_losses.append(avg_loss)

            #
            self.save_checkpoint(epoch, CHECKPOINT_PATH)

            api = HfApi()
            api.upload_file(
                path_or_fileobj = CHECKPOINT_PATH,                      # File trên máy
                path_in_repo = CHECKPOINT_FILENAME,              # File sẽ nằm trong thư mục mới
                repo_id = HUGGINGFACE_HUB_REPO,                        # Repo đích
                repo_type = "model",                                      # Hoặc "model"
                commit_message = f"Completed training {BASELINE_NAME} until epoch {epoch + 1}!"
            )

        print("Completed training!")

    def save_checkpoint(self, epoch, filepath):
        # Save model checkpoint with necessary states
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.lr_scheduler.state_dict(),
            'losses': self.train_losses
        }
        if os.path.exists(filepath):
            os.remove(filepath)

        torch.save(checkpoint, filepath)
        print(f"Checkpoint saved at {filepath}")
    
    def load_checkpoint(self, filepath):
        # Load checkpoint and restore model state
        if os.path.isfile(filepath):
            checkpoint = torch.load(filepath, map_location = self.device)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            start_epoch = checkpoint['epoch']
            self.train_losses = checkpoint['losses']
            
            print(f"Loaded checkpoint from {filepath}, resuming from epoch {start_epoch + 1}")
            return start_epoch
        else:
            print(f"No checkpoint found at {filepath}, starting from scratch")
            return 0

In [None]:
train_hf_dataset = load_dataset("Namronaldo2004/ViInfographicsVQA", split = "train", streaming = True)
train_dataset = ViInfographicsVQADataset(train_hf_dataset)

def custom_collate_fn(batch):
    images, questions, answers = zip(*batch)

    images = list(images)
    questions = list(questions)
    answers = list(answers)
    
    return images, questions, answers
    
train_dataloader = DataLoader(train_dataset, batch_size = 16, shuffle = False, collate_fn = custom_collate_fn)

In [None]:
# Khởi tạo model
VQA_model = NonTextModel()

# Loss function
criterion = nn.CrossEntropyLoss(ignore_index = VQA_model.bart_pho.tokenizer.pad_token_id)

# Optimizer và scheduler
optimizer = optim.AdamW(VQA_model.parameters(), lr = 1e-8)
lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma = 0.9)

# Khởi tạo class Trainer
VQA_trainer = VQATrainer(VQA_model, train_dataloader, criterion, optimizer, lr_scheduler)
start_epoch = VQA_trainer.load_checkpoint(CHECKPOINT_PATH)

In [None]:
sample = next(iter(train_hf_dataset))
image = sample["image"]
question = sample["question"]

VQA_model.generate([image], [question])

In [None]:
VQA_trainer.train(start_epoch = start_epoch, num_epochs = 2)