# Training

In [44]:
import torch
import torch.nn as nn
from transformers import T5Tokenizer, T5EncoderModel, ViTModel, T5ForConditionalGeneration
from torchvision import transforms
from PIL import Image
import os
from transformers import get_cosine_schedule_with_warmup, AdamW
from peft import get_peft_model, LoraConfig, TaskType
from torch.utils.data import Dataset, DataLoader
import json
from sklearn.model_selection import train_test_split

In [45]:
class QuestionEncoding(nn.Module):
    def __init__(self, pretrained_model):
        super(QuestionEncoding, self).__init__()
        self.encoder = T5EncoderModel.from_pretrained(pretrained_model)
        self.hidden_dim = self.encoder.config.hidden_size
        self.projection_layers = nn.ModuleList([nn.Linear(self.hidden_dim, self.hidden_dim) for _ in range(self.encoder.config.num_layers)])

    def forward(self, question):
        input_ids = question["input_ids"]
        attention_mask = question["attention_mask"]
        encoded_question = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        projected_question = [projection(encoded_question) for projection in self.projection_layers]
        return projected_question

In [46]:
class QuestionFusing(nn.Module):
    def __init__(self, hidden_dim):
        super(QuestionFusing, self).__init__()
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads=8)
        self.projection = nn.Linear(hidden_dim, hidden_dim)
        self.gating_projection = nn.Linear(hidden_dim, hidden_dim)
        self.beta = nn.Parameter(torch.zeros(1))

    def forward(self, visual_features, question_features):
        batch_size, _ = visual_features.size()
        visual_features = visual_features.unsqueeze(1)
        fused_features = torch.cat((visual_features, question_features), dim=1)
        attention_output, _ = self.attention(fused_features, fused_features, fused_features)
        visual_output = attention_output[:, 0, :].unsqueeze(1)
        projected_output = self.projection(visual_output)
        gated_output = self.gating_projection(visual_output) * torch.tanh(self.beta)
        fused_output = projected_output + gated_output
        fused_output = fused_output.squeeze(1)
        return fused_output

In [47]:
class QAViT(nn.Module):
    def __init__(self, vision_model, pretrained_model, fusion_layers):
        super(QAViT, self).__init__()
        self.vision_model = vision_model
        self.question_encoding = QuestionEncoding(pretrained_model)
        self.question_fusing = nn.ModuleList([QuestionFusing(self.question_encoding.hidden_dim) for _ in range(fusion_layers)])
        self.fusion_layers = fusion_layers

    def forward(self, image, question):
        visual_outputs = self.vision_model(pixel_values=image)
        visual_features = visual_outputs.last_hidden_state
        print("hola amigo before fusion",visual_features.shape)
        question_features = self.question_encoding(question)

        num_layers = visual_features.shape[1]
        start_layer = num_layers - self.fusion_layers

        for i in range(start_layer, num_layers):
            visual_features[:, i, :] = self.question_fusing[i - start_layer](visual_features[:, i, :], question_features[i - start_layer])
        print("hola amigo",visual_features.shape)
        visual_outputs.last_hidden_state = visual_features
        return visual_outputs

In [48]:
# Define the training dataset
class QADataset(Dataset):
    def __init__(self, image_dir, annotations, tokenizer, transform):
        self.image_dir = image_dir
        self.annotations = annotations
        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        annotation = self.annotations[idx]
        image_id = annotation["image_id"]
        image_path = os.path.join(self.image_dir, f"{image_id}.jpg")

        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)

        question = annotation["question"]
        answer = annotation["multiple_choice_answer"]

        question_tokens = self.tokenizer(question, return_tensors="pt", padding=True, truncation=True)

        return {
            "image": image,
            "question_tokens": {k: v.squeeze(0) for k, v in question_tokens.items()},
            "answer": answer
        }

