I will be testing biomed clip with a little set of images, just searching the elements of calcifications and nodules (maybe include the birads) to make the radiological report. 

In [5]:
from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')


In [6]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW
from PIL import Image
import os
import json
import pandas as pd
import numpy as np
import nrrd
from torchvision import transforms
from typing import List, Optional
from transformers import TrainingArguments, Trainer


ComplexMedicalDataset is a custom PyTorch Dataset class designed for handling medical imaging data, specifically mammograms, and their associated text reports for fine-tuning BioMedCLIP. The class manages the loading, processing, and batching of both image and text data.


In [7]:
"""Expected JSON Structure
jsonCopy{
    "sample_id": {
        "image_paths": ["path1.jpg", "path2.jpg", ...],
        "report": "medical report text"
    }
}
"""
class ComplexMedicalDataset(Dataset):
    def __init__(self, data_dir: str, processor, tokenizer):
        self.data_dir = data_dir
        self.processor = processor
        self.tokenizer = tokenizer
        
        with open(os.path.join(data_dir, 'dataset_info.json'), 'r') as f:
            self.data = json.load(f)
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def __len__(self):
        return len(self.data)
    
    """
    Purpose: Processes individual images with error handling and format standardization
    Features:

    Converts images to RGB format
    Handles different processor types
    Includes fallback transformation pipeline
    Ensures correct tensor dimensions
    Applies normalization with ImageNet statistics
"""
    
    def process_image(self, image_path):
        try:
            img = Image.open(os.path.join(self.data_dir, image_path)).convert('RGB')
            
            if self.processor:
                processed_img = self.processor(img)
            else:
                transform = transforms.Compose([
                    transforms.Resize((224, 224)),  # Adjust size as needed
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ])
                processed_img = transform(img)
            
            # Ensure the image has the correct number of dimensions
            if processed_img.dim() == 3:
                processed_img = processed_img.unsqueeze(0)  # Add batch dimension
            
            return processed_img

        except Exception as e:
            print(f"Error processing image {image_path}: {str(e)}")
            return torch.zeros((1, 3, 224, 224))  # Return a blank 4D tensor

    """
    Returns: Dictionary containing:

    Processed image tensors
    Tokenized text

    Processing Steps:

    Loads and processes multiple images if present
    Stacks multiple images into a single tensor
    Processes associated text report
    Returns formatted data dictionary
    """

    def __getitem__(self, idx: int) -> dict:
        item = self.data[idx]
        a = list(item.keys())[0]

        # Process images
        images = []
        for img_path in item[a]['image_paths']:
            processed_img = self.process_image(img_path)
            images.append(processed_img)
       
        
        # Stack images if multiple, otherwise use single image
        if len(images) > 1:
            images = torch.cat(images, dim=0)
        else:
            images = images[0]
        print(images.size())
        images = images.squeeze(1) # remove the dimension 1

        # Process text
        text = self.tokenizer(
            item[a]['report'], context_length=256
        ).to(self.device)
        
        # Remove batch dimension added by tokenizer
        #text = {k: v.squeeze(0) for k, v in text.items()}
        
        return {
            "image": images,
            "text": text
        }

    """
    Returns: Dictionary containing:

    Processed image tensors
    Tokenized text


    Processing Steps:

    Loads and processes multiple images if present
    Stacks multiple images into a single tensor
    Processes associated text report
    Returns formatted data dictionary
"""

    @staticmethod
    def collate_fn(batch):
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

        """Custom collate function to handle batching"""
        # Collate images
        images = (torch.stack([item['image'] for item in batch])).to(device)
        print(images.shape)
        
        # Collate text
        text_batch = {}
        for key in batch[0]['text'].keys():
            text_batch[key] = torch.stack([item['text'][key] for item in batch])
        
        texts = (torch.stack([item['text'] for item in batch])).to(device)
        
        # Collate indices
        indices = [item['idx'] for item in batch].to(device)
        print(images.size())
        
        return {
            'image': images,
            'text': texts
        }, indices


In [8]:
def fine_tune_biomed_clip(model, train_dataloader, num_epochs, device, learning_rate=5e-5):
    """
    Manual training loop for BioMed CLIP 
    THIS CODE HAS NOT BEEN TESTED
    """
    model.to(device)
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        num_batches = 0
        
        for batch in train_dataloader:
            try:
                # Move batch to device
                images = batch['image'].to(device)
                texts = batch['text']
                
                if isinstance(texts, dict):
                    texts = {k: v.to(device) for k, v in texts.items()}
                else:
                    texts = texts.to(device)
                
                # Forward pass
                outputs = model(image=images, text=texts)
                loss = outputs.loss if hasattr(outputs, 'loss') else outputs[0]
                
                # Backward pass and optimization
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                
                # Accumulate loss
                total_loss += loss.item()
                num_batches += 1
                
                if num_batches % 10 == 0:
                    print(f"Epoch {epoch+1}/{num_epochs}, Batch {num_batches}, Loss: {loss.item():.4f}")
                
            except Exception as e:
                print(f"Error in batch: {str(e)}")
                continue
        
        # Calculate and print average loss for the epoch
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
    
    return model

