In [6]:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from pathlib import Path

import os
os.environ['HF_HOME'] = '/vol/biomedic3/bglocker/ugproj2324/nns20/CheXagent/.tuned_projection_model' ## THIS HAS TO BE BEFORE YOU IMPORT TRANSFORMERS
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig, AdamW, get_cosine_schedule_with_warmup
from torchvision import transforms

In [7]:
def setup_model() -> tuple:
    device = "cuda"
    dtype = torch.bfloat16

    processor = AutoProcessor.from_pretrained("StanfordAIMI/CheXagent-8b", trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        "StanfordAIMI/CheXagent-8b", torch_dtype=dtype, trust_remote_code=True
    ).to(device)
    generation_config = GenerationConfig.from_pretrained("StanfordAIMI/CheXagent-8b")

    return processor, model, device, dtype, generation_config

In [8]:
processor, model, device, dtype, generation_config = setup_model()
    

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [9]:
# define fine-tuning hyperparameters
train_set_percentage = 0.8
val_set_percentage = 0.1
num_epochs = 1
batch_size = 1

In [10]:
test_png_dset_path = Path('/vol/biodata/data/chest_xray/VinDr-CXR/1.0.0_png_512/raw/test')

class VinDrImageTextDataset(Dataset):
    def __init__(self, file_path):
        self.image_paths = []
        self.texts = []
        with open(file_path, 'r') as f:
            for line in f:
                image_path, text = line.strip().split(';')
                self.image_paths.append(image_path)
                self.texts.append(text)
              
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx, no_image = False):
        image_path = self.image_paths[idx]
        text = self.texts[idx]
        if no_image:
            return text
        image = Image.open(test_png_dset_path / f"{image_path}.png").convert("RGB")
        # Convert to PyTorch float tensor #TODO: check if I need to scale the values
        transform = transforms.Compose([transforms.PILToTensor()])
        image_tensor = transform(image)
        return image_tensor, text

    @staticmethod
    def get_stats(dataset, no_image=True):
        no_findings = 0
        for idx in range(len(dataset)):
            # When dataset is VinDrImageTextDataset or supports 'no_image' handling
            if hasattr(dataset, 'image_paths'):  # Assuming 'image_paths' attribute is unique to your dataset
                image, text = dataset.__getitem__(idx, no_image=no_image)
            else:
                _, text = dataset[idx]  # Only text is extracted, assuming it doesn't load image
            if 'no findings' in text:
                no_findings += 1
        return no_findings

In [11]:
dataset_checking = VinDrImageTextDataset("/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/VinDr-CXR/image_text_reasoning_datasets/test_tuning_all_left_or_right")
dataset_checking_loader = DataLoader(dataset_checking, batch_size=1, shuffle=False)

# experiment with passing training inputs through a processor before forward pass through model
for image, text in dataset_checking_loader:
    print(text[0])
    inputs = processor(image,text[0], return_tensors="pt")
    print(inputs)


    
    break
    
    #print(f"Image: {image.size()} Text: {text}")   

In this image, there is: left Calcification, left Cardiomegaly, left ILD, right Pneumothorax, right Pleural effusion, right Atelectasis.
{'pixel_values': tensor([[[[[-1.1645, -1.1499, -1.1499,  ..., -1.1499, -1.1499, -1.1499],
           [-1.7777, -1.7777, -1.7631,  ..., -1.6317, -1.6317, -1.6171],
           [-1.7485, -1.7485, -1.7485,  ..., -1.4127, -1.4273, -1.3981],
           ...,
           [-1.7777, -1.7777, -1.7777,  ..., -1.7339, -1.7193, -1.7339],
           [-1.7777, -1.7777, -1.7777,  ..., -1.7193, -1.7193, -1.7339],
           [-1.7777, -1.7777, -1.7777,  ..., -1.7339, -1.7193, -1.7339]],

          [[-1.1068, -1.0918, -1.0918,  ..., -1.0918, -1.0918, -1.0918],
           [-1.7371, -1.7371, -1.7221,  ..., -1.5870, -1.5870, -1.5720],
           [-1.7071, -1.7071, -1.7071,  ..., -1.3619, -1.3769, -1.3469],
           ...,
           [-1.7371, -1.7371, -1.7371,  ..., -1.6921, -1.6771, -1.6921],
           [-1.7371, -1.7371, -1.7371,  ..., -1.6771, -1.6771, -1.6921],
         

  [torch.tensor(pixel_values) for pixel_values in encoding_image_processor["pixel_values"]]


In [12]:
print(type(model))

<class 'transformers_modules.StanfordAIMI.CheXagent-8b.4934e91451945c8218c267aae9c34929a7677829.modeling_chexagent.CheXagentForConditionalGeneration'>


In [13]:

def fine_tune_chexagent():

    # Freeze the vision encoder, q former and language model
    model.vision_model.requires_grad_(False)
    model.qformer.requires_grad_(False)
    model.language_model.requires_grad_(False)
    
    # Make the vision-language bridge trainable
    model.language_projection.requires_grad_(True)
    
    # Prepare the dataset
    dataset = VinDrImageTextDataset("/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/VinDr-CXR/image_text_reasoning_datasets/test_tuning_all_left_or_right")

    # split dataset into train, val, test 80%, 10%, 10%
    train_size = int(train_set_percentage * len(dataset))
    val_size = int(val_set_percentage * len(dataset))
    test_size = len(dataset) - train_size - val_size
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    
    # # Set up the optimizer and learning rate scheduler
    optimizer = AdamW(model.language_projection.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.05)
    num_training_steps = num_epochs * len(train_loader)
    lr_scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
    
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        for images, texts in train_loader:
            print(type(texts)) # is a tuple because of batching 

        #     inputs = processor(images=images, text=texts, return_tensors="pt").to(device, dtype=dtype)
        #     outputs = model(**inputs,labels = texts)
        #     loss = outputs.loss
        #     loss.backward()
        #     torch.nn.utils.clip_grad_norm_(model.language_projection.parameters(), 1.0)
        #     optimizer.step()
        #     lr_scheduler.step()
        #     optimizer.zero_grad()
        
        # # Evaluation
        # model.eval()
        # val_loss = 0
        # with torch.no_grad():
        #     for images, texts in val_loader:
        #         inputs = processor(images=images, text=texts, return_tensors="pt").to(device, dtype=dtype)
        #         outputs = model(**inputs)
        #         val_loss += outputs.loss.item()
        # val_loss /= len(val_loader)
        # print(f"Epoch {epoch+1}/{num_epochs} - Validation Loss: {val_loss:.4f}")
    
    # # Save the fine-tuned model
    # model.save_pretrained("fine_tuned_chexagent")

In [14]:
fine_tune_chexagent()



<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class 'tuple'>
<class '

: 