In [49]:
def qa_collate_fn(batch):
    # Stack images
    images = torch.stack([item["image"] for item in batch])

    # Prepare questions
    question_tokens = {k: torch.cat([item["question_tokens"][k] for item in batch]) for k in batch[0]["question_tokens"]}
    padded_question_tokens = {
        "input_ids": torch.nn.utils.rnn.pad_sequence([item["question_tokens"]["input_ids"] for item in batch], batch_first=True, padding_value=0),
        "attention_mask": torch.nn.utils.rnn.pad_sequence([item["question_tokens"]["attention_mask"] for item in batch], batch_first=True, padding_value=0),
    }

    # Tokenize and pad answers
    answers = [item["answer"] for item in batch]
    answer_tokens = tokenizer(answers, return_tensors="pt", padding=True, truncation=True)

    return {
        "image": images,
        "question_tokens": padded_question_tokens,
        "answer": answer_tokens
    }


In [50]:
def load_and_split_dataset(image_dir, json_path, tokenizer, transform):
    # Load the annotations from JSON
    with open(json_path, 'r') as file:
        data = json.load(file)
    annotations = data["annotations"]

    # Split into train (80%), val (10%), and test (10%)
    train_annotations, test_annotations = train_test_split(annotations, test_size=0.2, random_state=42)
    val_annotations, test_annotations = train_test_split(test_annotations, test_size=0.5, random_state=42)

    # Create Dataset objects
    train_dataset = QADataset(image_dir, train_annotations, tokenizer, transform)
    val_dataset = QADataset(image_dir, val_annotations, tokenizer, transform)
    test_dataset = QADataset(image_dir, test_annotations, tokenizer, transform)

    return train_dataset, val_dataset, test_dataset

In [51]:
# Apply LoRa to T5 model
def apply_lora_to_t5(t5_model):
    lora_config = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules=["q", "k"]
    )
    return get_peft_model(t5_model, lora_config)

In [52]:
def train_qavit_with_validation(qavit_model, t5_model, train_loader, val_loader, tokenizer, num_epochs, device):
    qavit_model.train()
    t5_model.train()

    # Apply LoRa to T5 model
    t5_model = apply_lora_to_t5(t5_model)

    # Optimizer and Scheduler
    optimizer = AdamW([
        {"params": qavit_model.parameters(), "lr": 1e-4},
        {"params": t5_model.parameters(), "lr": 5e-5}
    ])
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=1000,
        num_training_steps=len(train_loader) * num_epochs
    )

    for epoch in range(num_epochs):
        # Training
        total_loss = 0
        qavit_model.train()
        t5_model.train()
        for batch in train_loader:
            images = batch["image"].to(device)
            question_tokens = {k: v.to(device) for k, v in batch["question_tokens"].items()}
            answers = {k: v.to(device) for k, v in batch["answer"].items()}

            optimizer.zero_grad()

            # Forward pass through QA-ViT model
            visual_outputs = qavit_model(images, question_tokens)
            visual_features = visual_outputs.last_hidden_state.mean(dim=1)

            # Forward pass through T5 model
            encoder_outputs = (visual_features.unsqueeze(1).repeat(1, question_tokens["input_ids"].size(1), 1), None, None)

            t5_outputs = t5_model(
                input_ids=question_tokens["input_ids"],
                attention_mask=question_tokens["attention_mask"],
                encoder_outputs=encoder_outputs,
                labels=answers["input_ids"]
            )

            loss = t5_outputs.loss
            loss.backward()
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {avg_train_loss:.4f}")

        # Validation
        qavit_model.eval()
        t5_model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                images = batch["image"].to(device)
                question_tokens = {k: v.to(device) for k, v in batch["question_tokens"].items()}
                answers = {k: v.to(device) for k, v in batch["answer"].items()}

                # Forward pass through QA-ViT model
                visual_outputs = qavit_model(images, question_tokens)
                visual_features = visual_outputs.last_hidden_state.mean(dim=1)

                # Forward pass through T5 model
                encoder_outputs = (visual_features.unsqueeze(1).repeat(1, question_tokens["input_ids"].size(1), 1), None, None)

                t5_outputs = t5_model(
                    input_ids=question_tokens["input_ids"],
                    attention_mask=question_tokens["attention_mask"],
                    encoder_outputs=encoder_outputs,
                    labels=answers["input_ids"]
                )

                loss = t5_outputs.loss
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {avg_val_loss:.4f}")


