In [1]:
# Installs if you use google colab
#! pip install datasets transformers torch huggingface_hub accelerate bitsandbytes

In [11]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
import huggingface_hub
from datasets import load_dataset
import os
from dotenv import load_dotenv
from PIL import Image
from torch.utils.data import DataLoader
import random


import bitsandbytes 


In [19]:
# Log on to hugging face to pull models
load_dotenv()
auth_token = os.getenv("HF")
huggingface_hub.login(auth_token)

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [13]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cpu


In [14]:
'''
Load dataset 
'''

base_dir = os.getcwd()
dataset_path = os.path.join(base_dir, "PMC-VQA")

data_files = {
    "train": os.path.join(dataset_path, "train.csv"),
    "test": os.path.join(dataset_path, "test.csv"),
}
images_path = os.path.join(dataset_path, "PMC_images_unzipped")

dataset = load_dataset("csv", data_files=data_files)
print(dataset["train"][0])

example_image = "C:\\Users\\scott\\UCF_misc\\MedVQA\\PMC-VQA\\PMC_images_unzipped\\figures_0\\PMC1395322_F2.jpg"

{'Figure_path': 'PMC1064097_F1.jpg', 'Question': 'What is the uptake pattern in the breast? ', 'Answer': 'Focal uptake pattern', 'Choice A': ' A:Diffuse uptake pattern ', 'Choice B': ' B:Focal uptake pattern ', 'Choice C': ' C:No uptake pattern ', 'Choice D': ' D:Cannot determine from the information given ', 'Answer_label': 'B'}


In [15]:
'''
Load LLM

# Model: deepseek-ai/DeepSeek-R1-Distill-Qwen-32B
# Memory usage: ~36 GB

EXAMPLE USAGE:
prompt = "What is the capital of France?"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

# Free the inputs from GPU memory
del inputs
torch.cuda.empty_cache()
'''

# Set up quantization config
bnb_config = transformers.BitsAndBytesConfig(
    load_in_8bit=True,
    #nb_4bit_compute_dtype=torch.float16,
)

#llm_path = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
llm_path = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

# Load tokenizer
llm_tokenizer = transformers.AutoTokenizer.from_pretrained(llm_path)
#llm_tokenizer.requires_grad_(False)

# Load model with 4-bit quantization
llm_model = transformers.AutoModelForCausalLM.from_pretrained(
    llm_path,
    #quantization_config=bnb_config,
    device_map="auto", 
)
for param in llm_model.parameters():
    param.requires_grad = False



d_llm = llm_model.config.hidden_size
print(d_llm)

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


1536


In [20]:
'''
Load Vison model with a linear layer. 

Model: openai/clip-vit-large-patch14
Memory usage: 3 GB

EXAMPLE USAGE:


'''
class VisionEncoder(nn.Module):
    def __init__(self, projection_dim=d_llm):
        #Initializes the contrastive model.
        #projection_dim (int): The dimension of the projection space for contrastive learning.

        super(VisionEncoder, self).__init__()
        self.vision_processor = transformers.CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14", use_auth_token=auth_token)
        self.vision_model = transformers.CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14", use_auth_token=auth_token).to(device)
        self.vision_model.requires_grad_(False)

        '''
        # Freeze the weights of both the vision encoder and deepseek model.
        for param in self.vision_encoder.parameters():
            param.requires_grad = False
        # Determine the output feature dimension of the vision encoder.
        # This may vary based on your specific model.
        # For example, CLIP's vision encoder often has a config attribute like 'hidden_size'.
        feature_dim = (
            self.vision_encoder.config.hidden_size
            if hasattr(self.vision_encoder, 'config') and hasattr(self.vision_encoder.config, 'hidden_size')
            else 768  # Fallback default dimension
        )
        '''

        vision_model_dim = self.vision_model.config.hidden_size
        self.mlp = nn.Sequential(
            nn.Linear(vision_model_dim, 4*vision_model_dim),
            nn.ReLU(),
            nn.Linear(4*vision_model_dim, projection_dim)
        )
        
        

    def forward(self, images):
        with torch.no_grad():
            #image = Image.open()
            inputs = self.vision_processor(images=images, return_tensors="pt")      ##### Preprocess the image: resize, normalize, etc.        #####
            outputs = self.vision_model(**inputs)                           ##### Forward pass through the vision encoder              #####
            vision_embeddings = outputs.last_hidden_state              ##### The last_hidden_state contains the vision embeddings #####

        proj_features = self.mlp(vision_embeddings)
        #proj_features = F.normalize(proj_features, p=2, dim=1) # used in contrastive learning

        return proj_features


