In [None]:
import os
import torch
import numpy as np
import pandas as pd
import argparse
from torch.utils.data import DataLoader
from utils.util import set_seed, set_gpu_devices, save_file
from networks.model import VideoQAmodel
from DataLoader import VideoQADataset
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils.logger import logger

# Import training functions from train.py to ensure consistency
from train import train, eval, predict

print("Libraries imported successfully.")

# Configuration and Path Setup

Here we define the paths for:  
1. **Video Features**: `.pt` files organized in `train`, `val`, `test` folders.  
2. **Object Features**: Folder containing separate folders for each video ID (e.g. `object_feature_path/video_id/*.pkl`).  
3. **Split Directory**: Directory containing `train.txt`, `valid.txt`, and `test.txt` which define the video splits.  
4. **Annotation Path**: Path to the standard CSV annotations.

In [None]:
class Config:
    # Data Paths
    # Configure these paths according to your actual workspace
    video_feature_root = r"D:\KLTN\TranSTR\causalvid\features\vit"  # Root folder containing train/val/test folders of .pt files
    object_feature_path = r"D:\KLTN\TranSTR\causalvid\features\objects" # Root folder containing video_id subfolders
    sample_list_path = r"D:\KLTN\TranSTR\causalvid\data\vqa\causal\anno" # Path to {split}.csv files
    split_dir_txt = r"D:\KLTN\TranSTR\causalvid\data\splits" # Directory containing train.txt, valid.txt, test.txt
    
    # Model Params
    v = "v1"
    bs = 8 # Batch Size
    lr = 1e-4
    epoch = 20
    gpu = 0
    dropout = 0.3
    encoder_dropout = 0.3
    patience = 5
    gamma = 0.1
    decay = 1e-4
    
    # Dataset Params
    dataset = 'causal-vid'
    objs = 10 # Number of objects
    n_query = 5
    
    # Transformer
    d_model = 512
    word_dim = 768
    topK_frame = 20
    topK_obj = 10
    num_encoder_layers = 2
    num_decoder_layers = 2
    nheads = 8
    normalize_before = True
    activation = 'gelu'
    
    # Text Encoder
    text_encoder_lr = 1e-5
    freeze_text_encoder = False
    text_encoder_type = "bert-base-uncased"
    text_pool_mode = 1
    
    # Misc
    hard_eval = False
    pos_ratio = 1.0
    neg_ratio = 1.0
    a = 1.0
    use_amp = True
    num_workers = 4
    frame_feat_dim = 768 # Match your ViT feature dim
    obj_feat_dim = 2048 + 5 # Match your object feature dim

args = Config()
set_gpu_devices(args.gpu)
set_seed(999)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

## 1. Initialize DataLoaders
We use the modified `VideoQADataset` which will now:
- Load standard annotations from CSV.
- **Filter by Split TXT**: Only include video IDs found in `split_dir_txt/{split}.txt`.
- **Filter by Feature Existence**: Only include videos where the `.pt` file actually exists.
- **Load Object Features**: From separate folders (e.g. `.../video_id/imageXX.pkl`).

This robustly handles cases where you have a list of videos but are missing features for some.

In [None]:
print("Initializing Datasets...")

try:
    # TRAIN
    train_dataset = VideoQADataset(
        split='train', 
        n_query=args.n_query, 
        obj_num=args.objs,
        sample_list_path=args.sample_list_path,
        video_feature_path=args.video_feature_root,
        object_feature_path=args.object_feature_path,
        split_dir=args.split_dir_txt # Passing the split folder
    )

    # VAL
    val_dataset = VideoQADataset(
        split='val', 
        n_query=args.n_query, 
        obj_num=args.objs,
        sample_list_path=args.sample_list_path,
        video_feature_path=args.video_feature_root,
        object_feature_path=args.object_feature_path,
        split_dir=args.split_dir_txt
    )

    # TEST
    test_dataset = VideoQADataset(
        split='test', 
        n_query=args.n_query, 
        obj_num=args.objs,
        sample_list_path=args.sample_list_path,
        video_feature_path=args.video_feature_root,
        object_feature_path=args.object_feature_path,
        split_dir=args.split_dir_txt
    )

    train_loader = DataLoader(dataset=train_dataset, batch_size=args.bs, shuffle=True, 
                              num_workers=args.num_workers, pin_memory=True)
    val_loader = DataLoader(dataset=val_dataset, batch_size=args.bs, shuffle=False, 
                            num_workers=args.num_workers, pin_memory=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=args.bs, shuffle=False, 
                             num_workers=args.num_workers, pin_memory=True)

    print("\nDatasets initialized successfully.")