In [9]:
def generate_text(model, processor, image_paths, entity_info=None, max_length=100):
    """
    Generate text descriptions from medical images using BioMedCLIP.
    THIS CODE HAS NOT BEEN TESTED, I DONT KNOW IF IT WORKS
    """
    # Load and process all images
    images = [Image.open(path) for path in image_paths]
    
    # Process images with the BioMedCLIP processor
    inputs = processor(images, return_tensors="pt", padding=True)
    
    # Move inputs to the same device as model
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Generate text without gradient computation
    with torch.no_grad():
        outputs = model.generate(**inputs, max_length=max_length)
    
    # Decode the generated text
    generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
    
    return generated_text

In [10]:
import torch
from torch.utils.data import DataLoader
from open_clip import create_model_from_pretrained, get_tokenizer

"""
BioMedCLIP Setup and Data Loading Script

This script initializes the BioMedCLIP model, preprocessor, and tokenizer,
then sets up the dataset and dataloader for training or inference.

"""

# Initialize BioMedCLIP model and preprocessor
model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

# Initialize tokenizer
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

# Create dataset instance
dataset = ComplexMedicalDataset(
    data_dir="/Users/YusMolina/Documents/tesis/biomedCLIP/data/datosMex/",
    processor=preprocess,
    tokenizer=tokenizer
)

# Verify dataset loading by checking a sample
print(f"Sample from dataset: {dataset[4]}")

# Create DataLoader
dataloader = DataLoader(
    dataset, 
    batch_size=32, 
    shuffle=True, 
    collate_fn=ComplexMedicalDataset.collate_fn
    )

print(f"DataLoader configuration: {dataloader}")

"""
Expected directory structure:
/data/datosMex/
    ├── dataset_info.json
    ├── images/
    │   ├── image1.jpg
    │   ├── image2.jpg
    │   └── ...
    └── reports/
        ├── report1.txt
        └── ...

dataset_info.json format:
{
    "case_id": {
        "image_paths": ["path/to/image1.jpg", ...],
        "report": "medical report text"
    },
    ...
}
"""

# Configuration parameters
CONFIG = {
    'batch_size': 32,
    'num_workers': 4,
    'model_name': 'microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224',
    'data_dir': "/Users/YusMolina/Documents/tesis/biomedCLIP/data/datosMex/",
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu')
}

# Optional: Print configuration for verification
def print_setup_info():
    """Print configuration and setup information for verification."""
    print("\nBioMedCLIP Setup Information:")
    print(f"Device: {CONFIG['device']}")
    print(f"Dataset size: {len(dataset)}")
    print(f"Number of batches: {len(dataloader)}")
    print(f"Batch size: {CONFIG['batch_size']}")
    print(f"Number of workers: {CONFIG['num_workers']}")
    print("\nModel Information:")
    print(f"Model name: {CONFIG['model_name']}")
    


