# CONFIG

In [1]:
from huggingface_hub import login
login("YOUR_HUGGINGFACE_TOKEN")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import json
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import ViltProcessor, ViltForImagesAndTextClassification, ViltConfig
from peft import LoraConfig, get_peft_model, TaskType
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
from tqdm import tqdm
import numpy as np

# --- C·∫§U H√åNH ---
# D√πng model ViLT ƒë√£ pretrain s·∫µn cho t√°c v·ª• MLM (Masked Language Modeling)
# Model n√†y r·∫•t nh·∫π v√† hi·ªáu qu·∫£ cho b√†i to√°n Visual Question Answering ho·∫∑c Classification
MODEL_NAME = "dandelin/vilt-b32-mlm" 

JSON_FILE = "Super_Final_Dataset_Full.json"  # ƒê·∫£m b·∫£o ƒë∆∞·ªùng d·∫´n ƒë√∫ng
IMG_ROOT = "images"                          # ƒê·∫£m b·∫£o ƒë∆∞·ªùng d·∫´n ƒë√∫ng
BATCH_SIZE = 16    
EPOCHS = 10         
LEARNING_RATE = 5e-5 
MAX_LEN = 40       # ƒê·ªô d√†i text t·ªëi ƒëa
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {DEVICE}")

# --- DANH S√ÅCH T√äN CLASS (MAPPING) ---
TOPIC_NAMES = {
    0: "Restaurant", 1: "Chocolate", 2: "Chips/Snacks", 3: "Seasoning",
    4: "Alcohol", 5: "Coffee/Tea", 6: "Soda/Juice", 7: "Cars", 8: "Electronics",
    9: "Phone/TV/Internet", 10: "Financial",
    11: "Other Service", 12: "Beauty", 13: "Healthcare", 14: "Clothing",
    15: "Games", 16: "Home Appliance", 17: "Travel",
    18: "Media", 19: "Sports", 20: "Shopping", 21: "Environment",
    22: "Animals/Pet Care", 23: "Safety", 24: "Smoking/Alcohol Abuse",
    25: "Unclear"
}

Using device: cuda


# DATASET + DATALOADER

