In [1]:
# Importer les bibliothèques nécessaires
import pickle
import os
import json
import torch
from torch.utils.data import Dataset, DataLoader,random_split
import cv2
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
import torch.nn as nn
from torchvision.models import resnet50,ResNet50_Weights
from transformers import BertModel, BertTokenizer, BartForConditionalGeneration, BartTokenizer
from transformers import AdamW


In [None]:
# Script pour charger un dataset VQA, le prétraiter et configurer des DataLoaders pour l'entraînement et la validation.

In [12]:
# Configuration du modele sur CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
torch.cuda.empty_cache()

#classer le dataset en #[image,qst,answer]
class VQADataset(Dataset):
    def __init__(self, dataset_path, transform=None):
        with open(dataset_path, 'rb') as file:
            self.data = pickle.load(file)
        self.transform = transform
    
    def __len__(self):
        return len(self.data)

    def image_to_matrix(self, image_path):
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        return image
    
    def __getitem__(self, idx):
        image_path, question, answer = self.data[idx]
        image_matrix = self.image_to_matrix(image_path)
        #renvoyer la liste qui contient [image,qst,answer]
        return image_matrix, question, answer


transform = transforms.Compose([
    transforms.Resize((512, 512)),   #la taille attendue par resnet50
    transforms.ToTensor(),         
])

dataset_train_path = "C:/Users/ACER/OneDrive/Bureau/dataset_v3/vqa_dataset_full.pkl"
vqa_dataset_train = VQADataset(dataset_train_path, transform=transform)
train_size = int(0.00001 * len(vqa_dataset_train))
rest_size = len(vqa_dataset_train) - train_size
#tester le fine tuning du modele avec une petite partie du train_data
vqa_dataset_train_10, _ = random_split(vqa_dataset_train, [train_size, rest_size]) 



dataset_val_path = "C:/Users/ACER/OneDrive/Bureau/dataset_v3/vqa_dataset_val_full.pkl"
vqa_dataset_val = VQADataset(dataset_val_path, transform=transform)
val_size = int(0.00001 * len(vqa_dataset_val))
rest_size = len(vqa_dataset_val) - val_size
#pour tester le fine tuning du modele avec une petite partie du val data
vqa_dataset_val_10, _ = random_split(vqa_dataset_val, [val_size, rest_size])


print(f"Dataset_train length: {len(vqa_dataset_train_10)}")
print(f"Dataset_val length: {len(vqa_dataset_val_10)}")



#un batch size de 8
batch_size = 8
train_dataloader = DataLoader(vqa_dataset_train_10, batch_size=batch_size, shuffle=True) #melanger les donnees a chq epoch
val_dataloader = DataLoader(vqa_dataset_val_10, batch_size=batch_size, shuffle=False)

num_batches_train = len(train_dataloader)
print(f"Number of batches in dataloader: {num_batches_train}")


num_batches_val = len(val_dataloader)
print(f"Number of batches in dataloader: {num_batches_val}")

#afficher un element de chaque data

for images, questions, answers in train_dataloader:
    print(f"Image Tensor Shape: {images.shape}")  # Afficher la forme du tensor de l'imag
    print(f"Question: {questions[0]}")
    print(f"Answer: {answers[0]}")
    
    image_to_show = images[0].permute(1, 2, 0).numpy()  # Convertir de [C, H, W] à [H, W, C]
    image_to_show = (image_to_show * 255).astype(np.uint8)  # Remettre à une échelle 0-255
    Image.fromarray(image_to_show).show()  
    break  

print("*/*/*/*/*/*/*/*//**/*/*/*/*/*//*///*")
    
for images, questions, answers in val_dataloader:
    print(f"Image Tensor Shape: {images.shape}")  
    print(f"Question: {questions[0]}")
    print(f"Answer: {answers[0]}")
    
    image_to_show = images[0].permute(1, 2, 0).numpy()  # Convertir de [C, H, W] à [H, W, C]
    image_to_show = (image_to_show * 255).astype(np.uint8)  
    Image.fromarray(image_to_show).show()  
    break  
    
    


