In [6]:
!pip install pandas numpy torch transformers peft scikit-learn Pillow tqdm
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import (
    ViltProcessor, 
    ViltForQuestionAnswering,
    get_linear_schedule_with_warmup
)
from peft import (
    LoraConfig,
    get_peft_model
)
from sklearn.metrics import accuracy_score, f1_score
from PIL import Image
import re
from tqdm import tqdm

# === Helper function for text normalization ===
def normalize_text(text):
    """Normalize text for consistent comparison"""
    if not isinstance(text, str):
        text = str(text)
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

# === Load the data ===
train_df = pd.read_csv("/kaggle/input/vrmini2/train_split.csv")
test_df = pd.read_csv("/kaggle/input/vrmini2/test_split.csv")

print(f"Training samples: {len(train_df)}")
print(f"Testing samples: {len(test_df)}")

# === Device configuration ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# === Load the pre-trained model and processor ===
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
base_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

# Set max sequence length for text tokenization
MAX_LENGTH = 40  # Adjust as needed based on your data

# === IMPROVEMENT: Adding answer frequency analysis ===
print("Analyzing answer distribution in training data...")
answer_counts = train_df['answer'].value_counts()
print(f"Top 10 most common answers: {answer_counts.head(10)}")
print(f"Number of unique answers: {len(answer_counts)}")

# === Create a custom dataset - FIXED IMPLEMENTATION ===
class VQADataset(Dataset):
    def __init__(self, dataframe, processor, max_length=40):
        self.dataframe = dataframe
        self.processor = processor
        self.max_length = max_length
        
        # IMPROVEMENT: Better handling of answers
        # First normalize all answers for consistency
        self.dataframe['normalized_answer'] = self.dataframe['answer'].apply(normalize_text)
        
        # Get all unique normalized answers in the training set
        all_answers = sorted(self.dataframe['normalized_answer'].unique())
        self.answer_to_id = {answer: idx for idx, answer in enumerate(all_answers)}
        self.id_to_answer = {idx: answer for answer, idx in self.answer_to_id.items()}
        self.num_labels = len(self.answer_to_id)
        
        print(f"Dataset created with {self.num_labels} unique normalized answers")
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        image_path = row['full_path']
        question = row['question']
        answer = row['normalized_answer']  # Use normalized answer
        
        try:
            # Load image
            image = Image.open(image_path).convert("RGB")
            
            # Pre-resize image to ensure consistent dimensions
            # ViLT typically uses 384x384 images
            image = image.resize((384, 384))
            
            # Process with processor after consistent resizing
            encoding = self.processor(
                images=image, 
                text=question, 
                return_tensors="pt", 
                padding="max_length",
                max_length=self.max_length,
                truncation=True
            )
            
            # Remove batch dimension
            for k, v in encoding.items():
                encoding[k] = v.squeeze()
                
            # Convert answer to id
            answer_id = self.answer_to_id.get(answer, -1)
            
            # FIXED: Create a one-hot vector for the answer instead of scalar
            return {
                "pixel_values": encoding["pixel_values"],
                "input_ids": encoding["input_ids"],
                "attention_mask": encoding["attention_mask"],
                "token_type_ids": encoding["token_type_ids"],
                "labels": answer_id  # Still return scalar for efficient batching
            }
        except Exception as e:
            print(f"Error processing {image_path}: {e}")
            # Return a properly formatted dummy tensor instead of None
            return self._get_dummy_item()
    
    # Method to create dummy items for error cases 
    def _get_dummy_item(self):
        # Create a dummy item with the right shapes
        dummy_image = torch.zeros((3, 384, 384), dtype=torch.float32)
        dummy_ids = torch.zeros(self.max_length, dtype=torch.long)
        dummy_mask = torch.zeros(self.max_length, dtype=torch.long)
        dummy_type_ids = torch.zeros(self.max_length, dtype=torch.long)
        dummy_label = 0  # Return scalar
        
        return {
            "pixel_values": dummy_image,
            "input_ids": dummy_ids,
            "attention_mask": dummy_mask,
            "token_type_ids": dummy_type_ids,
            "labels": dummy_label
        }