vision_encoder = VisionEncoder()



In [None]:
# Import dataset
base_dir = os.getcwd()
dataset_path = os.path.join(base_dir, "PMC-VQA")

data_files = {
    "train": os.path.join(dataset_path, "train_2.csv"),
    "test": os.path.join(dataset_path, "test_2.csv"),
}
images_2_path = os.path.join(dataset_path, "images_2")

dataset_2 = load_dataset("csv", data_files=data_files)
dataset_2['test'][0]

{'index': 62,
 'Figure_path': 'PMC8253867_Fig2_41.jpg',
 'Caption': 'CT pulmonary angiogram reveals encasement and displacement of the left anterior descending coronary artery ( blue arrows ).',
 'Question': ' What is the name of the artery encased and displaced in the image? ',
 'Choice A': ' A: Right Coronary Artery ',
 'Choice B': ' B: Left Anterior Descending Coronary Artery ',
 'Choice C': ' C: Circumflex Coronary Artery ',
 'Choice D': ' D: Superior Mesenteric Artery ',
 'Answer': 'B',
 'split': 'test'}

In [26]:
class LLaVA(nn.Module):
    def __init__(self):
        super(LLaVA, self).__init__()
        self.vision_tower = vision_encoder
        self.llm = llm_model
        self.tokenizer = llm_tokenizer
        self.embedding_layer = llm_model.get_input_embeddings()
        self.prompt = (
            "Please Describe this image. "
            "Place your description and only the description with no extra commentary after 'ANSWER: ' "
            "Again, place your description and only the description with no extra commentary after 'ANSWER: ' "
        )
        self.cot_len = 64 # number of reasoning tokens to be masked out. 
        self.reply = "ANSWER: "

    def forward_align(self, image, caption, vision_first):    
        ###########################################################################################################################
        # Create the Input of the actual embeddings
        ###########################################################################################################################
        B = len(caption) # Batch size
        
        # Get the vision embeddings
        vision_embeds = self.vision_tower(image)

        # Tokenize the text prompt to go before the masked out CoT tokens
        tokens_prompt = self.tokenizer(
            self.prompt, 
            padding=True, 
            truncation=True, 
            return_tensors="pt"
        ).to(vision_embeds.device)
        tokens_prompt_ids = tokens_prompt.input_ids.repeat(B, 1)  # [B, 64]
        tokens_prompt_mask = tokens_prompt.attention_mask.repeat(B, 1)  # [B, 64]
        token_prompt_len = tokens_prompt.input_ids.shape[1]

        # Tokenize the (masked) CoT tokens
        cot_string = " ".join(["<>"] * 64)  
        tokens_cot = self.tokenizer( # -> "<mask> <mask> <mask> ... 64 times ..."
            cot_string, 
            return_tensors="pt"
        )
        tokens_cot_ids = tokens_cot.input_ids.repeat(B, 1)  # [B, 64]
        tokens_cot_mask = tokens_cot.attention_mask.repeat(B, 1)  # [B, 64]
        cot_token_len = tokens_cot.input_ids.shape[1]
        
        # Tokenize the text reply to go after the masked out CoT tokens
        reply = [
            self.reply + cap for cap in caption
        ]
        tokens_reply = self.tokenizer(
            reply, 
            padding=True, 
            truncation=True, 
            return_tensors="pt"
        ).to(vision_embeds.device)
        tokens_reply_len = tokens_reply.input_ids.shape[1]

        # Get the embedings for the 3 parts
        with torch.no_grad():
            print(vision_embeds.shape)
            prompt_embeds = self.embedding_layer(tokens_prompt_ids)
            print(prompt_embeds.shape)
            cot_embeds = self.embedding_layer(tokens_cot_ids)
            print(cot_embeds.shape)
            reply_embeds = self.embedding_layer(tokens_reply.input_ids)
            print(reply_embeds.shape)

        # This is used for the input to then LLM
        # Concat the text embedings with the vision ones
        if vision_first:
            combined_embeds = torch.cat([vision_embeds, prompt_embeds, cot_embeds, reply_embeds], dim=1)
        else: 
            combined_embeds = torch.cat([prompt_embeds, vision_embeds, cot_embeds, reply_embeds], dim=1) 

        print(combined_embeds.shape)

        ###########################################################################################################################
        # Create attention mask
        ###########################################################################################################################
        text_attention_mask = torch.cat([tokens_prompt_mask,
                                            tokens_cot_mask,
                                            tokens_reply.attention_mask], dim=1)

        # For the attention mask, we need one mask token per "vision token" plus the text tokens
        vision_seq_len = vision_embeds.shape[1]
        #text_seq_len = text_embeds.shape[1]
        vision_mask = torch.ones(B, vision_seq_len, dtype=torch.long, device=combined_embeds.device)
        #text_mask = torch.ones(batch_size, text_seq_len, dtype=torch.long, device=combined_embeds.device)
        
        if vision_first:
            combined_attention_mask = torch.cat([vision_mask, text_attention_mask], dim=1)
        else:
            combined_attention_mask = torch.cat([text_attention_mask, vision_mask, ], dim=1)
        ###########################################################################################################################
        # Create the labels portion of input-ids, vision and CoT masked out with -100
        ###########################################################################################################################
        # Creates the a tensor of input IDs for the text portion with CoT masked out
        text_embeds = torch.full(
            (B, token_prompt_len + cot_token_len + tokens_reply_len),
            -100,
            dtype=torch.long,
            device=vision_embeds.device
        )
        print(f"TExt embed labels: {text_embeds.shape}")
        # Add the input_ids for prompt and replys
        text_embeds[:, :token_prompt_len] = tokens_prompt_ids
        text_embeds[:, (token_prompt_len + cot_token_len):] = tokens_reply.input_ids
        text_embeds_len = text_embeds.shape[1]        
        
        # Need to mask out the vision embeddings for the label
        extended_labels = torch.full(
            (B, vision_seq_len + text_embeds_len),
            -100,  # ignore_index
            dtype=torch.long,
            device=vision_embeds.device
        )
        
        # Copy the text tokens into the text portion
        if vision_first:
            extended_labels[:, vision_seq_len:] = text_embeds  # the text portion at end
        else:
            # Prompt before image, cot, and reply tokens
            extended_labels[:, :(vision_seq_len+self.cot_len+tokens_reply_len)] = tokens_prompt_ids
            # Reply after prompt, image, cot tokens
            extended_labels[:, (token_prompt_len+vision_seq_len+self.cot_len):] = tokens_reply.input_ids
        
        # Run through the model
        outputs = self.llm(
            inputs_embeds=combined_embeds,
            attention_mask=combined_attention_mask,
            labels=extended_labels   # teacher-forcing
        )
        loss = outputs.loss  # cross-entropy

        # Retrun the loss
        return loss
        

