<img src='./image/download.png'>

In [12]:
import numpy as np
import matplotlib.pyplot as plt 
import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW

import torch.nn as nn
from transformers import AutoModel, ViTModel, AutoModelForCausalLM, Trainer, TrainingArguments
import os 
import warnings
warnings.filterwarnings('ignore')

os.environ['CUDA_VISIBLE_DEVICES'] = "0, 1"     # This means that the program can see and use GPU 0 and GPU 1.

device = torch.device('cuda')
device

device(type='cuda')

In [13]:
import torch
import torch.nn as nn
from transformers import ViTModel, ViTFeatureExtractor, AutoTokenizer, AutoModel, AutoModelForCausalLM

In [14]:
DECODER_MODEL = 'gpt2'
TEXT_ENCODER_MODEL = 'distilbert-base-uncased'
IMAGE_ENCODER_MODEL = 'facebook/dino-vitb16'

In [15]:
decoder_tokenizer = AutoTokenizer.from_pretrained(DECODER_MODEL)
decoder_tokenizer.pad_token = decoder_tokenizer.eos_token
image_feature_extractor = ViTFeatureExtractor.from_pretrained(IMAGE_ENCODER_MODEL)

In [16]:
from PIL import Image
import os
import json
from collections import defaultdict
import numpy as np
from tqdm.auto import tqdm

def retrieve_image(image_file):
    try:
        image = Image.open(image_file)
        return image
    except:
        return None

def load_vqa_data(
    annotations_file, questions_file, images_folder, load_images=False, 
    start_at=None, end_at=None, max_images=None, max_questions=None,
    agree_threshold=5
):
    with open(annotations_file, "r") as f:
        annotations_data = json.load(f)

    with open(questions_file, "r") as f:
        questions_data = json.load(f)

    data = []
    images_used = defaultdict(int)
    # Create a dictionary to map question_id to the annotation data
    annotations_dict = {annotation["question_id"]: annotation for annotation in annotations_data["annotations"]}
    print(len(annotations_dict))
    for question in tqdm(questions_data["questions"][start_at:end_at]):
        question_id = question["question_id"]
        annotation = annotations_dict[question_id]

        image_id = question["image_id"]
        image_file = f"{images_folder}/COCO_{images_folder}_{str(image_id).zfill(12)}.jpg"
        if max_questions and images_used[image_file] >= max_questions:
            continue
        all_answers = [ans["answer"] for ans in annotation["answers"]]
        if all_answers.count(annotation['multiple_choice_answer']) < agree_threshold:
            continue

        if load_images:
            # Load the image and convert it to a numpy array
            image = retrieve_image(image_file)
            if not image:
                continue
            image.close()  # Close the image object
            
        else:
            if not os.path.exists(image_file):
                continue
            # Store the image file path
            image = image_file

        # Add the data as a dictionary
        data.append(
            {
                "image_id": image_id,
                "question_id": question_id,
                "question": question["question"],
                "answer": decoder_tokenizer.bos_token + ' ' + annotation["multiple_choice_answer"]+decoder_tokenizer.eos_token,
                "all_answers": all_answers,
                "image": image,
            }
        )
        images_used[image_file] += 1
        
        if max_images and len(images_used) >= max_images:
            break

    return data

In [17]:
# text tokenizer
text_tokenizer = AutoTokenizer.from_pretrained(TEXT_ENCODER_MODEL)

In [18]:
from PIL import Image
import torchvision.transforms as T 
from transformers import DataCollatorWithPadding
import numpy as np 

def preprocess_image(image):
    # open the image if the input is a file path
    if type(image) == str:
        img = image.open(image)
    else:
        img = image
        
    # check the number of channels in the images and convert to RGB if necessary
    if img.mode == "L":         # "L" stands for grayscale mode
        img_rgb = img.convert('RGB')
    else:
        img_rgb = img
        
    return img_rgb
        