# === IMPROVED: Better batch size balance ===
BATCH_SIZE = 8

# === FIXED: Updated collate function for VQA model ===
def collate_fn(batch):
    # Filter out None values in case there are any
    batch = [item for item in batch if item is not None]
    if not batch:
        raise ValueError("Empty batch after filtering None values")
    
    # Verify all pixel_values have the same shape
    shapes = [item["pixel_values"].shape for item in batch]
    if len(set(str(s) for s in shapes)) > 1:
        print(f"Warning: Inconsistent shapes found: {shapes}")
    
    # Create a batch-sized tensor of label indices
    label_indices = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
    
    return {
        "pixel_values": torch.stack([item["pixel_values"] for item in batch]),
        "input_ids": torch.stack([item["input_ids"] for item in batch]),
        "attention_mask": torch.stack([item["attention_mask"] for item in batch]),
        "token_type_ids": torch.stack([item["token_type_ids"] for item in batch]),
        "label_indices": label_indices  # Keep original indices for evaluation
    }

# === Create datasets and dataloaders ===
train_dataset = VQADataset(train_df, processor, max_length=MAX_LENGTH)
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    collate_fn=collate_fn,
    num_workers=0
)

# === Model preparation ===
num_labels = len(train_dataset.answer_to_id)
print(f"Number of unique answers: {num_labels}")

# FIXED: Properly configure the model for VQA
# The ViLT VQA model expects to output num_classes predictions
base_model.config.num_labels = num_labels
base_model.classifier = nn.Sequential(
    nn.Linear(base_model.config.hidden_size, base_model.config.hidden_size),
    nn.LayerNorm(base_model.config.hidden_size),
    nn.GELU(),
    nn.Linear(base_model.config.hidden_size, num_labels)
)

# IMPROVEMENT: Better LoRA configuration
lora_config = LoraConfig(
    r=32,                      # Higher rank for more expressivity
    lora_alpha=64,             # Higher alpha for better scaling
    lora_dropout=0.05,         # Lower dropout to avoid underfitting
    target_modules=["query", "key", "value", "output.dense"],
    bias="none",               # Don't adapt bias terms
    task_type="QUESTION_ANS"   # Specify the task type
)

# Apply LoRA to the model
model = get_peft_model(base_model, lora_config)
model.to(device)

# Print trainable parameter counts
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable_params:,} out of {all_params:,} ({100 * trainable_params / all_params:.2f}%)")

# Ensure model's config is updated with our answer map
model.config.id2label = train_dataset.id_to_answer
model.config.label2id = train_dataset.answer_to_id
model.config.num_labels = num_labels

# === IMPROVEMENT: Better training parameters ===
num_epochs = 5              # More epochs for better learning
learning_rate = 5e-4        # Higher learning rate
weight_decay = 0.01         # Weight decay for regularization
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Better training scheduler
total_steps = len(train_dataloader) * num_epochs
warmup_steps = int(0.1 * total_steps)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

# === IMPROVED: Prediction function with consistent image sizing ===
def predict_answer_finetuned(image_path, question, model, processor, id2label, max_length, device):
    """Improved prediction function with consistent image sizing"""
    try:
        image = Image.open(image_path).convert("RGB")
        
        # Always resize to 384x384 for consistency
        image = image.resize((384, 384))
        
        encoding = processor(
            images=image, 
            text=question, 
            return_tensors="pt",
            padding="max_length",
            max_length=max_length,
            truncation=True
        ).to(device)
        
        with torch.no_grad():
            outputs = model(**encoding)
        
        # Get probabilities
        logits = outputs.logits
        probs = torch.softmax(logits, dim=-1)[0]
        
        # Get top prediction
        predicted_idx = torch.argmax(probs).item()
        confidence = probs[predicted_idx].item()
        
        predicted_answer = id2label.get(predicted_idx, "unknown")
        
        return predicted_answer
    except Exception as e:
        print(f"Error with {image_path}: {e}")
        return "unknown"

