In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from PIL import Image
import torchvision.transforms as transforms
import json
import os


In [2]:
class DisasterVQADataset(Dataset):
    def __init__(self, images_base_dir, annotations_path, class_to_label_path, transform=None):
        with open(annotations_path, 'r') as f:
            self.annotations = json.load(f)
        with open(class_to_label_path, 'r') as f:
            self.class_to_label = json.load(f)
        self.images_base_dir = images_base_dir
        self.transform = transform if transform else transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        
        # Filter out missing images during initialization
        self.valid_indices = []
        for idx, item in enumerate(self.annotations):
            img_filename = os.path.basename(item['Image_dir'])
            img_path = os.path.join(self.images_base_dir, img_filename)
            if os.path.exists(img_path):
                self.valid_indices.append(idx)
            else:
                print(f"Warning: Image not found: {img_path}")

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        # Use valid_indices to skip missing images
        actual_idx = self.valid_indices[idx]
        item = self.annotations[actual_idx]
        
        # Extract only the image filename, ignoring folder paths in Image_dir
        img_filename = os.path.basename(item['Image_dir'])
        img_path = os.path.join(self.images_base_dir, img_filename)

        try:
            image = Image.open(img_path).convert('RGB')
            image = self.transform(image)
        except Exception as e:
            # Return a blank image if there's an error
            print(f"Error loading image {img_path}: {e}")
            image = torch.zeros(3, 224, 224)  # Default blank image

        question = item['Question']
        encoding = self.tokenizer(question, padding='max_length', max_length=20,
                                  truncation=True, return_tensors='pt')
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)

        answer_str = item['Ground_Truth']
        answer_token = self.class_to_label.get(answer_str, 0)  # fallback to 0

        return image, input_ids, attention_mask, answer_token


In [3]:
# Visual encoder (CNN)
class VisualEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, 3, 2, 1), nn.ReLU(),
            nn.Conv2d(64, 128, 3, 2, 1), nn.ReLU(),
            nn.Conv2d(128, 256, 3, 2, 1), nn.ReLU(),
            nn.AdaptiveAvgPool2d((14, 14))
        )
    
    def forward(self, x):
        return self.cnn(x)


In [4]:
# Attention module fusing image and question features
class AttentionModule(nn.Module):
    def __init__(self, img_feat_dim, txt_feat_dim):
        super().__init__()
        self.img_proj = nn.Conv2d(img_feat_dim, 256, 1)
        self.text_proj = nn.Linear(txt_feat_dim, 256)
        self.attn_conv = nn.Conv2d(256, 1, 1)
    
    def forward(self, img_feats, text_feats):
        img_proj = self.img_proj(img_feats)
        text_proj = self.text_proj(text_feats).unsqueeze(-1).unsqueeze(-1)
        joint = torch.tanh(img_proj + text_proj)
        attn_scores = self.attn_conv(joint)
        attn_weights = torch.softmax(attn_scores.view(attn_scores.size(0), -1), dim=-1)
        attn_weights = attn_weights.view_as(attn_scores)
        attended = (img_feats * attn_weights).sum(dim=[2, 3])
        return attended, attn_weights


