In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/

In [None]:
%pip install ultralytics transformers
import ultralytics
ultralytics.checks()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from transformers import BertTokenizer, BertModel, BertConfig
from ultralytics import YOLO
from collections import defaultdict
from PIL import Image
from tqdm import tqdm
import json
import csv
import os

In [None]:
# Ensure CUDA (GPU support) is available if possible, else use CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

In [None]:
# Load the trained YOLOv8 model
yolo_model yolo_model = YOLO('/content/drive/MyDrive/00_PFE/Object_Detection/Training_Results/Yolov8-V6/Results/runs/train/experiment/weights/best.pt').to(device)

In [None]:
# Define the label mapping
label_mapping = [
    "flooded", "non flooded", "flooded,non flooded", "Yes", "No",
    "0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
    "10", "11", "12", "13", "14", "15", "16", "17", "18", "19",
    "20", "21", "22", "23", "24", "25", "26", "27", "28", "29",
    "30", "31", "32", "33", "34", "35", "36", "37", "38", "39",
    "40", "41", "42", "43", "44", "45", "46", "47", "48", "49", "50"
]

question_type_mapping = {
    "Condition_Recognition": 0,
    "Yes_No": 1,
    "Simple_Counting": 2,
    "Complex_Counting": 3
}

In [None]:
# Function to extract features from YOLOv8
def extract_yolo_features(image_path, model, device):
    results = model(image_path)

    # Initialize lists to store extracted features
    boxes_list = []
    cls_list = []

    for result in results:
        if result.boxes is not None:
            boxes = result.boxes.xyxy.to(device)  # Bounding box coordinates
            classes = result.boxes.cls.to(device)  # Class values
            boxes_list.append(boxes)
            cls_list.append(classes)

    # Combine features into a single tensor
    if boxes_list:
        features = torch.cat([torch.cat(boxes_list), torch.cat(cls_list).unsqueeze(1)], dim=1)
    else:
        features = torch.empty((0, 5), device=device)

    return features

In [None]:
# VQADataset class
class VQADataset(Dataset):
    def __init__(self, annotations_file, img_dir, tokenizer, transform=None):
        with open(annotations_file, 'r') as f:
            self.annotations = json.load(f)
        self.img_dir = img_dir
        self.transform = transform
        self.tokenizer = tokenizer
        self.img_to_annotations = self._group_by_image()

    def _group_by_image(self):
        img_to_annotations = defaultdict(list)
        for idx, annotation in self.annotations.items():
            img_to_annotations[annotation['Image_ID']].append(annotation)
        return img_to_annotations

    def __len__(self):
        return len(self.img_to_annotations)

    def __getitem__(self, idx):
        image_id = list(self.img_to_annotations.keys())[idx]
        annotations = self.img_to_annotations[image_id]
        img_path = os.path.join(self.img_dir, image_id)
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        questions = []
        answers = []
        question_types = []
        for annotation in annotations:
            inputs = self.tokenizer.encode_plus(
                annotation['Question'],
                add_special_tokens=True,
                return_tensors='pt',
                padding='max_length',
                truncation=True,
                max_length=64
            )
            question = inputs['input_ids'].squeeze(0).to(device)
            attention_mask = inputs['attention_mask'].squeeze(0).to(device)
            answer_text = str(annotation['Ground_Truth'])
            answer_idx = label_mapping.index(answer_text)
            question_type_idx = question_type_mapping[annotation['Question_Type']]
            questions.append((question, attention_mask))
            answers.append(torch.tensor(answer_idx, device=device))
            question_types.append(torch.tensor(question_type_idx, device=device))
        return {
            'image_path': img_path,
            'questions': questions,
            'attention_masks': [am for _, am in questions],
            'answers': torch.stack(answers),
            'question_types': torch.stack(question_types)
        }

In [None]:
def custom_collate_fn(batch):
    batch_image_paths = [item['image_path'] for item in batch]
    batch_questions = [q for item in batch for q, _ in item['questions']]
    batch_attention_masks = [am for item in batch for _, am in item['questions']]
    batch_answers = torch.cat([item['answers'] for item in batch])
    batch_question_types = torch.cat([item['question_types'] for item in batch])
    num_questions_per_image = [len(item['questions']) for item in batch]
    return {
        'image_paths': batch_image_paths,
        'questions': batch_questions,
        'attention_masks': batch_attention_masks,
        'answers': batch_answers,
        'question_types': batch_question_types,
        'num_questions_per_image': num_questions_per_image
    }