# Define a custom loss function for VQA
class VQALoss(nn.Module):
    def __init__(self, num_labels):
        super(VQALoss, self).__init__()
        self.num_labels = num_labels
        self.loss_fn = nn.CrossEntropyLoss()
        
    def forward(self, logits, label_indices):
        return self.loss_fn(logits, label_indices)

# Initialize loss function
vqa_loss_fn = VQALoss(num_labels).to(device)

# === FIXED: Training loop ===
print("Starting training...")
best_loss = float('inf')

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    model.train()
    running_loss = 0.0
    
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}")
    for step, batch in enumerate(progress_bar):
        try:
            # Move tensors to device
            pixel_values = batch["pixel_values"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            token_type_ids = batch["token_type_ids"].to(device)
            label_indices = batch["label_indices"].to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass - don't pass labels to model, will handle manually
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                pixel_values=pixel_values
            )
            
            # FIXED: Handle VQA model outputs properly
            # Calculate loss using our custom loss function
            loss = vqa_loss_fn(outputs.logits, label_indices)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            # Update parameters
            optimizer.step()
            scheduler.step()
            
            # Update progress bar
            running_loss += loss.item()
            progress_bar.set_postfix({"loss": running_loss / (step + 1)})
            
            # Learning rate monitoring
            if step % 100 == 0:
                print(f"Step {step}, LR: {scheduler.get_last_lr()[0]:.2e}")
                
            # Save checkpoint for best loss
            if (step + 1) % 200 == 0:
                avg_loss = running_loss / (step + 1)
                if avg_loss < best_loss:
                    best_loss = avg_loss
                    print(f"Saving checkpoint with loss: {best_loss:.4f}")
                    checkpoint_path = f"./vilt_lora_checkpoint_epoch{epoch+1}_step{step+1}"
                    model.save_pretrained(checkpoint_path)
                    
        except Exception as e:
            print(f"Error in training batch: {e}")
            import traceback
            traceback.print_exc()  # Print full traceback for debugging
            continue
    
    avg_loss = running_loss / len(train_dataloader)
    print(f"Average loss: {avg_loss:.4f}")
    
    # Periodic evaluation
    if (epoch + 1) % 2 == 0 or epoch == num_epochs - 1:
        # Quick validation on a subset
        model.eval()
        val_subset = train_df.sample(min(100, len(train_df)), random_state=epoch)
        correct = 0
        total = 0
        
        print("Running quick validation...")
        for _, row in tqdm(val_subset.iterrows(), total=len(val_subset)):
            try:
                image_path = row["full_path"]
                question = row["question"]
                true_answer = normalize_text(row["answer"])
                
                pred_answer = predict_answer_finetuned(image_path, question, model, processor, 
                                                      train_dataset.id_to_answer, MAX_LENGTH, device)
                pred_answer = normalize_text(pred_answer)
                
                if pred_answer == true_answer:
                    correct += 1
                total += 1
            except Exception as e:
                print(f"Error in validation: {e}")
                
        val_accuracy = correct / total if total > 0 else 0
        print(f"Validation Accuracy: {val_accuracy:.4f} ({correct}/{total})")

# === Save the fine-tuned model ===
model_save_path = "./vilt_lora_finetuned_final"
model.save_pretrained(model_save_path)
processor.save_pretrained(model_save_path)
print(f"Model saved to {model_save_path}")

# === Evaluation ===
print("Starting evaluation...")
model.eval()

# Create a test dataset
test_dataset = VQADataset(test_df, processor, max_length=MAX_LENGTH)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=0
)

# Run predictions
predictions = []
ground_truth = []

print("Running predictions on test set...")
for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
    try:
        image_path = row["full_path"]
        question = row["question"]
        true_answer = normalize_text(row["answer"])
        
        pred_answer = predict_answer_finetuned(
            image_path, question, model, processor, 
            train_dataset.id_to_answer, MAX_LENGTH, device
        )
        pred_answer = normalize_text(pred_answer)
        
        predictions.append(pred_answer)
        ground_truth.append(true_answer)
    except Exception as e:
        print(f"Error with prediction: {e}")
        predictions.append("unknown")
        ground_truth.append(true_answer)