In [53]:
# Define the data paths and questions/answers
image_dir = "./segmentedImage"
json_path = "/Users/sudhanshu/Desktop/UMASS_COURSES_SEMESTERS/SEM_2/NLP/Project/dataset/combined_data.json"#"dataset/combined_data.json"


questionId_path ="/Users/sudhanshu/Desktop/UMASS_COURSES_SEMESTERS/SEM_2/NLP/Project/dataset/questionId_image_mapping.json" #"dataset/questionId_image_mapping.json"


# image_paths = ["/home/dpadalia_umass_edu/685proj/pink_bear.jpg"]
# questions = ["What is the color of the object?"]
# answers = ["pink"]

# Load and preprocess the images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Tokenizer and model initialization
pretrained_model = "google/flan-t5-base"
tokenizer = T5Tokenizer.from_pretrained(pretrained_model)

# Create the dataset and dataloader
train_dataset, val_dataset, test_dataset = load_and_split_dataset(image_dir, json_path, tokenizer, transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=qa_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=qa_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=qa_collate_fn)

# Load the ViT base model
vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224")

# Initialize the QA-ViT model
fusion_layers = 4
qavit_model = QAViT(vit_model, pretrained_model, fusion_layers)

# Load the T5 model for conditional generation
t5_model = T5ForConditionalGeneration.from_pretrained(pretrained_model)

# Move models to the appropriate device (e.g., GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
qavit_model.to(device)
t5_model.to(device)

# Train the QA-ViT model
num_epochs = 5
# train_qavit(qavit_model, t5_model, train_loader, tokenizer, num_epochs, device)
#train_qavit_with_validation(qavit_model, t5_model, train_loader, val_loader, tokenizer, num_epochs, device)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Inference

In [54]:
from transformers.modeling_outputs import BaseModelOutput

def infer_qavit(qavit_model, t5_model, image_path, question, tokenizer, device):
    # Load and preprocess the image
    image = Image.open(image_path).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image = transform(image).unsqueeze(0).to(device)

    # Tokenize the question
    question_tokens = tokenizer(question, return_tensors="pt", padding=True, truncation=True)
    question_tokens = {k: v.to(device) for k, v in question_tokens.items()}

    # Perform inference with the QA-ViT model
    qavit_model.eval()
    with torch.no_grad():
        visual_outputs = qavit_model(image, question_tokens)
        print("visual features without mean ",visual_outputs.last_hidden_state.shape)
        visual_features = visual_outputs.last_hidden_state.mean(dim=1)
    print("visual features",visual_features.shape)
        

    # Construct the encoder outputs
    encoder_hidden_state = visual_features.unsqueeze(1).repeat(1, question_tokens["input_ids"].size(1), 1)
    print(encoder_hidden_state.shape)
    encoder_outputs = BaseModelOutput(last_hidden_state=encoder_hidden_state)

    # Generate output with T5 model
    t5_model.eval()
    with torch.no_grad():
        output_ids = t5_model.generate(
            input_ids=question_tokens["input_ids"],
            attention_mask=question_tokens["attention_mask"],
            encoder_outputs=encoder_outputs,
            max_length=10,
            num_beams=1,
            early_stopping=True
        )

    # Decode the generated output
    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True).split()[0]

    return output_text

In [55]:
# Example usage
image_path = "/Users/sudhanshu/Desktop/UMASS_COURSES_SEMESTERS/SEM_2/NLP/Project/lang_segment_anything/segmentedImage/262148_262148000_image.png"
question = "Where is he looking?"

output_text = infer_qavit(qavit_model, t5_model, image_path, question, tokenizer, device)
print("Generated Output:", output_text)

hola amigo before fusion torch.Size([1, 197, 768])
hola amigo torch.Size([1, 197, 768])
visual features without mean  torch.Size([1, 197, 768])
visual features torch.Size([1, 768])
torch.Size([1, 7, 768])
Generated Output: Subscribe


# Legacy code