def science_qa_data_collator(batch):
    # Preprocess and tokenize text
    if batch[0].get('answer'):
        text_inputs = [sample['question'] for sample in batch]
    else:
        text_inputs = [f"{sample['question']} Choices are: {'. '.join(sample['choices'])}" for sample in batch]
    text_tensors = text_tokenizer(text_inputs, padding=True, return_tensors='pt')
    
    try:
        image_inputs = image_feature_extractor([preprocess_image(sample["image"]) for sample in batch])
    except Exception as e:
        print(e)
        print([sample["image"] for sample in batch])
    image_tensors = torch.from_numpy(np.stack(image_inputs['pixel_values']))
    
    # prepare decoder inputs [targets]
    target_inputs = [sample["answer"] for sample in batch]
    target_tensors = decoder_tokenizer(target_inputs, padding=True, return_tensors="pt")
    
    # return input tensors
    return {
        "input_text": text_tensors["input_ids"],
        "attention_mask": text_tensors["attention_mask"],
        "input_image": image_tensors,
        "decoder_input_ids": target_tensors["input_ids"],
        "labels": target_tensors["input_ids"]
    }

In [19]:
from tqdm.notebook import tqdm


class MultiModalModel(nn.Module):
    """
    A MultiModal class used to perform visual question answering (VAQ).
    It consists of encoders for text and image and a decoder for generating the answer
    
    Attributes
    ---------
    text_encoder : A model to encode text input.
    image_encoder : A model to encode image input.
    decoder : A model to decode and genrate answers.
    text_projection : A linear layer to project text encoding to a specific size.
    image_projection : A linear layer to project image encoding to a specific size.
    """
    
    def __init__(self, text_encoder_model, image_encoder_model, decoder_model, freeze=None, load_from=None):
        """
        Initialize the MultiModalModel
        
        Parameters
        ----------
        text_encoder_model (str): Pre-trained text encoder model name.
        image_encoder_model (str): Pre-trained image encoder model name.
        decode_model (str): Pre-trained decoder model name.
        freeze (str. optional): Which parts of the model to freeze. Can be 'encoders', 'decoder', 'all' or specific encoder.
        load_from (str. optional): Path to a checkpoint file to load the model.
        """
        super(MultiModalModel, self).__init__()
        
        # Initialize text and image encoders
        self.text_encoder = AutoModel.from_pretrained(text_encoder_model)
        self.image_encoder = ViTModel.from_pretrained(image_encoder_model)
        
        # Initialize the GPT-2 decoder
        self.decoder = AutoModelForCausalLM.from_pretrained(
            decoder_model, add_cross_attention=True, tie_word_embeddings=True
        )
        
        # Initialize linear layers for projecting embedded features
        self.text_projection = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size)
        self.image_projection = nn.Linear(self.image_encoder.config.hidden_size, self.decoder.config.hidden_size)
        
        # Freeze specified if required or load from a checkpoint
        if load_from:
            self.load_model_checkpoint(load_from)
        else:
            self.freeze(freeze)
            
    def freeze(self, freeze):
        """
        Freeze specific parts of the model to prevent them from being updated during training
        
        Parameters
        --------
        freeze (str): which parts to freeze. Can be 'encoders', 'decoder', 'all' or specific encoder
        """
        if not freeze:
            return
        print('Freezing...')
        if freeze in ('encoders', 'all') or 'text_encoder' in freeze:
            print('Freezing text encoder')
            for param in self.text_encoder.parameters():
                param.requires_grad = False
        
        if freeze in ('encoders', 'all') or 'image_encoder' in freeze:
            print('Freezing image encoder')
            for param in self.image_encoer.parameters():
                param.requires_grad = False
                
        if freeze in ('decoder', 'all'):
            print('Freezing decoder (except for cross attention)')
            for name, param in self.decoder.named_parameters():
                if 'corssattention' not in name:
                    param.requires_grad = False
                    
    def load_model_checkpoint(self, path):
        """
        Load the model from a saved checkpoint.
        
        Parameters
        --------
        path(str) : Path to the saved checkpoint.
        """
        checkpoint = torch.load(path)
        checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.item()}
        self.load_state_dict(checkpoint)
        
    def encode_text(self, input_text, attention_mask):
        """ 
        Encode text using the text encoder and project it to a specific size.
        
        Parameters
        --------
        torch.Tensor: Projected text encoding
        """
        self.check_input(input_text, "input_text")
        text_encoded = self.text_encoder(input_text, attention_mask=attention_mask).last_hidden_state.mean(dim=1)
        return self.text_projection(text_encoded)
    
    def encode_image(self, input_image):
        """
        Encode image using the image encoder and project it to a specific size.
        
        Parameters:
        ----------
        input_image(torch.Tensor) : input image tensor.
        
        Returns:
        ----------
        torch.Tensor : Projected image encoding
        """
        self.check_input(input_image, "input_image")
        image_encoded = self.image_encoder(input_image).last_hidden_state.mean(dim=1)
        return self.image_projection(image_encoded)
    
    def forward(self, input_text, input_image, decoder_input_ids, attention_mask, labels=None):
        """
        Forward pass through the model.
        
        Parameters
        --------
        input_text (torch.Tensor): input text tensor
        input_image (torch.Tensor): input image tensor
        decoder_input_ids (torch.Tensor) : Decode input IDs tensor
        attention_mask (torch.Tensor): Attention mask for the input text
        labels (torch.Tensor, optional): Ground truth labels for the target
        
        Returns
        ---------
        torch.Tensor: Decoder output
        """
        self.check_input(decoder_input_ids, "decoder_input_ids")
        
        # Encode text and image
        text_projected = self.encode_text(input_text, attention_mask)
        image_projected = self.encode_image(input_image)
        
        # Combined encoded features
        combined_features = (text_projected + image_projected) / 2
        if labels is not None:
            labels = torch.where(labels == decoder_tokenizer.pad_token_id, -100, labels)
            
        # Decode with GPT-2
        decoder_outputs = self.decode(
            input_ids=decoder_input_ids,
            labels=labels,
            encoder_hidden_state=combined_features.unsqueeze()
        )
        return decoder_outputs
    
    def generate(self, image, questions, max_text_length=5):
        """
        Generate answers for the given image and list of questions.
        
        Parameters
        ----------
        image (Image): input image
        question (list): List of questions related to the image
        max_text_length (int, optional): Maximum text length for generated answers
        
        Returns
        ----------
        Image: input image
        """
        
        # encode text and image
        image = retrieve_image(image)
        image_input = image_feature_extractor(images=[preprocess_image(image)], return_tensors='pt')
        input_image = image_input["pixel_values"]
        image_projected = self.encode_image(input_image)
        
        for question in questions:
            i = text_tokenizer(question, return_tensors='pt')
            text_projected = self.encode_text(i['input_ids'], i['attention_mask'])
            
            # combine encoded features
            combined_features = (text_projected + image_projected) / 2
            
            generated_so_far = torch.LongTensor([[decoder_tokenizer.bos_token_id]])
            with torch.no_grad():
                for _ in tqdm(range(max_text_length)):
                    
                    decode_outputs = self.decoder(
                        input_ids=generated_so_far,
                        encoder_hidden_states=combined_features.unsqueeze()
                    )
                    next_token_logits = decode_outputs.logits[:, -1, :]
                    next_token_probs = F.softmax(next_token_logits, dim=-1)
                    next_token = next_token_logits.argmax(-1)
                    confidence = next_token_probs[0, next_token].item()
                    print("Next token:", decoder_tokenizer.decode(next_token), "Confidence:", confidence)
                    generated_so_far = torch.cat((generated_so_far, next_token.unsqueeze(0)), dim=1)
            print(question, decoder_tokenizer.decode(generated_so_far[0]))
            
        return image
                
        
     

In [20]:
OUTPUT_DIR = './data/vga_custom'

In [21]:
train_data = load_vga_data(
    "./data/v2_mscoco_train2014_annotations.json", "./data/v2_OpenEnded_mscoco_train2014_questions.json", "./data/train2014"
)

443757


  0%|          | 0/443757 [00:00<?, ?it/s]

In [22]:
train_data

[]

In [24]:
torch.cuda.get_device_capability()[0]

8