# Add predictions to dataframe
test_df["norm_answer"] = test_df["answer"].apply(normalize_text)
test_df["prediction"] = predictions
test_df["norm_prediction"] = test_df["prediction"].apply(normalize_text)

# Compute Matches
test_df["match"] = test_df["norm_answer"] == test_df["norm_prediction"]

# Metrics
accuracy = accuracy_score(test_df["norm_answer"], test_df["norm_prediction"])
f1 = f1_score(test_df["norm_answer"], test_df["norm_prediction"], average="macro", zero_division=0)

# Print Results
print("\n=== Evaluation Metrics for LoRA Fine-tuned ViLT ===")
print(f"Accuracy     : {accuracy:.4f}")
print(f"F1 Score     : {f1:.4f}")

# Error analysis
print("\n=== Error Analysis ===")
from collections import Counter
error_cases = test_df[~test_df["match"]]
print(f"Total errors: {len(error_cases)}")

top_predictions = Counter(error_cases["norm_prediction"]).most_common(5)
print(f"Top wrong predictions: {top_predictions}")

# Check the most common answer categories
print("\n=== Common Reference Answer Categories ===")
answer_cats = Counter(test_df["norm_answer"]).most_common(10)
print(f"Top reference answers: {answer_cats}")

# Sample predictions
print("\n=== Sample Predictions vs References ===")
sample = test_df.sample(10, random_state=42)
for _, row in sample.iterrows():
    print(f"Question : {row['question']}")
    print(f"Reference: {row['answer']} (Normalized: {row['norm_answer']})")
    print(f"Prediction: {row['prediction']} (Normalized: {row['norm_prediction']})")
    print(f"Match    : {row['match']}")
    print("---")

# Save results
test_df.to_csv("test_results_lora_finetuned_improved.csv", index=False)
print("Results saved to test_results_lora_finetuned_improved.csv")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Training samples: 9639
Testing samples: 2410
Using device: cuda




Analyzing answer distribution in training data...
Top 10 most common answers: answer
Yes         680
Plastic     380
Black       355
Two         295
Blue        249
Pink        202
Three       164
Abstract    143
Silicone    139
White       134
Name: count, dtype: int64
Number of unique answers: 1765
Dataset created with 1736 unique normalized answers
Number of unique answers: 1736
Trainable parameters: 3,833,856 out of 117,355,976 (3.27%)
Starting training...
Epoch 1/5


Epoch 1:   0%|          | 1/1205 [00:00<08:09,  2.46it/s, loss=7.62]

Step 0, LR: 8.31e-07


Epoch 1:   8%|▊         | 101/1205 [00:35<06:03,  3.04it/s, loss=7.47]

Step 100, LR: 8.39e-05


Epoch 1:  17%|█▋        | 200/1205 [01:07<05:35,  3.00it/s, loss=7.31]

Saving checkpoint with loss: 7.3084


Epoch 1:  17%|█▋        | 201/1205 [01:07<05:31,  3.02it/s, loss=7.31]

Step 200, LR: 1.67e-04


Epoch 1:  25%|██▍       | 301/1205 [01:40<05:01,  2.99it/s, loss=7.11]

Step 300, LR: 2.50e-04


Epoch 1:  33%|███▎      | 400/1205 [02:13<04:34,  2.93it/s, loss=6.95]

Saving checkpoint with loss: 6.9459


Epoch 1:  33%|███▎      | 401/1205 [02:14<04:30,  2.97it/s, loss=6.95]

Step 400, LR: 3.33e-04


Epoch 1:  42%|████▏     | 501/1205 [02:47<03:51,  3.04it/s, loss=6.83]

Step 500, LR: 4.16e-04


Epoch 1:  50%|████▉     | 600/1205 [03:20<03:26,  2.93it/s, loss=6.73]

Saving checkpoint with loss: 6.7262


Epoch 1:  50%|████▉     | 601/1205 [03:20<03:22,  2.98it/s, loss=6.72]

Step 600, LR: 4.99e-04