In [5]:
# SAM-VQA model combining vision and language
class SAMVQA(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.visual_encoder = VisualEncoder()
        self.language_encoder = BertModel.from_pretrained("bert-base-uncased")
        self.attention = AttentionModule(256, 768)  # 768 from bert pooler output
        self.classifier = nn.Linear(256, num_classes)

    def forward(self, images, input_ids, attention_mask):
        img_feats = self.visual_encoder(images)
        txt_feats = self.language_encoder(input_ids=input_ids, attention_mask=attention_mask).pooler_output
        attended_feat, attn_weights = self.attention(img_feats, txt_feats)
        logits = self.classifier(attended_feat)
        return logits, attn_weights



In [6]:
def analyze_dataset_issues(images_base_dir, annotations_path):
    """Analyze what's wrong with the dataset"""
    print("\n=== DATASET ANALYSIS ===")
    
    # Load annotations
    with open(annotations_path, 'r') as f:
        annotations = json.load(f)
    
    # Get actual image files in directory
    actual_images = set([f for f in os.listdir(images_base_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
    
    # Get referenced images from annotations
    referenced_images = set()
    for item in annotations:
        img_filename = os.path.basename(item['Image_dir'])
        referenced_images.add(img_filename)
    
    print(f"Actual images in directory: {len(actual_images)}")
    print(f"Referenced images in annotations: {len(referenced_images)}")
    print(f"Annotations count: {len(annotations)}")
    
    # Find missing images
    missing = referenced_images - actual_images
    print(f"\nMissing images: {len(missing)}")
    
    # Find extra images (not referenced)
    extra = actual_images - referenced_images
    print(f"Extra images (not referenced): {len(extra)}")
    
    # Show sample of missing images
    if missing:
        print(f"\nSample of missing images (first 10):")
        for img in list(missing)[:10]:
            print(f"  - {img}")
    
    # Check for case sensitivity issues
    actual_images_lower = set([f.lower() for f in actual_images])
    missing_due_to_case = []
    for ref_img in referenced_images:
        if ref_img not in actual_images and ref_img.lower() in actual_images_lower:
            missing_due_to_case.append(ref_img)
    
    if missing_due_to_case:
        print(f"\nPossible case sensitivity issues ({len(missing_due_to_case)} images):")
        for img in missing_due_to_case[:5]:
            print(f"  - {img} (referenced) vs actual: {[f for f in actual_images if f.lower() == img.lower()]}")
    
    return len(missing) == 0


def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Use your actual paths here!
    images_base_dir = "data/Images/train_images"
    annotations_path = "data/data/train_annotations.json"
    class_to_label_path = "data/data/class_to_label.json"

    # First, analyze what's wrong
    all_images_exist = analyze_dataset_issues(images_base_dir, annotations_path)
    
    if not all_images_exist:
        print("\n❌ Dataset has missing images. Please fix the issues above before training.")
        print("\nPossible solutions:")
        print("1. Download the missing images")
        print("2. Check if images are in subdirectories")
        print("3. Fix the Image_dir paths in your JSON file")
        print("4. Use case-correct filenames")
        return

    try:
        dataset = DisasterVQADataset(images_base_dir, annotations_path, class_to_label_path)
        print(f"Dataset loaded with {len(dataset)} valid samples out of {len(dataset.annotations)} total")
        
        if len(dataset) == 0:
            print("No valid samples found! Check your image paths.")
            return
            
        dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

        with open(class_to_label_path, 'r') as f:
            class_to_label = json.load(f)
        num_classes = len(class_to_label)
        print(f"Number of classes: {num_classes}")

        model = SAMVQA(num_classes).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
        criterion = nn.CrossEntropyLoss()

        print("Starting training...")
        for epoch in range(5):
            model.train()
            total_loss = 0
            batch_count = 0
            
            for batch_idx, (images, input_ids, attention_mask, answers) in enumerate(dataloader):
                images = images.to(device)
                input_ids = input_ids.to(device)
                attention_mask = attention_mask.to(device)
                answers = answers.to(device)

                optimizer.zero_grad()
                logits, _ = model(images, input_ids, attention_mask)
                loss = criterion(logits, answers)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
                batch_count += 1
                
                if batch_idx % 10 == 0:  # Print progress every 10 batches
                    print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}")
            
            avg_loss = total_loss / batch_count
            print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")
            
    except Exception as e:
        print(f"Error during training: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    train()

Using device: cuda

=== DATASET ANALYSIS ===
Actual images in directory: 1364
Referenced images in annotations: 411
Annotations count: 1833

Missing images: 411
Extra images (not referenced): 1364

Sample of missing images (first 10):
  - 7880.JPG
  - 7143.JPG
  - 8797.JPG
  - 9231.JPG
  - 7431.JPG
  - 7232.JPG
  - 6594.JPG
  - 9344.JPG
  - 7876.JPG
  - 6898.JPG

❌ Dataset has missing images. Please fix the issues above before training.

Possible solutions:
1. Download the missing images
2. Check if images are in subdirectories
3. Fix the Image_dir paths in your JSON file
4. Use case-correct filenames