except Exception as e:
    print(f"\nError initializing datasets: {e}")
    # Check your paths in Config if this fails.

## 2. Verify Data Loading
We verify that a batch can be loaded correctly, ensuring dimensions are as expected.

In [None]:
try:
    if len(train_loader) > 0:
        for sample in train_loader:
            vid_frame_feat, vid_obj_feat, qns_word, ans_word, ans_id, qns_key = sample
            print("--- Data Verification ---")
            print(f"Frame Features Shape: {vid_frame_feat.shape}")
            print(f"Object Features Shape: {vid_obj_feat.shape}")
            print(f"Question Batch Size: {len(qns_word)}")
            print(f"Sample Question: {qns_word[0]}")
            print("--- Verification Passed ---")
            break
    else:
        print("Warning: Train loader is empty. Check your data paths and splits.")
except Exception as e:
    print(f"Data Verification Failed: {e}")

## 3. Model Setup and Training
We initialize the model with the configuration and start the training loop.

In [None]:
config_dict = {k: v for k, v in Config.__dict__.items() if not k.startswith('__')}
config_dict['device'] = device
# Ensure compatibility with Model constructor
model = VideoQAmodel(**config_dict)

# Optimizer config
param_dicts = [
    {"params": [p for n, p in model.named_parameters() if "text_encoder" not in n and p.requires_grad]},
    {"params": [p for n, p in model.named_parameters() if "text_encoder" in n and p.requires_grad], "lr": args.text_encoder_lr}
]
optimizer = torch.optim.AdamW(params=param_dicts, lr=args.lr, weight_decay=args.decay)
scheduler = ReduceLROnPlateau(optimizer, 'max', factor=args.gamma, patience=args.patience, verbose=True)
model.to(device)
xe = nn.CrossEntropyLoss().to(device)
scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)

print("Model and Optimizer created.")

In [None]:
print("Starting Training Loop...")
best_eval_score = 0.0
best_epoch = 1

# Initialize logger dummy/wrapper if needed or use imported
# logger.debug = print # Simple redirect if logger not configured

if len(train_loader) > 0:
    for epoch in range(1, args.epoch + 1):
        print(f"Epoch {epoch}/{args.epoch}")
        train_loss, train_acc = train(model, optimizer, train_loader, xe, device, use_amp=args.use_amp)
        eval_score = eval(model, val_loader, device)
        scheduler.step(eval_score)
        
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f} | Val Acc: {eval_score:.2f}")
        
        if eval_score > best_eval_score:
            best_eval_score = eval_score
            best_epoch = epoch
            save_path = f'./models/best_model_epoch_{epoch}.ckpt'
            if not os.path.exists('./models'): os.makedirs('./models')
            torch.save(model.state_dict(), save_path)
            print(f"  Saved new best model to {save_path}")
        
        # Optional: Run test every epoch
        test_score = eval(model, test_loader, device)
        print(f"  Test Acc: {test_score:.2f}")

    print("Training Complete.")
    print(f"Best Val Acc: {best_eval_score:.2f} at Epoch {best_epoch}")
else:
    print("Cannot start training because train_loader is empty.")

## 4. Final Test Evaluation
Load the best model and evaluate on the test set, saving predictions.

In [None]:
if os.path.exists(f'./models/best_model_epoch_{best_epoch}.ckpt'):
    model.load_state_dict(torch.load(f'./models/best_model_epoch_{best_epoch}.ckpt'))
    print("Loaded best model.")
    
results, test_acc = predict(model, test_loader, device)
print(f"Final Test Accuracy: {test_acc:.2f}%")

# Save results
if not os.path.exists('./prediction'): os.makedirs('./prediction')
result_path = f'./prediction/results_best_{best_epoch}.json'
save_file(results, result_path)
print(f"Predictions saved to {result_path}")