Epoch 1:  58%|█████▊    | 701/1205 [03:53<02:46,  3.02it/s, loss=6.62]

Step 700, LR: 4.91e-04


Epoch 1:  66%|██████▋   | 800/1205 [04:26<02:23,  2.81it/s, loss=6.53]

Saving checkpoint with loss: 6.5282


Epoch 1:  66%|██████▋   | 801/1205 [04:26<02:22,  2.83it/s, loss=6.53]

Step 800, LR: 4.82e-04


Epoch 1:  75%|███████▍  | 901/1205 [04:59<01:39,  3.06it/s, loss=6.44]

Step 900, LR: 4.72e-04


Epoch 1:  83%|████████▎ | 1000/1205 [05:32<01:09,  2.94it/s, loss=6.37]

Saving checkpoint with loss: 6.3745


Epoch 1:  83%|████████▎ | 1001/1205 [05:32<01:08,  2.97it/s, loss=6.37]

Step 1000, LR: 4.63e-04


Epoch 1:  91%|█████████▏| 1101/1205 [06:05<00:34,  3.02it/s, loss=6.32]

Step 1100, LR: 4.54e-04


Epoch 1: 100%|█████████▉| 1200/1205 [06:38<00:01,  2.98it/s, loss=6.26]

Saving checkpoint with loss: 6.2614


Epoch 1: 100%|█████████▉| 1201/1205 [06:38<00:01,  2.98it/s, loss=6.26]

Step 1200, LR: 4.45e-04


Epoch 1: 100%|██████████| 1205/1205 [06:39<00:00,  3.01it/s, loss=6.26]


Average loss: 6.2593
Epoch 2/5


Epoch 2:   0%|          | 1/1205 [00:00<06:29,  3.09it/s, loss=6.12]

Step 0, LR: 4.44e-04


Epoch 2:   8%|▊         | 101/1205 [00:33<06:04,  3.03it/s, loss=5.6] 

Step 100, LR: 4.35e-04


Epoch 2:  17%|█▋        | 200/1205 [01:06<05:41,  2.94it/s, loss=5.53]

Saving checkpoint with loss: 5.5317


Epoch 2:  17%|█▋        | 201/1205 [01:06<05:37,  2.97it/s, loss=5.53]

Step 200, LR: 4.26e-04


Epoch 2:  25%|██▍       | 301/1205 [01:39<04:54,  3.07it/s, loss=5.48]

Step 300, LR: 4.17e-04


Epoch 2:  33%|███▎      | 400/1205 [02:11<04:32,  2.96it/s, loss=5.46]

Saving checkpoint with loss: 5.4568


Epoch 2:  33%|███▎      | 401/1205 [02:11<04:28,  2.99it/s, loss=5.46]

Step 400, LR: 4.07e-04


Epoch 2:  42%|████▏     | 501/1205 [02:44<03:48,  3.09it/s, loss=5.4] 

Step 500, LR: 3.98e-04


Epoch 2:  50%|████▉     | 600/1205 [03:17<03:23,  2.98it/s, loss=5.4] 

Saving checkpoint with loss: 5.3973


Epoch 2:  50%|████▉     | 601/1205 [03:17<03:21,  3.00it/s, loss=5.4]

Step 600, LR: 3.89e-04


Epoch 2:  58%|█████▊    | 701/1205 [03:50<02:45,  3.05it/s, loss=5.35]

Step 700, LR: 3.80e-04


Epoch 2:  66%|██████▋   | 800/1205 [04:22<02:16,  2.96it/s, loss=5.35]

Saving checkpoint with loss: 5.3470


Epoch 2:  66%|██████▋   | 801/1205 [04:23<02:15,  2.98it/s, loss=5.35]

Step 800, LR: 3.71e-04


Epoch 2:  75%|███████▍  | 901/1205 [04:56<01:39,  3.06it/s, loss=5.34]

Step 900, LR: 3.61e-04


Epoch 2:  83%|████████▎ | 1000/1205 [05:28<01:08,  2.99it/s, loss=5.31]

