# CONFIG

In [1]:
import json
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from peft import LoraConfig, get_peft_model
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
from tqdm import tqdm # Th∆∞ vi·ªán t·∫°o thanh ti·∫øn tr√¨nh

# --- C·∫§U H√åNH ---
MODEL_NAME = "openai/clip-vit-base-patch32"
JSON_FILE = "Super_Final_Dataset_Full.json" # S·ª≠a l·∫°i ƒë∆∞·ªùng d·∫´n c·ªßa b·∫°n
IMG_ROOT = "images"                         # S·ª≠a l·∫°i ƒë∆∞·ªùng d·∫´n c·ªßa b·∫°n
BATCH_SIZE = 16    # Gi·∫£m xu·ªëng 8 n·∫øu VRAM < 8GB
EPOCHS = 10        # S·ªë v√≤ng l·∫∑p train
LEARNING_RATE = 1e-4 # Learning rate cho LoRA (th∆∞·ªùng cao h∆°n fine-tune full model ch√∫t)
NUM_CLASSES = 26
MAX_CHUNKS = 4

# Thi·∫øt l·∫≠p thi·∫øt b·ªã (GPU ∆∞u ti√™n)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- DANH S√ÅCH T√äN CLASS (MAPPING) ---
TOPIC_NAMES = {
    0: "Restaurant", 1: "Chocolate", 2: "Chips/Snacks", 3: "Seasoning",
    4: "Alcohol", 5: "Coffee/Tea", 6: "Soda/Juice", 7: "Cars", 8: "Electronics",
    9: "Phone/TV/Internet", 10: "Financial",
    11: "Other Service", 12: "Beauty", 13: "Healthcare", 14: "Clothing",
    15: "Games", 16: "Home Appliance", 17: "Travel",
    18: "Media", 19: "Sports", 20: "Shopping", 21: "Environment",
    22: "Animals/Pet Care", 23: "Safety", 24: "Smoking/Alcohol Abuse",
    25: "Unclear"
}

  from .autonotebook import tqdm as notebook_tqdm


# DATASET + DATALOADER

In [2]:
# --- 1. DATASET CLASS (C·∫≠p nh·∫≠t ƒë·ªÉ h·ªó tr·ª£ Split) ---
class AdDataset(Dataset):
    def __init__(self, json_data, keys_list, img_root_dir, processor, max_chunks=4):
        self.data = json_data
        self.keys = keys_list # Ch·ªâ nh·∫≠n danh s√°ch key thu·ªôc t·∫≠p train ho·∫∑c val
        self.img_root_dir = img_root_dir
        self.processor = processor
        self.max_chunks = max_chunks

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        img_key = self.keys[idx]
        item = self.data[img_key]
        
        # X·ª≠ l√Ω ·∫£nh
        img_path = os.path.join(self.img_root_dir, img_key)
        try:
            image = Image.open(img_path).convert("RGB")
        except:
            image = Image.new('RGB', (224, 224), color='black')

        # X·ª≠ l√Ω Text
        # Luc dau la: slogan - qa - caption - ocr
        text_parts = [
            item.get("qa", ""),
            item.get("caption_text", ""),
            item.get("slogan_text", ""),
            item.get("ocr_text", "")
        ]
        full_text = " ".join([t for t in text_parts if t])

        # Tokenize & Chunking
        inputs = self.processor(
            text=full_text, 
            images=image, 
            return_tensors="pt", 
            padding="max_length", 
            truncation=True, 
            max_length=77 * self.max_chunks
        )
        
        long_input_ids = inputs['input_ids'][0] 
        long_attention_mask = inputs['attention_mask'][0]
        
        chunk_size = 77
        input_ids_chunks = []
        attention_mask_chunks = []
        
        for i in range(self.max_chunks):
            start = i * chunk_size
            end = start + chunk_size
            chunk_ids = long_input_ids[start:end]
            chunk_mask = long_attention_mask[start:end]
            
            if len(chunk_ids) < chunk_size:
                pad_len = chunk_size - len(chunk_ids)
                chunk_ids = torch.cat([chunk_ids, torch.zeros(pad_len, dtype=torch.long)])
                chunk_mask = torch.cat([chunk_mask, torch.zeros(pad_len, dtype=torch.long)])
            
            input_ids_chunks.append(chunk_ids)
            attention_mask_chunks.append(chunk_mask)

        input_ids = torch.stack(input_ids_chunks)
        attention_mask = torch.stack(attention_mask_chunks)
        pixel_values = inputs['pixel_values'][0]

        label_str = item.get("topic_id")
        label = int(label_str) - 1

        return {
            "pixel_values": pixel_values,
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": torch.tensor(label, dtype=torch.long)
        }

