In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoProcessor, AdamW, get_cosine_schedule_with_warmup
from pathlib import Path
from PIL import Image

In [5]:
test_png_dset_path = Path('/vol/biodata/data/chest_xray/VinDr-CXR/1.0.0_png_512/raw/test')

class VinDrImageTextDataset(Dataset):
    def __init__(self, file_path):
        self.image_paths = []
        self.texts = []
        with open(file_path, 'r') as f:
            for line in f:
                image_path, text = line.strip().split(';')
                self.image_paths.append(image_path)
                self.texts.append(text)
              
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        text = self.texts[idx]
        image = Image.open(image_path).convert("RGB")
        return image, text

dset = VinDrImageTextDataset('/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/VinDr-CXR/image_text_reasoning_datasets/test_tuning_all_left_or_right')

In [None]:

def fine_tune_chexagent():
    # Load the pretrained CheXagent model
    processor, model, device, dtype, generation_config = setup_model()
    
    # Freeze the vision encoder and language model
    model.vision_model.requires_grad_(False)
    model.language_model.requires_grad_(False)
    
    # Make the vision-language bridge trainable
    model.language_projection.requires_grad_(True)
    
    # Prepare the dataset
    train_dataset = ImageTextDataset(train_image_paths, train_texts)
    val_dataset = ImageTextDataset(val_image_paths, val_texts)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32)
    
    # Set up the optimizer and learning rate scheduler
    optimizer = AdamW(model.language_projection.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.05)
    num_epochs = 10
    num_training_steps = num_epochs * len(train_loader)
    lr_scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
    
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        for images, texts in train_loader:
            inputs = processor(images=images, text=texts, return_tensors="pt").to(device, dtype=dtype)
            outputs = model(**inputs)
            loss = outputs.loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.language_projection.parameters(), 1.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
        
        # Evaluation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for images, texts in val_loader:
                inputs = processor(images=images, text=texts, return_tensors="pt").to(device, dtype=dtype)
                outputs = model(**inputs)
                val_loss += outputs.loss.item()
        val_loss /= len(val_loader)
        print(f"Epoch {epoch+1}/{num_epochs} - Validation Loss: {val_loss:.4f}")
    
    # Save the fine-tuned model
    model.save_pretrained("fine_tuned_chexagent")

if __name__ == '__main__':
    fine_tune_chexagent()