In [56]:
# def run_qavit(image_path, question, pretrained_model, fusion_layers):
#     # Load and preprocess the image
#     image = Image.open(image_path).convert("RGB")
#     transform = transforms.Compose([
#         transforms.Resize((224, 224)),
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
#     ])
#     image = transform(image).unsqueeze(0)

#     # Tokenize the question
#     tokenizer = T5Tokenizer.from_pretrained(pretrained_model)
#     question_tokens = tokenizer(question, return_tensors="pt", padding=True, truncation=True)

#     # Load the ViT base model
#     vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224")

#     # Initialize the QAViT model
#     qavit_model = QAViT(vit_model, pretrained_model, fusion_layers)

#     # Move the image, question, and QAViT model to the appropriate device (e.g., GPU)
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     # device = torch.device("cpu")
    
#     image = image.to(device)
#     question_tokens = {k: v.to(device) for k, v in question_tokens.items()}
#     qavit_model.to(device)

#     # Forward pass through the QAViT model
#     with torch.no_grad():
#         visual_outputs = qavit_model(image, question_tokens)
#         visual_features = visual_outputs.last_hidden_state

#     # Load the T5 model for conditional generation
#     t5_model = T5ForConditionalGeneration.from_pretrained(pretrained_model)
#     t5_model.to(device)
    
#     input_ids = question_tokens["input_ids"]
#     attention_mask = question_tokens["attention_mask"]
#     visual_features = visual_features.mean(dim=1)  # Perform average pooling
#     encoder_inputs = input_ids
#     encoder_attention_mask = attention_mask

#     output_ids = t5_model.generate(
#         input_ids=encoder_inputs,
#         attention_mask=encoder_attention_mask,
#         max_length=100,
#         num_beams=4,
#         early_stopping=True
#     )

#     # Decode the generated output
#     output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

#     return output_text

In [57]:
# # Example usage
# image_path = "/home/dpadalia_umass_edu/685proj/pink_bear.jpg"
# question = "What is the color of the object?"
# pretrained_model = "google/flan-t5-base"
# fusion_layers = 4

# output_text = run_qavit(image_path, question, pretrained_model, fusion_layers)
# print("Generated Output:", output_text)

In [58]:
# # Training function
# def train_qavit(qavit_model, t5_model, train_loader, tokenizer, num_epochs, device):
#     qavit_model.train()
#     t5_model.train()

#     # Apply LoRa to T5 model
#     t5_model = apply_lora_to_t5(t5_model)

#     # Optimizer and Scheduler
#     optimizer = AdamW([
#         {"params": qavit_model.parameters(), "lr": 1e-4},
#         {"params": t5_model.parameters(), "lr": 5e-5}
#     ])
#     scheduler = get_cosine_schedule_with_warmup(
#         optimizer,
#         num_warmup_steps=1000,
#         num_training_steps=len(train_loader) * num_epochs
#     )

#     criterion = nn.CrossEntropyLoss()

#     for epoch in range(num_epochs):
#         total_loss = 0
#         for batch in train_loader:
#             images = batch["image"].to(device)
#             question_tokens = {k: v.to(device) for k, v in batch["question_tokens"].items()}
#             answers = tokenizer(batch["answer"], return_tensors="pt", padding=True, truncation=True)
#             answers = {k: v.to(device) for k, v in answers.items()}

#             optimizer.zero_grad()

#             # Forward pass through QA-ViT model
#             visual_outputs = qavit_model(images, question_tokens)
#             visual_features = visual_outputs.last_hidden_state.mean(dim=1)
            
#             encoder_outputs = (visual_features.unsqueeze(1).repeat(1, question_tokens["input_ids"].size(1), 1), None, None)
            
#             # Forward pass through T5 model
#             t5_outputs = t5_model(
#                 input_ids=question_tokens["input_ids"],
#                 attention_mask=question_tokens["attention_mask"],
#                 encoder_outputs=encoder_outputs,
#                 labels=answers["input_ids"]
#             )

#             loss = t5_outputs.loss
#             loss.backward()
#             optimizer.step()
#             scheduler.step()

#             total_loss += loss.item()

#         print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader):.4f}")