Using device: cpu
Dataset_train length: 4
Dataset_val length: 2
Number of batches in dataloader: 1
Number of batches in dataloader: 1
Image Tensor Shape: torch.Size([4, 3, 512, 512])
Question: Is her back showing?
Answer: yes
*/*/*/*/*/*/*/*//**/*/*/*/*/*//*///*
Image Tensor Shape: torch.Size([2, 3, 512, 512])
Question: Is she sitting inside?
Answer: yes


In [6]:


# Visual Feature Extractor avec ResNet50
class VisualFeatureExtractor(nn.Module):
    def __init__(self, hidden_dim=768): #768 est le dim_embed de bert
        super(VisualFeatureExtractor, self).__init__()
        self.resnet = resnet50(weights=ResNet50_Weights.DEFAULT).to(device) #utilisant les poids de resnet pre entraines
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, hidden_dim) #projection sur hidden_dim

    def forward(self, image):
        image = image.to(device) # l image vers gpu
        features = self.resnet(image) #map features
        return features  #shape (batch_size,768)

# Question Encoder avec BERT
class QuestionEncoder(nn.Module):
    def __init__(self):
        super(QuestionEncoder, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased").to(device)
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    def forward(self, question):
        inputs = self.tokenizer(question, return_tensors="pt", padding=True, truncation=True) #tokenisation de la qst
        inputs = {key: value.to(device) for key, value in inputs.items()} #passer tout en gpu
        outputs = self.bert(**inputs) #passer les valeurs du dic en param a bert
        return outputs.last_hidden_state #shape (batch_size,seq_lg,768)

# Cross-Attention 
class CrossAttention(nn.Module):
    def __init__(self, hidden_dim=768):
        super(CrossAttention, self).__init__()
        self.cross_attention = nn.MultiheadAttention(hidden_dim, num_heads=4).to(device) #cross attention avec 4 heads

    def forward(self, question_embeddings, visual_features):
        visual_features = visual_features.unsqueeze(0) #shape (1,batch_size,768)
        question_embeddings = question_embeddings.permute(1, 0, 2)
        fusion_features, _ = self.cross_attention(question_embeddings, visual_features, visual_features)
        return fusion_features.permute(1, 0, 2) #(batch_size,seq_lg,768)

# Answer Generator avec BART
class AnswerGenerator(nn.Module):
    def __init__(self, hidden_dim=768):
        super(AnswerGenerator, self).__init__()
        self.bart = BartForConditionalGeneration.from_pretrained("facebook/bart-large").to(device)
        self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
        self.fc = nn.Linear(hidden_dim, self.bart.config.d_model)
        self.layer_norm = nn.LayerNorm(hidden_dim)  # ajouter normalization pour reduire les fluctuations des valeurs
        
        if self.tokenizer.pad_token is None:
            self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            self.bart.resize_token_embeddings(len(self.tokenizer))

    def forward(self, fusion_features, max_length=50):
        # normalisation
        normalized_features = self.layer_norm(fusion_features)
        
        # projeter sur la dim attendue par BART
        projected_features = self.fc(normalized_features)
        
        #moyenne sur seq_lg
        averaged_features = projected_features.mean(dim=1)
        
        # Generation de la réponse
        outputs = self.bart.generate(
            inputs_embeds=averaged_features.unsqueeze(1),
            max_length=max_length,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id
        )
        return outputs


# combiner les sous modèles
class VQAModel(nn.Module):
    def __init__(self, hidden_dim=768):
        super(VQAModel, self).__init__()
        self.visual_extractor = VisualFeatureExtractor(hidden_dim)
        self.question_encoder = QuestionEncoder()
        self.cross_attention = CrossAttention(hidden_dim)
        self.answer_generator = AnswerGenerator()

    def forward(self, image, question):
        visual_features = self.visual_extractor(image)
        question_embeddings = self.question_encoder(question).to(device)
        fusion_features = self.cross_attention(question_embeddings, visual_features).to(device)
        answer = self.answer_generator(fusion_features)
        return answer

# initialiser le modèle
model = VQAModel()
model = model.to(device)


In [None]:
#Fine tuning

In [9]:
#fine tuning (en premier c'etait pour corss_attention+answer_generator mais ça prend bcp de temps)

optimizer = AdamW([
    {'params': model.answer_generator.parameters(), 'lr': 5e-5}  # Fine-tune uniquement BART
], weight_decay=1e-4)

# fonction Fine-tuning 
def fine_tune_vqa(model, train_dataloader, val_dataloader, num_epochs=5, checkpoint_path="C:/Users/ACER/OneDrive/Bureau/dataset_v3/best_model_vqa"):
    best_val_loss = float('inf')
    tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
    #entrainnement
    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0
        print(f"Epoch {epoch+1}/{num_epochs}: Training starts...\n")
        for batch_idx, (image, question, answer) in enumerate(train_dataloader):
            optimizer.zero_grad() #initialiser l optimizer
            #passer tous les elements en device
            image = image.to(device)
            question_embeddings = model.question_encoder(question).to(device)
            visual_features = model.visual_extractor(image.to(device)).to(device)
            fusion_features = model.cross_attention(question_embeddings, visual_features)
            outputs = model.answer_generator(fusion_features)
            targets = tokenizer(answer, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
            #calculer la difference entre answer_gen et ground_truth
            loss = model.answer_generator.bart(input_ids=targets, labels=targets).loss
            #ignorer les NaN dans le calcul de la loss totale
            if loss is not None and not torch.isnan(loss):
                total_train_loss += loss.item()
            #backpropag
            loss.backward()
            optimizer.step()

            progress = (batch_idx + 1) / len(train_dataloader) * 100
            print(f"Batch {batch_idx}/{len(train_dataloader)} - Loss: {loss.item()} {progress:.2f}%")

        avg_train_loss = total_train_loss / len(train_dataloader)
        print(f"Epoch {epoch+1}/{num_epochs} - Avg training loss: {avg_train_loss}")

        
        #validation
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for batch_idx, (image, question, answer) in enumerate(val_dataloader):
                image = image.to(device)
                question_embeddings = model.question_encoder(question).to(device)
                visual_features = model.visual_extractor(image.to(device))
                fusion_features = model.cross_attention(question_embeddings, visual_features)
                outputs = model.answer_generator(fusion_features)
                targets = tokenizer(answer, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
                loss = model.answer_generator.bart(input_ids=targets, labels=targets).loss
                total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_dataloader)
        print(f"Epoch {epoch+1}/{num_epochs} - Validation loss: {avg_val_loss}")
        #sauvegarder le meilleure modèle
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Model saved with val loss: {best_val_loss}\n")

In [None]:
fine_tune_vqa(model, train_dataloader, val_dataloader, num_epochs=5)

In [13]:
def load_and_evaluate_model(image_path, question, checkpoint_path):
    # Charger le modèle 
    model = VQAModel().to(device)
    #charger les poids_a_jour
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.eval()

    # Prétraitement de l'image
    transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor()
    ])
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)  # Ajouter une dimension batch

    # Tokenizer pour la question
    tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")

    # Encoder la question
    inputs = tokenizer(question, return_tensors="pt", padding=True, truncation=True)
    question_embeddings = model.question_encoder(question).to(device)

    # Extraire les caractéristiques visuelles de l'image
    visual_features = model.visual_extractor(image).to(device)

    fusion_features = model.cross_attention(question_embeddings, visual_features)

    with torch.no_grad():
        outputs = model.answer_generator(fusion_features)

    # Décoder la réponse
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return answer




In [14]:
image_path = "C:/Users/ACER/OneDrive/Images/01.jpg"
question = "how many animal in the image?"
checkpoint_path = "C:/Users/ACER/Downloads/VQA_epoch2.pth"
#afficher la rép générée (dans cet expl c était sans poids a jour)
answer = load_and_evaluate_model(image_path, question, checkpoint_path)
print(f"Answer: {answer}") 

Answer: This is not a good time to start. Let's start at the beginning. This is a very good time. I'm going to start with the back of the head.This is a great time to stop and look around. This