In [None]:
class VQAModel(nn.Module):
    def __init__(self, bert_model, yolo_input_dim, hidden_dim, combined_dim, vocab_size, num_question_types):
        super(VQAModel, self).__init__()
        self.bert_model = bert_model
        self.fc_yolo = nn.Linear(yolo_input_dim, hidden_dim)
        self.fc_question_type = nn.Embedding(num_question_types, hidden_dim)
        self.fc_proj = nn.Linear(hidden_dim * 2 + 768, combined_dim)  # Project to BERT input dimension

    def forward(self, image_features, questions, attention_masks, question_types, num_questions_per_image):
        image_features = [self.fc_yolo(image_feature) for image_feature in image_features]
        image_features = torch.stack(image_features)

        text_features = [self.bert_model(question.unsqueeze(0).to(image_features.device), attention_mask=attention_mask.unsqueeze(0).to(image_features.device)).pooler_output for question, attention_mask in zip(questions, attention_masks)]
        text_features = torch.cat(text_features, dim=0)

        expanded_image_features = []
        for image_feature, num_questions in zip(image_features, num_questions_per_image):
            expanded_image_features.append(image_feature.repeat(num_questions, 1))
        expanded_image_features = torch.cat(expanded_image_features, dim=0)

        question_type_features = self.fc_question_type(question_types)

        combined_features = torch.cat((expanded_image_features, text_features, question_type_features), dim=1)
        projected_features = self.fc_proj(combined_features)

        # Reshape projected features to match BERT's expected input dimensions
        batch_size = projected_features.size(0)
        seq_length = 1  # Since we're treating each feature set as a single "sequence"
        projected_features = projected_features.view(batch_size, seq_length, -1)

        # Prepare BERT inputs
        extended_attention_mask = torch.ones((batch_size, seq_length), device=projected_features.device)

        # Pass through BERT model
        outputs = self.bert_model(inputs_embeds=projected_features, attention_mask=extended_attention_mask)
        pooled_output = outputs.pooler_output

        return pooled_output

In [None]:
# Initialize tokenizer, BERT model, and VQA model
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased').to(device)
num_classes = len(label_mapping)
num_question_types = len(question_type_mapping)
hidden_dim = 256
combined_dim = 768  # Adjust based on BERT input dimensions
vqa_model = VQAModel(bert_model=bert_model, yolo_input_dim=5, hidden_dim=hidden_dim, combined_dim=combined_dim, vocab_size=num_classes, num_question_types=num_question_types).to(device)

In [None]:
# Define optimizer and loss function
optimizer = torch.optim.Adam(vqa_model.parameters(), lr=0.0001)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
# Initialize dataset and dataloader
annotations_file = '/content/drive/MyDrive/00_PFE/DataSet/Visual_Question_Answering /FloodNet Challenge @ EARTHVISION 2021 - Track 2/Questions/Training Question.json'
img_dir = '/content/drive/MyDrive/00_PFE/DataSet/Visual_Question_Answering /FloodNet Challenge @ EARTHVISION 2021 - Track 2/Images/Train_Image'
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
dataset = VQADataset(annotations_file, img_dir, bert_tokenizer, transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)

In [None]:
# Setup logging
def setup_logging():
    with open('log.csv', 'w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["Epoch", "Average Loss", "Average Accuracy"])

def log_epoch(epoch, avg_loss, avg_accuracy):
    with open('log.csv', 'a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow([epoch, f"{avg_loss:.4f}", f"{avg_accuracy * 100:.2f}%"])

In [None]:
# Function to compute accuracy
def compute_accuracy(predictions, labels):
    _, predicted = torch.max(predictions, 1)
    correct = (predicted == labels).float().sum()
    return correct / labels.size(0)

In [None]:
setup_logging()
num_epochs = 10
best_accuracy = 0.0

for epoch in range(num_epochs):
    total_loss = 0
    total_accuracy = 0
    total_batches = 0
    vqa_model.train()
    with tqdm(dataloader, desc=f"Epoch {epoch + 1}") as pbar:
        for batch in pbar:
            image_paths = batch['image_paths']
            questions = batch['questions']
            attention_masks = batch['attention_masks']
            answers = batch['answers']
            question_types = batch['question_types']
            num_questions_per_image = batch['num_questions_per_image']

            # Extract features using YOLOv8
            image_features_list = []
            for image_path in image_paths:
                features = extract_yolo_features(image_path, yolo_model, device)
                if features.nelement() == 0:
                    features = torch.zeros((1, 5), device=device)  # Handle no detections case
                image_features_list.append(features.mean(dim=0))
            image_features = torch.stack(image_features_list)

            optimizer.zero_grad()
            outputs = vqa_model(image_features, questions, attention_masks, question_types, num_questions_per_image)
            loss = criterion(outputs, answers)
            if torch.isnan(loss):
                print(f"Encountered NaN loss, skipping this batch")
                continue
            loss.backward()
            optimizer.step()

            batch_loss = loss.item()
            batch_accuracy = compute_accuracy(outputs, answers).item()
            total_loss += batch_loss
            total_accuracy += batch_accuracy
            total_batches += 1
            pbar.set_postfix(Loss=batch_loss, Accuracy=f"{batch_accuracy:.4f}")

    avg_loss = total_loss / total_batches
    avg_accuracy = total_accuracy / total_batches
    print(f"Epoch {epoch + 1} - Average Loss: {avg_loss:.4f}, Average Accuracy: {avg_accuracy:.4f}")

    log_epoch(epoch + 1, avg_loss, avg_accuracy)

    # Save the best model
    if avg_accuracy > best_accuracy:
        best_accuracy = avg_accuracy
        torch.save(vqa_model.state_dict(), f"/content/drive/MyDrive/00_PFE/VQA/Code-V3/VQAModel_Best.pth")

print("Training complete!")