Saving checkpoint with loss: 5.3115


Epoch 2:  83%|████████▎ | 1001/1205 [05:28<01:07,  3.01it/s, loss=5.31]

Step 1000, LR: 3.52e-04


Epoch 2:  91%|█████████▏| 1101/1205 [06:01<00:33,  3.07it/s, loss=5.3] 

Step 1100, LR: 3.43e-04


Epoch 2: 100%|█████████▉| 1200/1205 [06:34<00:01,  2.93it/s, loss=5.28]

Saving checkpoint with loss: 5.2822


Epoch 2: 100%|█████████▉| 1201/1205 [06:34<00:01,  2.97it/s, loss=5.28]

Step 1200, LR: 3.34e-04


Epoch 2: 100%|██████████| 1205/1205 [06:35<00:00,  3.04it/s, loss=5.28]


Average loss: 5.2813
Running quick validation...


100%|██████████| 100/100 [00:03<00:00, 31.29it/s]


Validation Accuracy: 0.3300 (33/100)
Epoch 3/5


Epoch 3:   0%|          | 1/1205 [00:00<06:31,  3.08it/s, loss=6.06]

Step 0, LR: 3.33e-04


Epoch 3:   8%|▊         | 101/1205 [00:33<05:58,  3.08it/s, loss=4.99]

Step 100, LR: 3.24e-04


Epoch 3:  17%|█▋        | 200/1205 [01:05<05:36,  2.98it/s, loss=4.95]

Saving checkpoint with loss: 4.9538


Epoch 3:  17%|█▋        | 201/1205 [01:05<05:34,  3.00it/s, loss=4.96]

Step 200, LR: 3.15e-04


Epoch 3:  25%|██▍       | 301/1205 [01:38<04:53,  3.08it/s, loss=4.9] 

Step 300, LR: 3.06e-04


Epoch 3:  33%|███▎      | 400/1205 [02:11<04:31,  2.96it/s, loss=4.9] 

Saving checkpoint with loss: 4.8963


Epoch 3:  33%|███▎      | 401/1205 [02:11<04:35,  2.92it/s, loss=4.89]

Step 400, LR: 2.96e-04


Epoch 3:  42%|████▏     | 501/1205 [02:44<03:56,  2.98it/s, loss=4.9] 

Step 500, LR: 2.87e-04


Epoch 3:  50%|████▉     | 600/1205 [03:16<03:24,  2.96it/s, loss=4.88]

Saving checkpoint with loss: 4.8849


Epoch 3:  50%|████▉     | 601/1205 [03:16<03:22,  2.98it/s, loss=4.89]

Step 600, LR: 2.78e-04


Epoch 3:  58%|█████▊    | 701/1205 [03:49<02:45,  3.05it/s, loss=4.9] 

Step 700, LR: 2.69e-04


Epoch 3:  66%|██████▋   | 801/1205 [04:22<02:11,  3.07it/s, loss=4.89]

Step 800, LR: 2.59e-04


Epoch 3:  75%|███████▍  | 901/1205 [04:54<01:41,  3.00it/s, loss=4.86]

Step 900, LR: 2.50e-04


Epoch 3:  83%|████████▎ | 1000/1205 [05:27<01:08,  2.99it/s, loss=4.88]

Saving checkpoint with loss: 4.8771


Epoch 3:  83%|████████▎ | 1001/1205 [05:27<01:07,  3.02it/s, loss=4.88]

Step 1000, LR: 2.41e-04


Epoch 3:  91%|█████████▏| 1101/1205 [06:00<00:33,  3.09it/s, loss=4.89]

Step 1100, LR: 2.32e-04


Epoch 3: 100%|█████████▉| 1200/1205 [06:32<00:01,  2.94it/s, loss=4.88]

Saving checkpoint with loss: 4.8751


Epoch 3: 100%|█████████▉| 1201/1205 [06:32<00:01,  2.98it/s, loss=4.88]

Step 1200, LR: 2.23e-04


Epoch 3: 100%|██████████| 1205/1205 [06:34<00:00,  3.06it/s, loss=4.88]


