#### Preliminary Testing - Phase 4 - Model Training - Multi-class classification

This notebook attempts to implement the training script general to all the models - CLIP, BLIP, and ViLT 

In [17]:
import os
import pandas as pd
from PIL import Image
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
from transformers import (
    CLIPModel, CLIPProcessor,
    BlipForQuestionAnswering, BlipProcessor,
    ViltForQuestionAnswering, ViltProcessor
)
from tqdm import tqdm
from sklearn.model_selection import train_test_split

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"


In [18]:
# Questions and answer classes
questions = [
    "What limb is injured?",
    "Is the patient intubated?",
    "Where is the catheter inserted?",
    "Is there bleeding?",
    "Has the bleeding stopped?",
    "Is the patient moving?",
    "Is the patient breathing?",
    "Is there a tourniquet?",
    "Is there a chest tube?",
    "Are the patient and instruments secured?",
    "If a limb is missing which one?",
    "Is there mechanical ventilation?",
    "What is the position of the injury?"
]

classes = [
    ['no limb is injured', 'left leg', 'left arm', 'right leg', 'right arm'],
    ["can't identify", 'no', 'yes'],
    ['no catheter is used', 'lower limb'],
    ['no', 'yes'],
    ['there is no bleeding', 'no', 'yes'],
    ["can't identify", 'yes', 'no'],
    ["can't identify", 'no', 'yes'],
    ['no', 'yes'],
    ['no', 'yes'],
    ['no', 'yes', "can't identify"],
    ['none', 'left arm', 'left leg', 'right leg'],
    ["can't identify", 'no', 'yes'],
    ['thorax', 'throat', "can't identify", 'lower limb', 'abdomen', 'upper limb']
]


In [19]:
# Base Dataset for classification
class ClassificationVQADataset(Dataset):
    def __init__(self, dataframe, image_dir, processor, classes):
        self.data = dataframe
        self.image_dir = image_dir
        self.processor = processor
        self.qa_columns = dataframe.columns[2:]
        self.label_encoders = [LabelEncoder().fit(cls) for cls in classes]

    def __len__(self):
        return len(self.data) * len(self.qa_columns)

    def __getitem__(self, idx):
        row_idx = idx // len(self.qa_columns)
        q_idx = idx % len(self.qa_columns)
        row = self.data.iloc[row_idx]
        question = self.qa_columns[q_idx]
        answer = row[question]
        if pd.isna(answer):
            next_idx = (idx + 1) % len(self)
            return self.__getitem__(next_idx)
        image_path = os.path.join(self.image_dir, row['video_id'], f"{row['video_id']}_frame{row['frame']}.jpg")
        image = Image.open(image_path).convert("RGB")
        label = self.label_encoders[q_idx].transform([str(answer).strip()])[0]
        return {
            "text": question.strip(),
            "image": image,
            "label": torch.tensor(label, dtype=torch.long),
            "question_idx": q_idx
        }

# Collate function generator
def get_collate_fn(processor):
    def classification_collate(batch):
        texts = [item["text"] for item in batch]
        images = [item["image"] for item in batch]
        labels = torch.stack([item["label"] for item in batch])
        question_idxs = [item["question_idx"] for item in batch]
        processed = processor(text=texts, images=images, return_tensors="pt", padding=True, truncation=True)
        return {
            "input_ids": processed["input_ids"],
            "attention_mask": processed["attention_mask"],
            "pixel_values": processed["pixel_values"],
            "labels": labels,
            "question_idxs": question_idxs
        }
    return classification_collate


In [35]:
# Unified classifier
class VQAClassifier(nn.Module):
    def __init__(self, model_name, base_model, hidden_dim, num_classes_per_question):
        super().__init__()
        self.name = model_name
        self.base = base_model
        self.classifiers = nn.ModuleList([
            nn.Linear(hidden_dim, num_classes) for num_classes in num_classes_per_question
        ])

    def forward(self, input_ids, attention_mask, pixel_values, question_idx):
        if self.name == "blip":
            # BLIP-specific pooling from text encoder
            vision_outputs = self.base.vision_model(pixel_values=pixel_values)
            image_embeds = vision_outputs.last_hidden_state

            text_inputs = self.base.text_encoder.embeddings(input_ids=input_ids)
            text_outputs = self.base.text_encoder.encoder(
                hidden_states=text_inputs,
                attention_mask=attention_mask
            )
            pooled = text_outputs.last_hidden_state[:, 0, :]  # [CLS]
        
        elif self.name == "clip":
            outputs = self.base(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
            pooled = outputs.text_embeds + outputs.image_embeds

        else:  # For ViLT and any other HuggingFace encoder-style model
            outputs = self.base(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
            pooled = outputs.last_hidden_state[:, 0, :]  # CLS token

        logits = self.classifiers[question_idx](pooled)
        return logits


In [40]:
from transformers import ViltModel

# Model registry
models = {
    "clip": {
        "model": CLIPModel.from_pretrained("openai/clip-vit-base-patch32"),
        "processor": CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32"),
        "hidden_dim": 512
    },
    "blip": {
        "model": BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base"),
        "processor": BlipProcessor.from_pretrained("Salesforce/blip-vqa-base"),
        "hidden_dim": 768
    },
    "vilt": {
        "model": ViltModel.from_pretrained("dandelin/vilt-b32-mlm"),
        "processor": ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm"),
        "hidden_dim": 768
    }
}



In [41]:
# Load data
df = pd.read_csv("data/sample-df.csv")
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)


In [42]:
# Prompt user for model selection
print("Available models:", ", ".join(models.keys()))
selected_model = input("Enter the model you want to train (clip, blip, vilt): ").strip().lower()

if selected_model not in models:
    raise ValueError(f"Invalid model name. Choose from: {list(models.keys())}")


Available models: clip, blip, vilt


In [43]:
# Train selected model
config = models[selected_model]
print(f"\n🚀 Training {selected_model.upper()}...")
processor = config["processor"]
collate_fn = get_collate_fn(processor)
train_dataset = ClassificationVQADataset(train_df, image_dir="frames_sample", processor=processor, classes=classes)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

base_model = config["model"].to(device)
model = VQAClassifier(selected_model, base_model, config["hidden_dim"], [len(c) for c in classes]).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()

for epoch in range(3):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} - {selected_model}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        question_idxs = batch["question_idxs"]

        losses = []
        for i in range(len(input_ids)):
            logits = model(
                input_ids[i].unsqueeze(0),
                attention_mask[i].unsqueeze(0),
                pixel_values[i].unsqueeze(0),
                question_idx=question_idxs[i]
            )
            loss = criterion(logits, labels[i].unsqueeze(0))
            losses.append(loss)

        batch_loss = sum(losses) / len(losses)
        batch_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += batch_loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"✅ {selected_model.upper()} Epoch {epoch+1} Avg Loss: {avg_loss:.4f}")

torch.save(model.state_dict(), f"pt/{selected_model}_classifier.pt")



🚀 Training VILT...


Epoch 1 - vilt: 100%|██████████| 7/7 [00:13<00:00,  1.95s/it]


✅ VILT Epoch 1 Avg Loss: 1.0944


Epoch 2 - vilt: 100%|██████████| 7/7 [00:14<00:00,  2.10s/it]


✅ VILT Epoch 2 Avg Loss: 0.8698


Epoch 3 - vilt: 100%|██████████| 7/7 [00:16<00:00,  2.36s/it]


✅ VILT Epoch 3 Avg Loss: 0.6163