In [3]:
print(f"Using device: {device}")

# Load d·ªØ li·ªáu & Split
print("Loading Data...")
processor = CLIPProcessor.from_pretrained(MODEL_NAME, use_fast=True)
with open(JSON_FILE, 'r', encoding='utf-8') as f:
    full_data = json.load(f)

all_keys = list(full_data.keys())

# L·∫•y labels ra ƒë·ªÉ stratify split (ƒë·∫£m b·∫£o chia ƒë·ªÅu c√°c class)
all_labels_for_split = []
valid_keys = []

print("Scanning labels for split stratification...")
for k in tqdm(all_keys):
    try:
        lbl = int(full_data[k]['topic_id']) - 1
        if 0 <= lbl < NUM_CLASSES:
            all_labels_for_split.append(lbl)
            valid_keys.append(k)
    except:
        continue

# --- CHIA 3 T·∫¨P: TRAIN (80%) - VAL (10%) - TEST (10%) ---
# B∆∞·ªõc 1: Chia Train (80%) v√† Temp (20%)
train_keys, temp_keys, train_labels, temp_labels = train_test_split(
    valid_keys, all_labels_for_split, test_size=0.2, random_state=42, stratify=all_labels_for_split
)

# B∆∞·ªõc 2: Chia Temp th√†nh Val (50% c·ªßa Temp = 10% t·ªïng) v√† Test (50% c·ªßa Temp = 10% t·ªïng)
val_keys, test_keys = train_test_split(
    temp_keys, test_size=0.5, random_state=42, stratify=temp_labels
)

print(f"Dataset Split Summary:")
print(f"   - Train set: {len(train_keys)} samples")
print(f"   - Val set:   {len(val_keys)} samples")
print(f"   - Test set:  {len(test_keys)} samples")

train_dataset = AdDataset(full_data, train_keys, IMG_ROOT, processor, MAX_CHUNKS)
val_dataset = AdDataset(full_data, val_keys, IMG_ROOT, processor, MAX_CHUNKS)
test_dataset = AdDataset(full_data, test_keys, IMG_ROOT, processor, MAX_CHUNKS) # Dataset m·ªõi cho Test

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0) # Loader m·ªõi

Using device: cuda
Loading Data...
Scanning labels for split stratification...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 59282/59282 [00:00<00:00, 1516212.56it/s]

Dataset Split Summary:
   - Train set: 47425 samples
   - Val set:   5928 samples
   - Test set:  5929 samples





# MODEL

In [4]:
# --- 2. MODEL CLASS (Gi·ªØ nguy√™n c·∫•u h√¨nh LoRA c·ªßa b·∫°n) ---
class MultimodalCLIPClassifier(nn.Module):
    def __init__(self, num_classes, base_model_name="openai/clip-vit-base-patch32"):
        super().__init__()
        self.clip = CLIPModel.from_pretrained(base_model_name, use_safetensors=True)
        
        # C·∫•u h√¨nh LoRA nh∆∞ b·∫°n y√™u c·∫ßu (gi·ªØ nguy√™n params)
        config = LoraConfig(
            r=8, 
            lora_alpha=16,
            target_modules=["q_proj", "v_proj"], 
            lora_dropout=0.1,
            bias="none",
        )
        self.clip = get_peft_model(self.clip, config)
        self.classifier = nn.Linear(self.clip.config.projection_dim * 2, num_classes)
        
    def forward(self, pixel_values, input_ids, attention_mask, labels=None):
        batch_size, num_chunks, seq_len = input_ids.shape
        
        vision_outputs = self.clip.base_model.model.vision_model(pixel_values=pixel_values)
        image_embeds = self.clip.base_model.model.visual_projection(vision_outputs[1])

        flat_input_ids = input_ids.view(-1, seq_len) 
        flat_attention_mask = attention_mask.view(-1, seq_len)
        
        text_outputs = self.clip.base_model.model.text_model(
            input_ids=flat_input_ids, 
            attention_mask=flat_attention_mask
        )
        text_embeds_flat = self.clip.base_model.model.text_projection(text_outputs[1])
        text_embeds = text_embeds_flat.view(batch_size, num_chunks, -1)
        text_embeds = torch.mean(text_embeds, dim=1) 

        combined_features = torch.cat((image_embeds, text_embeds), dim=1)
        logits = self.classifier(combined_features)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
            
        return loss, logits