Average loss: 4.8756
Epoch 4/5


Epoch 4:   0%|          | 1/1205 [00:00<06:31,  3.08it/s, loss=4.74]

Step 0, LR: 2.22e-04


Epoch 4:   8%|▊         | 101/1205 [00:32<05:59,  3.07it/s, loss=4.73]

Step 100, LR: 2.13e-04


Epoch 4:  17%|█▋        | 200/1205 [01:05<05:36,  2.99it/s, loss=4.69]

Saving checkpoint with loss: 4.6939


Epoch 4:  17%|█▋        | 201/1205 [01:05<05:32,  3.02it/s, loss=4.69]

Step 200, LR: 2.04e-04


Epoch 4:  25%|██▍       | 301/1205 [01:38<04:52,  3.09it/s, loss=4.66]

Step 300, LR: 1.94e-04


Epoch 4:  33%|███▎      | 400/1205 [02:10<04:30,  2.98it/s, loss=4.66]

Saving checkpoint with loss: 4.6585


Epoch 4:  33%|███▎      | 401/1205 [02:10<04:27,  3.01it/s, loss=4.66]

Step 400, LR: 1.85e-04


Epoch 4:  42%|████▏     | 501/1205 [02:43<03:51,  3.05it/s, loss=4.63]

Step 500, LR: 1.76e-04


Epoch 4:  50%|████▉     | 600/1205 [03:16<03:22,  2.98it/s, loss=4.62]

Saving checkpoint with loss: 4.6169


Epoch 4:  50%|████▉     | 601/1205 [03:16<03:20,  3.01it/s, loss=4.61]

Step 600, LR: 1.67e-04


Epoch 4:  58%|█████▊    | 701/1205 [03:49<02:46,  3.03it/s, loss=4.59]

Step 700, LR: 1.58e-04


Epoch 4:  66%|██████▋   | 800/1205 [04:21<02:15,  3.00it/s, loss=4.58]

Saving checkpoint with loss: 4.5761


Epoch 4:  66%|██████▋   | 801/1205 [04:21<02:14,  3.00it/s, loss=4.58]

Step 800, LR: 1.48e-04


Epoch 4:  75%|███████▍  | 901/1205 [04:54<01:38,  3.09it/s, loss=4.56]

Step 900, LR: 1.39e-04


Epoch 4:  83%|████████▎ | 1000/1205 [05:26<01:08,  2.99it/s, loss=4.56]

Saving checkpoint with loss: 4.5606


Epoch 4:  83%|████████▎ | 1001/1205 [05:27<01:07,  3.02it/s, loss=4.56]

Step 1000, LR: 1.30e-04


Epoch 4:  91%|█████████▏| 1101/1205 [06:00<00:34,  3.00it/s, loss=4.56]

Step 1100, LR: 1.21e-04


Epoch 4: 100%|█████████▉| 1200/1205 [06:32<00:01,  2.97it/s, loss=4.55]

Saving checkpoint with loss: 4.5535


Epoch 4: 100%|█████████▉| 1201/1205 [06:32<00:01,  3.01it/s, loss=4.56]

Step 1200, LR: 1.11e-04


Epoch 4: 100%|██████████| 1205/1205 [06:33<00:00,  3.06it/s, loss=4.56]


Average loss: 4.5551
Running quick validation...


100%|██████████| 100/100 [00:03<00:00, 31.74it/s]


Validation Accuracy: 0.4100 (41/100)
Epoch 5/5


Epoch 5:   0%|          | 1/1205 [00:00<06:42,  2.99it/s, loss=3.95]

Step 0, LR: 1.11e-04


Epoch 5:   8%|▊         | 101/1205 [00:33<06:07,  3.01it/s, loss=4.32]

Step 100, LR: 1.02e-04


Epoch 5:  17%|█▋        | 200/1205 [01:05<05:39,  2.96it/s, loss=4.38]

Saving checkpoint with loss: 4.3838


Epoch 5:  17%|█▋        | 201/1205 [01:05<05:34,  3.00it/s, loss=4.38]

Step 200, LR: 9.26e-05