In [3]:
# --- CLASS DATASET (ƒê√£ ch·ªânh s·ª≠a cho ViLT) ---
class AdDataset(Dataset):
    def __init__(self, full_data_dict, keys_list, img_root, processor, label2id, max_len=128):
        self.full_data = full_data_dict
        self.keys = keys_list # Danh s√°ch c√°c key (vd: "2/71762.jpg") thu·ªôc t·∫≠p n√†y
        self.img_root = img_root
        self.processor = processor
        self.label2id = label2id
        self.max_len = max_len

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        key = self.keys[idx]
        item = self.full_data[key]
        
        # 1. Load ·∫£nh
        img_path = os.path.join(self.img_root, key)
        try:
            image = Image.open(img_path).convert("RGB")

            # --- [FIX QUAN TR·ªåNG] ---
            # √âp ·∫£nh v·ªÅ k√≠ch th∆∞·ªõc vu√¥ng 384x384 ƒë·ªÉ tr√°nh l·ªói l·ªách size trong Batch
            image = image.resize((384, 384)) 
            # ------------------------

        except:
            # ·∫¢nh l·ªói -> T·∫°o ·∫£nh ƒëen
            image = Image.new('RGB', (384, 384), color='black')

        # 2. X·ª≠ l√Ω Text (N·ªëi chu·ªói thay v√¨ Chunking)
        # ƒê√£ th√™m field "qa" v√†o nh∆∞ b·∫°n y√™u c·∫ßu
        # Luc dau la: qa - slogan - caption - ocr
        text = (
            f"QA: {item.get('qa', '')} "
            f"Caption: {item.get('caption_text', '')} "
            f"Slogan: {item.get('slogan_text', '')} "
            f"OCR: {item.get('ocr_text', '')}"
        )

        # 3. Label
        label_str = str(item['topic_id'])
        label = self.label2id[label_str]

        # 4. Processor (ViLT x·ª≠ l√Ω c·∫£ ·∫£nh v√† text c√πng l√∫c)
        encoding = self.processor(
            image, 
            text, 
            return_tensors="pt", 
            padding="max_length", 
            truncation=True, 
            max_length=self.max_len
        )

        # Squeeze ƒë·ªÉ b·ªè batch dimension th·ª´a
        for k, v in encoding.items():
            encoding[k] = v.squeeze()

        return {
            'input_ids': encoding['input_ids'],
            'token_type_ids': encoding['token_type_ids'],
            'attention_mask': encoding['attention_mask'],
            'pixel_values': encoding['pixel_values'],
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [4]:
# --- PH·∫¶N CHU·∫®N B·ªä D·ªÆ LI·ªÜU ---

print("Loading Data & Processor...")
# D√πng ViltProcessor thay v√¨ CLIPProcessor
processor = ViltProcessor.from_pretrained(MODEL_NAME)

with open(JSON_FILE, 'r', encoding='utf-8') as f:
    full_data = json.load(f)

# 1. T·∫°o Mapping Label (Quan tr·ªçng: Map ID sang 0,1,2...)
# Qu√©t to√†n b·ªô file ƒë·ªÉ l·∫•y t·∫•t c·∫£ topic_id duy nh·∫•t
all_unique_labels = set()
for k, v in full_data.items():
    all_unique_labels.add(str(v['topic_id']))

# Sort ƒë·ªÉ th·ª© t·ª± lu√¥n c·ªë ƒë·ªãnh: '1', '2', ..., '25'
sorted_labels = sorted(list(all_unique_labels), key=lambda x: int(x) if x.isdigit() else x)
label2id = {label: i for i, label in enumerate(sorted_labels)}
id2label = {i: label for label, i in label2id.items()}

NUM_CLASSES = len(label2id)
print(f"Found {NUM_CLASSES} classes.")

# 2. L·ªçc keys h·ª£p l·ªá ƒë·ªÉ Split
all_keys_valid = []
all_labels_for_split = []

print("Scanning data for split stratification...")
for k in tqdm(list(full_data.keys())):
    # Ki·ªÉm tra xem ·∫£nh c√≥ t·ªìn t·∫°i kh√¥ng (Optional nh∆∞ng n√™n l√†m)
    if not os.path.exists(os.path.join(IMG_ROOT, k)):
        continue
        
    topic_id = str(full_data[k]['topic_id'])
    if topic_id in label2id:
        all_keys_valid.append(k)
        all_labels_for_split.append(label2id[topic_id]) # D√πng index (0-24) ƒë·ªÉ stratify

# --- CHIA 3 T·∫¨P: TRAIN (80%) - VAL (10%) - TEST (10%) ---
# B∆∞·ªõc 1: Chia Train (80%) v√† Temp (20%)
train_keys, temp_keys, train_labels, temp_labels = train_test_split(
    all_keys_valid, all_labels_for_split, test_size=0.2, random_state=42, stratify=all_labels_for_split
)

# B∆∞·ªõc 2: Chia Temp th√†nh Val (50% c·ªßa Temp = 10% t·ªïng) v√† Test (50% c·ªßa Temp = 10% t·ªïng)
val_keys, test_keys = train_test_split(
    temp_keys, test_size=0.5, random_state=42, stratify=temp_labels
)

print(f"Dataset Split Summary:")
print(f"   - Train set: {len(train_keys)} samples")
print(f"   - Val set:   {len(val_keys)} samples")
print(f"   - Test set:  {len(test_keys)} samples")

# T·∫°o Dataset Instance (Truy·ªÅn full_data + keys c·ªßa t·ª´ng t·∫≠p)
train_dataset = AdDataset(full_data, train_keys, IMG_ROOT, processor, label2id, MAX_LEN)
val_dataset = AdDataset(full_data, val_keys, IMG_ROOT, processor, label2id, MAX_LEN)
test_dataset = AdDataset(full_data, test_keys, IMG_ROOT, processor, label2id, MAX_LEN)

# T·∫°o DataLoader
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print("Data Preparation Complete!")

Loading Data & Processor...


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Found 26 classes.
Scanning data for split stratification...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 59282/59282 [00:00<00:00, 73823.97it/s]


Dataset Split Summary:
   - Train set: 47425 samples
   - Val set:   5928 samples
   - Test set:  5929 samples
Data Preparation Complete!


# MODEL

In [5]:
def get_vilt_lora_model(num_labels, id2label, label2id):
    print(f"Loading ViLT model: {MODEL_NAME}")
    
    # Load model g·ªëc d√†nh cho classification
    model = ViltForImagesAndTextClassification.from_pretrained(
        MODEL_NAME,
        num_labels=num_labels,
        id2label=id2label,
        label2id=label2id,
        num_images=1,                  # <--- TH√äM D√íNG N√ÄY ƒê·ªÇ S·ª¨A L·ªñI
        ignore_mismatched_sizes=True,   # B·∫Øt bu·ªôc v√¨ ta thay ƒë·ªïi s·ªë l∆∞·ª£ng class output
        use_safetensors=True
    )
    
    # C·∫•u h√¨nh LoRA
    peft_config = LoraConfig(
        r=8,           # Rank
        lora_alpha=16,  # Alpha
        target_modules=["query", "value"], # C√°c module trong Self-Attention
        lora_dropout=0.1,
        bias="none",
        modules_to_save=["classifier"] # Quan tr·ªçng: Train l·∫°i c·∫£ l·ªõp classifier cu·ªëi c√πng
    )
    
    # √Åp d·ª•ng LoRA
    model = get_peft_model(model, peft_config)
    
    # In ra s·ªë l∆∞·ª£ng tham s·ªë trainable
    model.print_trainable_parameters()
    
    return model

# TRAINER

In [6]:
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    progress_bar = tqdm(dataloader, desc="Training")
    
    for batch in progress_bar:
        # ƒê·∫©y d·ªØ li·ªáu l√™n GPU
        input_ids = batch['input_ids'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        pixel_values = batch['pixel_values'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        # ViLT t·ª± t√≠nh loss n·∫øu truy·ªÅn labels v√†o, nh∆∞ng ta c√≥ th·ªÉ l·∫•y logits ƒë·ªÉ t·ª± t√≠nh
        outputs = model(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values
        )
        
        logits = outputs.logits
        loss = criterion(logits, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # T√≠nh accuracy training s∆° b·ªô
        _, predicted = torch.max(logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        progress_bar.set_postfix(loss=loss.item())
        
    avg_loss = total_loss / len(dataloader)
    acc = correct / total
    return avg_loss, acc

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(
                input_ids=input_ids,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
                pixel_values=pixel_values
            )
            
            logits = outputs.logits
            loss = criterion(logits, labels)
            
            total_loss += loss.item()
            
            _, predicted = torch.max(logits, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
    avg_loss = total_loss / len(dataloader)
    acc = correct / total
    
    return avg_loss, acc, all_labels, all_preds

In [None]:
# 4. Init Model (ViLT + LoRA)
model = get_vilt_lora_model(len(label2id), id2label, label2id)
model.to(DEVICE)

# 5. Optimizer & Loss
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

# 6. Training Loop
best_val_loss = float('inf')
patience = 2
early_stop_counter = 0
save_path = "vilt_qa_cap_slo_ocr.pth"

print("\n--- START TRAINING ---")
for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")
    
    # Train
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE)
    
    # Validate
    val_loss, val_acc, _, _ = evaluate(model, val_loader, criterion, DEVICE)
    
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.4f}")
    
    # Save Best Model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stop_counter = 0
        torch.save(model.state_dict(), save_path)
        print("--> Saved Best Model!")
    else:
        early_stop_counter += 1
        print(f"‚è∏ No improvement ({early_stop_counter}/{patience})")

    # Early stopping
    if early_stop_counter >= patience:
        print("üõë Early stopping triggered!")
        break

print("Training Complete!")

Loading ViLT model: dandelin/vilt-b32-mlm


Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Some weights of ViltForImagesAndTextClassification were not initialized from the model checkpoint at dandelin/vilt-b32-mlm and are newly initialized: ['classifier.0.bias', 'classifier.0.weight', 'classifier.1.bias', 'classifier.1.weight', 'classifier.3.bias', 'classifier.3.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 907,034 || all params: 113,114,164 || trainable%: 0.8019

--- START TRAINING ---

Epoch 1/10


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2965/2965 [24:47<00:00,  1.99it/s, loss=0.116]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [02:13<00:00,  2.79it/s]


Train Loss: 1.0323 | Train Acc: 0.7222
Val Loss:   0.7647 | Val Acc:   0.7913
--> Saved Best Model!

Epoch 2/10


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2965/2965 [24:43<00:00,  2.00it/s, loss=0.24]  
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [02:13<00:00,  2.79it/s]


Train Loss: 0.6929 | Train Acc: 0.8075
Val Loss:   0.6877 | Val Acc:   0.8109
--> Saved Best Model!

Epoch 3/10


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2965/2965 [24:43<00:00,  2.00it/s, loss=0.0474]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [02:13<00:00,  2.78it/s]


Train Loss: 0.6084 | Train Acc: 0.8306
Val Loss:   0.6548 | Val Acc:   0.8209
--> Saved Best Model!

Epoch 4/10


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2965/2965 [24:42<00:00,  2.00it/s, loss=2.16]  
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [02:13<00:00,  2.78it/s]


Train Loss: 0.5503 | Train Acc: 0.8451
Val Loss:   0.6414 | Val Acc:   0.8239
--> Saved Best Model!

Epoch 5/10


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2965/2965 [24:44<00:00,  2.00it/s, loss=0.00288]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [02:12<00:00,  2.79it/s]


Train Loss: 0.4992 | Train Acc: 0.8575
Val Loss:   0.6246 | Val Acc:   0.8295
--> Saved Best Model!

Epoch 6/10


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2965/2965 [24:43<00:00,  2.00it/s, loss=0.0429]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [02:13<00:00,  2.79it/s]


Train Loss: 0.4547 | Train Acc: 0.8700
Val Loss:   0.6295 | Val Acc:   0.8268
‚è∏ No improvement (1/2)

Epoch 7/10


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2965/2965 [24:44<00:00,  2.00it/s, loss=0.00842]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [02:13<00:00,  2.78it/s]


Train Loss: 0.4151 | Train Acc: 0.8796
Val Loss:   0.6227 | Val Acc:   0.8283
--> Saved Best Model!

Epoch 8/10


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2965/2965 [24:44<00:00,  2.00it/s, loss=0.0405]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [02:13<00:00,  2.79it/s]


Train Loss: 0.3751 | Train Acc: 0.8909
Val Loss:   0.6173 | Val Acc:   0.8316
--> Saved Best Model!

Epoch 9/10


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2965/2965 [24:44<00:00,  2.00it/s, loss=0.0451]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [02:13<00:00,  2.78it/s]


Train Loss: 0.3384 | Train Acc: 0.9021
Val Loss:   0.6202 | Val Acc:   0.8338
‚è∏ No improvement (1/2)

Epoch 10/10


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2965/2965 [24:44<00:00,  2.00it/s, loss=0.000579]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [02:13<00:00,  2.78it/s]

Train Loss: 0.3026 | Train Acc: 0.9126
Val Loss:   0.6308 | Val Acc:   0.8316
‚è∏ No improvement (2/2)
üõë Early stopping triggered!
Training Complete!





# EVALUATING

In [8]:
# --- 2. H√ÄM TEST TR√äN T·∫¨P TEST (ƒê√£ s·ª≠a cho ViLT) ---
def test_final_model(model, test_loader, device, label2id):
    print("\n" + "="*50)
    print("ƒêANG CH·∫†Y ƒê√ÅNH GI√Å TR√äN T·∫¨P TEST (FINAL EVALUATION)")
    print("="*50)
    
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing"):
            # ƒê·∫©y d·ªØ li·ªáu v√†o device
            input_ids = batch['input_ids'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass (ViLT tr·∫£ v·ªÅ object, l·∫•y .logits)
            outputs = model(
                input_ids=input_ids,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
                pixel_values=pixel_values
            )
            
            logits = outputs.logits
            
            # L·∫•y class c√≥ x√°c su·∫•t cao nh·∫•t
            _, predicted = torch.max(logits, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # --- T√çNH TO√ÅN CH·ªà S·ªê ---
    acc = accuracy_score(all_labels, all_preds)
    print(f"\n‚úÖ TEST ACCURACY: {acc*100:.2f}%")
    
    # T·∫°o danh s√°ch t√™n class theo ƒë√∫ng th·ª© t·ª± index 0, 1, 2... c·ªßa model
    # V√¨ label2id c√≥ th·ªÉ kh√¥ng ƒë·ªß 26 class (n·∫øu d·ªØ li·ªáu thi·∫øu), ta c·∫ßn map ƒë·ªông
    # Logic: Model Class ID -> Topic ID G·ªëc -> T√™n Ti·∫øng Anh
    
    # ƒê·∫£o ng∆∞·ª£c label2id ƒë·ªÉ l·∫•y topic_id g·ªëc t·ª´ model output (0 -> "1", 1 -> "2"...)
    id2label = {v: k for k, v in label2id.items()} 
    
    target_names = []
    unique_classes_in_model = sorted(label2id.values()) # [0, 1, 2, ..., N]
    
    for class_idx in unique_classes_in_model:
        original_topic_id_str = id2label[class_idx] # V√≠ d·ª•: "1", "25"
        
        # Topic ID trong file json l√† 1-based (1..25), nh∆∞ng ta c·∫ßn map sang key c·ªßa TOPIC_NAMES
        # Gi·∫£ s·ª≠ topic "1" t∆∞∆°ng ·ª©ng v·ªõi key 0 trong TOPIC_NAMES (Restaurant)
        # N·∫øu logic c·ªßa b·∫°n l√† Topic 1 = Key 0, Topic 2 = Key 1:
        try:
            topic_key = int(original_topic_id_str) - 1 
            name = TOPIC_NAMES.get(topic_key, f"Topic {original_topic_id_str}")
        except:
            name = f"Topic {original_topic_id_str}"
            
        target_names.append(name)
    
    print("\nüìä CHI TI·∫æT THEO T·ª™NG CLASS:")
    # digits=4 ƒë·ªÉ hi·ªÉn th·ªã 4 s·ªë sau d·∫•u ph·∫©y cho ch√≠nh x√°c
    print(classification_report(all_labels, all_preds, target_names=target_names, digits=4))
    
    print("\n(L∆∞u √Ω: B·∫°n c√≥ th·ªÉ d√πng Confusion Matrix ƒë·ªÉ xem chi ti·∫øt nh·∫ßm l·∫´n gi·ªØa c√°c l·ªõp n·∫øu c·∫ßn)")

In [None]:
# --- 3. INFERENCE ON TEST SET ---

# ƒê∆∞·ªùng d·∫´n file model ƒë√£ l∆∞u (ƒë·∫£m b·∫£o t√™n file kh·ªõp v·ªõi l√∫c train)
SAVED_MODEL_PATH = "vilt_qa_cap_slo_ocr.pth" 

print(f"Loading Best Model from {SAVED_MODEL_PATH} for Final Testing...")

# 1. Kh·ªüi t·∫°o l·∫°i ki·∫øn tr√∫c model (ph·∫£i gi·ªëng h·ªát l√∫c train)
# L∆∞u √Ω: H√†m get_vilt_lora_model c·∫ßn ƒë∆∞·ª£c define ·ªü c√°c cell tr∆∞·ªõc ƒë√≥
# label2id v√† id2label l·∫•y t·ª´ b∆∞·ªõc prepare_data
final_model = get_vilt_lora_model(len(label2id), id2label, label2id)
final_model.to(DEVICE)

# 2. Load tr·ªçng s·ªë (weights)
try:
    final_model.load_state_dict(torch.load(SAVED_MODEL_PATH, map_location=DEVICE, weights_only=True))
    print("Model weights loaded successfully!")
except Exception as e:
    print(f"Error loading weights: {e}")
    print("H√£y ki·ªÉm tra l·∫°i ƒë∆∞·ªùng d·∫´n file .pth")

# 3. G·ªçi h√†m test
test_final_model(final_model, test_loader, DEVICE, label2id)

Loading Best Model from best_vilt_lora_model_loss.pth for Final Testing...
Loading ViLT model: dandelin/vilt-b32-mlm


Some weights of ViltForImagesAndTextClassification were not initialized from the model checkpoint at dandelin/vilt-b32-mlm and are newly initialized: ['classifier.0.bias', 'classifier.0.weight', 'classifier.1.bias', 'classifier.1.weight', 'classifier.3.bias', 'classifier.3.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 907,034 || all params: 113,114,164 || trainable%: 0.8019
Model weights loaded successfully!

ƒêANG CH·∫†Y ƒê√ÅNH GI√Å TR√äN T·∫¨P TEST (FINAL EVALUATION)


Testing: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [02:12<00:00,  2.79it/s]


‚úÖ TEST ACCURACY: 83.22%

üìä CHI TI·∫æT THEO T·ª™NG CLASS:
                       precision    recall  f1-score   support

           Restaurant     0.8819    0.8667    0.8742       405
            Chocolate     0.8308    0.9003    0.8642       371
         Chips/Snacks     0.7486    0.7528    0.7507       178
            Seasoning     0.8889    0.7778    0.8296        72
              Alcohol     0.9223    0.9321    0.9272       280
           Coffee/Tea     0.8714    0.8714    0.8714        70
           Soda/Juice     0.8709    0.9115    0.8908       407
                 Cars     0.9627    0.9518    0.9572       705
          Electronics     0.8443    0.8672    0.8556       369
    Phone/TV/Internet     0.7015    0.6438    0.6714        73
            Financial     0.7500    0.8333    0.7895       144
        Other Service     0.3355    0.4643    0.3895       112
               Beauty     0.9083    0.9067    0.9075       579
           Healthcare     0.7246    0.5155    0.6024  