# TRAINER

In [5]:
# --- 3. TRAINING ENGINE ---
def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    loop = tqdm(dataloader, leave=True)
    for batch in loop:
        pixel_values = batch['pixel_values'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()

        # Forward
        loss, logits = model(pixel_values, input_ids, attention_mask, labels)
        
        # Backward
        loss.backward()
        
        # K·ªπ thu·∫≠t: Gradient Clipping ƒë·ªÉ tr√°nh gradient n·ªï (g√¢y l·ªói NaN)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = torch.max(logits.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        loop.set_description(f"Train Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(dataloader)
    accuracy = 100 * correct / total
    return avg_loss, accuracy

def evaluate(model, dataloader, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            pixel_values = batch['pixel_values'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            loss, logits = model(pixel_values, input_ids, attention_mask, labels)
            
            total_loss += loss.item()
            _, predicted = torch.max(logits.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    avg_loss = total_loss / len(dataloader)
    accuracy = 100 * correct / total
    return avg_loss, accuracy

In [None]:
# Kh·ªüi t·∫°o Model
print("Initializing Model...")
model = MultimodalCLIPClassifier(num_classes=NUM_CLASSES).to(device)
model.clip.print_trainable_parameters()

# Optimizer (Ch·ªâ t·ªëi ∆∞u tham s·ªë c√≥ requires_grad=True)
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE)

# Loop Train
best_val_loss = float("inf")
patience = 2
early_stop_counter = 0

print("\n--- START TRAINING ---")
for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, device)
    val_loss, val_acc = evaluate(model, val_loader, device)
    
    print(f"Result - Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"Result - Val Loss: {val_loss:.4f}   | Val Acc: {val_acc:.2f}%")
    
    # L∆∞u Best Model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stop_counter = 0
        torch.save(model.state_dict(), "clip_qa_cap_slo_ocr.pth")
        print("--> Saved Best Model!")
    else:
        early_stop_counter += 1
        print(f"‚è∏ No improvement ({early_stop_counter}/{patience})")

    # Early stopping
    if early_stop_counter >= patience:
        print("üõë Early stopping triggered!")
        break

print("Training Completed.")

Initializing Model...
trainable params: 491,520 || all params: 151,768,833 || trainable%: 0.3239

--- START TRAINING ---

Epoch 1/10


Train Loss: 3.3928: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2965/2965 [35:41<00:00,  1.38it/s]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [03:15<00:00,  1.89it/s]


Result - Train Loss: 1.2338 | Train Acc: 68.60%
Result - Val Loss: 0.7585   | Val Acc: 80.28%
--> Saved Best Model!

Epoch 2/10


Train Loss: 0.5145: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2965/2965 [23:39<00:00,  2.09it/s]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [01:49<00:00,  3.39it/s]


Result - Train Loss: 0.6207 | Train Acc: 83.41%
Result - Val Loss: 0.6044   | Val Acc: 83.11%
--> Saved Best Model!

Epoch 3/10


Train Loss: 3.1314: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2965/2965 [23:41<00:00,  2.09it/s]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [01:49<00:00,  3.38it/s]


Result - Train Loss: 0.4630 | Train Acc: 87.60%
Result - Val Loss: 0.5688   | Val Acc: 83.50%
--> Saved Best Model!

Epoch 4/10


Train Loss: 0.0010: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2965/2965 [23:43<00:00,  2.08it/s]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [01:50<00:00,  3.36it/s]


Result - Train Loss: 0.3554 | Train Acc: 90.28%
Result - Val Loss: 0.5608   | Val Acc: 83.94%
--> Saved Best Model!

Epoch 5/10


Train Loss: 0.0014: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2965/2965 [23:42<00:00,  2.08it/s]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [01:50<00:00,  3.34it/s]


Result - Train Loss: 0.2709 | Train Acc: 92.71%
Result - Val Loss: 0.5853   | Val Acc: 83.96%
‚è∏ No improvement (1/2)

Epoch 6/10


Train Loss: 2.1124: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2965/2965 [23:43<00:00,  2.08it/s]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [01:50<00:00,  3.37it/s]

Result - Train Loss: 0.2044 | Train Acc: 94.61%
Result - Val Loss: 0.6127   | Val Acc: 83.79%
‚è∏ No improvement (2/2)
üõë Early stopping triggered!
Training Completed.





# EVALUATING

In [8]:
# --- 4. H√ÄM TEST TR√äN T·∫¨P TEST (M·ªõi th√™m) ---
def test_final_model(model, test_loader, device):
    print("\n" + "="*50)
    print("ƒêANG CH·∫†Y ƒê√ÅNH GI√Å TR√äN T·∫¨P TEST (FINAL EVALUATION)")
    print("="*50)
    
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing"):
            pixel_values = batch['pixel_values'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            _, logits = model(pixel_values, input_ids, attention_mask, labels)
            
            _, predicted = torch.max(logits.data, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # T√≠nh to√°n c√°c ch·ªâ s·ªë
    acc = accuracy_score(all_labels, all_preds)
    print(f"\n‚úÖ TEST ACCURACY: {acc*100:.2f}%")
    
    # T·∫°o danh s√°ch t√™n class cho b√°o c√°o ƒë·∫πp
    target_names = [TOPIC_NAMES[i] for i in range(NUM_CLASSES)]
    
    print("\nüìä CHI TI·∫æT THEO T·ª™NG CLASS:")
    print(classification_report(all_labels, all_preds, target_names=target_names, digits=4))
    
    print("\n(L∆∞u √Ω: B·∫°n c√≥ th·ªÉ d√πng Confusion Matrix ƒë·ªÉ xem chi ti·∫øt nh·∫ßm l·∫´n gi·ªØa c√°c l·ªõp n·∫øu c·∫ßn)")

In [None]:
# --- INFERENCE ON TEST SET ---
# Sau khi train xong, load l·∫°i model t·ªët nh·∫•t ƒë·ªÉ test
print("Loading Best Model for Final Testing...")

# Reset model v·ªÅ tr·∫°ng th√°i kh·ªüi t·∫°o r·ªìi load weight
final_model = MultimodalCLIPClassifier(num_classes=NUM_CLASSES).to(device)
final_model.load_state_dict(torch.load("clip_qa_cap_slo_ocr.pth", map_location=device, weights_only=True))

# G·ªçi h√†m test
test_final_model(final_model, test_loader, device)

Loading Best Model for Final Testing...

ƒêANG CH·∫†Y ƒê√ÅNH GI√Å TR√äN T·∫¨P TEST (FINAL EVALUATION)


Testing: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [03:18<00:00,  1.86it/s]


‚úÖ TEST ACCURACY: 84.36%

üìä CHI TI·∫æT THEO T·ª™NG CLASS:
                       precision    recall  f1-score   support

           Restaurant     0.8988    0.8988    0.8988       405
            Chocolate     0.8469    0.8949    0.8702       371
         Chips/Snacks     0.7380    0.7753    0.7562       178
            Seasoning     0.9200    0.6389    0.7541        72
              Alcohol     0.9355    0.9321    0.9338       280
           Coffee/Tea     0.7794    0.7571    0.7681        70
           Soda/Juice     0.8828    0.9435    0.9121       407
                 Cars     0.9563    0.9617    0.9590       705
          Electronics     0.8266    0.8916    0.8579       369
    Phone/TV/Internet     0.8088    0.7534    0.7801        73
            Financial     0.7885    0.8542    0.8200       144
        Other Service     0.5733    0.3839    0.4599       112
               Beauty     0.8997    0.9447    0.9217       579
           Healthcare     0.7556    0.7010    0.7273  