Epoch 5:  25%|██▍       | 301/1205 [01:38<04:55,  3.06it/s, loss=4.39]

Step 300, LR: 8.33e-05


Epoch 5:  33%|███▎      | 400/1205 [02:11<04:39,  2.88it/s, loss=4.37]

Saving checkpoint with loss: 4.3734


Epoch 5:  33%|███▎      | 401/1205 [02:11<04:37,  2.89it/s, loss=4.38]

Step 400, LR: 7.41e-05


Epoch 5:  42%|████▏     | 501/1205 [02:44<03:51,  3.04it/s, loss=4.35]

Step 500, LR: 6.49e-05


Epoch 5:  50%|████▉     | 600/1205 [03:16<03:22,  2.98it/s, loss=4.33]

Saving checkpoint with loss: 4.3302


Epoch 5:  50%|████▉     | 601/1205 [03:17<03:23,  2.97it/s, loss=4.33]

Step 600, LR: 5.57e-05


Epoch 5:  58%|█████▊    | 701/1205 [03:49<02:45,  3.05it/s, loss=4.33]

Step 700, LR: 4.65e-05


Epoch 5:  66%|██████▋   | 800/1205 [04:22<02:16,  2.97it/s, loss=4.32]

Saving checkpoint with loss: 4.3240


Epoch 5:  66%|██████▋   | 801/1205 [04:22<02:17,  2.94it/s, loss=4.32]

Step 800, LR: 3.72e-05


Epoch 5:  75%|███████▍  | 901/1205 [04:55<01:39,  3.05it/s, loss=4.31]

Step 900, LR: 2.80e-05


Epoch 5:  83%|████████▎ | 1000/1205 [05:28<01:09,  2.97it/s, loss=4.31]

Saving checkpoint with loss: 4.3096


Epoch 5:  83%|████████▎ | 1001/1205 [05:28<01:08,  2.98it/s, loss=4.31]

Step 1000, LR: 1.88e-05


Epoch 5:  91%|█████████▏| 1101/1205 [06:01<00:33,  3.08it/s, loss=4.31]

Step 1100, LR: 9.59e-06


Epoch 5: 100%|█████████▉| 1200/1205 [06:34<00:01,  2.99it/s, loss=4.3] 

Saving checkpoint with loss: 4.3044


Epoch 5: 100%|█████████▉| 1201/1205 [06:34<00:01,  3.02it/s, loss=4.3]

Step 1200, LR: 3.69e-07


Epoch 5: 100%|██████████| 1205/1205 [06:35<00:00,  3.04it/s, loss=4.3]


Average loss: 4.3018
Running quick validation...


100%|██████████| 100/100 [00:03<00:00, 31.28it/s]


Validation Accuracy: 0.4900 (49/100)
Model saved to ./vilt_lora_finetuned_final
Starting evaluation...
Dataset created with 738 unique normalized answers
Running predictions on test set...


100%|██████████| 2410/2410 [01:17<00:00, 31.04it/s]


=== Evaluation Metrics for LoRA Fine-tuned ViLT ===
Accuracy     : 0.3983
F1 Score     : 0.0501

=== Error Analysis ===
Total errors: 1450
Top wrong predictions: [('redmi', 217), ('two', 198), ('floral', 89), ('abstract', 53), ('white', 49)]

=== Common Reference Answer Categories ===
Top reference answers: [('yes', 176), ('plastic', 109), ('black', 99), ('two', 72), ('blue', 53), ('silicone', 49), ('pink', 40), ('three', 39), ('brown', 36), ('white', 34)]

=== Sample Predictions vs References ===
Question : Considering the image and metadata, where would this blanket be most appropriately used?
Reference: Outdoors (Normalized: outdoors)
Prediction: outdoors (Normalized: outdoors)
Match    : True
---
Question : Based on the image and metadata, where would this product most likely be found in a physical store?
Reference: Kitchen (Normalized: kitchen)
Prediction: amazon (Normalized: amazon)
Match    : False
---
Question : What is the color of the shoe?
Reference: Grey (Normalized: grey)