torch.Size([4, 3, 224, 224])
Sample from dataset: {'image': tensor([[[[-1.7923, -1.7923, -1.7923,  ..., -1.7923, -1.7923, -1.7923],
          [ 0.1055, -0.2594, -1.0915,  ..., -1.7923, -1.7923, -1.7923],
          [ 1.9303,  1.9303,  1.8427,  ..., -1.7923, -1.7923, -1.7923],
          ...,
          [ 1.9303,  1.9303,  1.9303,  ..., -1.7923, -1.7923, -1.7923],
          [ 1.9303,  1.9303,  1.9303,  ..., -1.7923, -1.7923, -1.7923],
          [ 1.9303,  1.9303,  1.9303,  ..., -1.7923, -1.7923, -1.7923]],

         [[-1.7521, -1.7521, -1.7521,  ..., -1.7521, -1.7521, -1.7521],
          [ 0.1989, -0.1763, -1.0317,  ..., -1.7521, -1.7521, -1.7521],
          [ 2.0749,  2.0749,  1.9848,  ..., -1.7521, -1.7521, -1.7521],
          ...,
          [ 2.0749,  2.0749,  2.0749,  ..., -1.7521, -1.7521, -1.7521],
          [ 2.0749,  2.0749,  2.0749,  ..., -1.7521, -1.7521, -1.7521],
          [ 2.0749,  2.0749,  2.0749,  ..., -1.7521, -1.7521, -1.7521]],

         [[-1.4802, -1.4802, -1.4802,  ...

In [11]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


taken from the tutorial: https://huggingface.co/docs/transformers/training

In [12]:
from transformers import TrainingArguments

training_args = TrainingArguments(output_dir="test_trainer")

In [13]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")

In [14]:
import numpy as np
import torch

def compute_metrics(eval_pred):
    """
    Compute evaluation metrics for CLIP-like models, specifically designed for
    image-text matching tasks.
    
    NOT TESTED
    """
    outputs = eval_pred.predictions
    
    # Handle CLIP-style embeddings (image_embeds, text_embeds)
    if isinstance(outputs, tuple):
        image_embeds, text_embeds = outputs
        
        # Compute cosine similarity matrix
        similarity = torch.matmul(
            torch.from_numpy(image_embeds), 
            torch.from_numpy(text_embeds).t()
        )
        
        # Get predictions (diagonal elements should be highest for correct pairs)
        predictions = torch.argmax(similarity, dim=1).numpy()
        labels = np.arange(len(predictions))  # Assuming paired data
        
        # Calculate evaluation metrics
        accuracy = (predictions == labels).mean()
        mean_similarity = similarity.diagonal().mean().item()
        
        # Return dictionary of metrics
        return {
            "accuracy": accuracy,                    # Matching accuracy
            "mean_similarity": mean_similarity       # Average similarity for correct pairs
        }
    
    # Handle standard classification outputs
    else:
        predictions = np.argmax(outputs, axis=-1)
        labels = eval_pred.label_ids
        accuracy = (predictions == labels).mean()
        return {"accuracy": accuracy}




In [15]:
# I know is bad, sorry
train_dataset = dataset
eval_dataset = dataset

In [16]:
# Training arguments setup
def setup_training(
    output_dir: str = "./results",
    num_epochs: int = 3,
    eval_steps: int = 100,
    learning_rate: float = 5e-5
) -> TrainingArguments:
   
    return TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_epochs,        
        eval_steps=eval_steps,
        learning_rate=learning_rate,
        metric_for_best_model="accuracy",
        logging_dir="./logs",
        logging_steps=10,
        report_to="tensorboard",
        # Gradient related
        gradient_accumulation_steps=1,
        weight_decay=0.01,
        # Mixed precision training
        fp16=torch.cuda.is_available(),
        # Other settings
        remove_unused_columns=False,  # Important for CLIP-like models
    )

In [17]:
training_args = setup_training(output_dir="./results")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset= train_dataset,
    eval_dataset= eval_dataset,
    compute_metrics=compute_metrics,
)

Prints to obain more information of the model 

In [18]:
trainer.train()

  0%|          | 0/3 [00:00<?, ?it/s]

torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])
torch.Size([4, 3, 224, 224])


ValueError: too many values to unpack (expected 4)

In [16]:
help(model.forward)


  0%|          | 0/3 [00:00<?, ?it/s]

torch.Size([4, 1, 3, 224, 224])
torch.Size([4, 1, 3, 224, 224])
torch.Size([4, 1, 3, 224, 224])
torch.Size([4, 1, 3, 224, 224])
torch.Size([4, 1, 3, 224, 224])
torch.Size([4, 1, 3, 224, 224])
torch.Size([4, 1, 3, 224, 224])


ValueError: too many values to unpack (expected 4)

In [20]:
import torchinfo

torchinfo.summary(model, input_size=(4, 3, 224, 224))  # specify input size


Layer (type:depth-idx)                             Output Shape              Param #
CustomTextCLIP                                     [4, 512]                  109,710,849
├─TimmModel: 1-1                                   [4, 512]                  --
│    └─VisionTransformer: 2-1                      [4, 768]                  152,064
│    │    └─PatchEmbed: 3-1                        [4, 196, 768]             590,592
│    │    └─Dropout: 3-2                           [4, 197, 768]             --
│    │    └─Identity: 3-3                          [4, 197, 768]             --
│    │    └─Identity: 3-4                          [4, 197, 768]             --
│    │    └─Sequential: 3-5                        [4, 197, 768]             85,054,464
│    │    └─LayerNorm: 3-6                         [4, 197, 768]             1,536
│    │    └─Identity: 3-7                          [4, 768]                  --
│    │    └─Dropout: 3-8                           [4, 768]                  --
│    

In [15]:
model

CustomTextCLIP(
  (visual): TimmModel(
    (trunk): VisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
        (norm): Identity()
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (patch_drop): Identity()
      (norm_pre): Identity()
      (blocks): Sequential(
        (0): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path1): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768

In [20]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import Dataset, DataLoader
from typing import Optional
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import os
import json

# Make sure the Dataset is defined at the module level
class ComplexMedicalDataset(Dataset):
    def __init__(self, data_dir: str, processor, tokenizer):
        super().__init__()  # Add super().__init__()
        self.data_dir = data_dir
        self.processor = processor
        self.tokenizer = tokenizer
        
        with open(os.path.join(data_dir, 'dataset_info.json'), 'r') as f:
            self.data = json.load(f)
        
        self.image_size = (224, 224)
        self.transform = transforms.Compose([
            transforms.Resize(self.image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                              std=[0.229, 0.224, 0.225])
        ])

    def process_image(self, image_path):
        try:
            img = Image.open(os.path.join(self.data_dir, image_path)).convert('RGB')
            return self.transform(img)
        except Exception as e:
            print(f"Error processing image {image_path}: {str(e)}")
            return torch.zeros((3, *self.image_size))

    def __getitem__(self, idx: int) -> dict:
        item = self.data[idx]
        a = list(item.keys())[0]
        
        # Process first image only
        image = self.process_image(item[a]['image_paths'][0])
        
        # Process text
        text = self.tokenizer(
            item[a]['report']
        )
        
        # Remove batch dimension from text tensors
        text = {k: v.squeeze(0) for k, v in text.items()}
        
        return {
            "image": image,
            "text": text
        }

    def __len__(self):
        return len(self.data)

class BiomedCLIPModule(pl.LightningModule):
    def __init__(self, model, learning_rate=5e-5):
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.save_hyperparameters(ignore=['model'])

    def forward(self, image, text):
        return self.model(image=image, text=text)

    def training_step(self, batch, batch_idx):
        images = batch['image']
        texts = batch['text']
        
        # Make sure images are in the correct shape (B, C, H, W)
        if images.dim() == 3:
            images = images.unsqueeze(0)
            
        outputs = self(images, texts)
        loss = outputs.loss if hasattr(outputs, 'loss') else outputs[0]
        
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images = batch['image']
        texts = batch['text']
        
        if images.dim() == 3:
            images = images.unsqueeze(0)
            
        outputs = self(images, texts)
        loss = outputs.loss if hasattr(outputs, 'loss') else outputs[0]
        
        self.log('val_loss', loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 
            T_max=10,
            eta_min=1e-6
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss"
            }
        }

class BiomedCLIPDataModule(pl.LightningDataModule):
    def __init__(self, 
                 data_dir: str,
                 processor,
                 tokenizer,
                 batch_size: int = 32,
                 num_workers: int = 4):
        super().__init__()
        self.data_dir = data_dir
        self.processor = processor
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage: Optional[str] = None):
        if stage == 'fit' or stage is None:
            self.train_dataset = ComplexMedicalDataset(
                self.data_dir, 
                self.processor, 
                self.tokenizer
            )
            self.val_dataset = self.train_dataset

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

def train_model_lightning(model, preprocess, tokenizer):
    # Initialize data module
    model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

    data_module = BiomedCLIPDataModule(
        data_dir="/Users/YusMolina/Documents/tesis/biomedCLIP/data/datosMex/",
        processor=preprocess,
        tokenizer=tokenizer,
        batch_size=32,
        num_workers=0  # Set to 0 for debugging
    )
    
    # Initialize model
    pl_model = BiomedCLIPModule(model)
    
    # Setup callbacks
    checkpoint_callback = ModelCheckpoint(
        dirpath='checkpoints',
        filename='biomed-clip-{epoch:02d}-{val_loss:.2f}',
        save_top_k=3,
        monitor='val_loss',
        mode='min'
    )
    
    # Setup logger
    logger = TensorBoardLogger("lightning_logs", name="biomed_clip")
    
    # Initialize trainer
    trainer = pl.Trainer(
        max_epochs=10,
        accelerator='auto',
        devices=1,
        callbacks=[checkpoint_callback],
        logger=logger,
        gradient_clip_val=0.5,
        log_every_n_steps=10,
        deterministic=True
    )
    
    # Train the model
    trainer.fit(pl_model, data_module)
    
    return trainer, pl_model

# Now you can run the training
if __name__ == '__main__':
    model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
    tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

    trainer, trained_model = train_model_lightning(model, preprocess, tokenizer)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type           | Params
-----------------------------------------
0 | model | CustomTextCLIP | 195 M 
-----------------------------------------
195 M     Trainable params
0         Non-trainable params
195 M     Total params
783.611   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

AttributeError: 'Tensor' object has no attribute 'items'