model = LLaVA()

In [27]:
def train(model, dataset, device, epochs, batch_size, lr, logging_steps):
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    trainable_params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(trainable_params, lr=lr)

    model.to(device)

    for epoch in range(epochs):
        total_train_loss = 0.0

        for step, batch in enumerate(train_loader):
            figure_paths = batch['Figure_path']  # list of image filenames
            captions     = batch['Caption']      # list of caption strings

            # --- Load and transform images ---
            images = []
            for fig_path in figure_paths:
                img_full_path = os.path.join("./PMC-VQA/images_2/figures/", fig_path)
                img = Image.open(img_full_path)
                images.append(img)

            #images = torch.stack(images).to(device)
            # Forward pass for training, returns cross-entropy loss
            val = random.randint(1, 1234)
            loss = model.forward_align(images, captions, val % 2 == 0)
            total_train_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (step + 1) % logging_steps == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Step [{step+1}/{len(train_loader)}], "
                      f"Loss: {loss.item():.4f}")

        avg_train_loss = total_train_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{epochs}] completed. Avg Training Loss: {avg_train_loss:.4f}")


train(model, dataset_2['train'], device, 1, 5, 0.0001, 1)



torch.Size([5, 257, 1536])
torch.Size([5, 43, 1536])
torch.Size([5, 65, 1536])
torch.Size([5, 61, 1536])
torch.Size([5, 426, 1536])
TExt embed labels: torch.Size([5, 169])
Epoch [1/1], Step [1/30521], Loss: 4.8676
torch.Size([5, 257, 1536])
torch.Size([5, 43, 1536])
torch.Size([5, 65, 1536])
torch.Size([5, 52, 1536])
torch.Size([5, 417, 1536])
TExt embed labels: torch.Size([5, 160])


KeyboardInterrupt: 