I will be testing biomed clip with a little set of images, just searching the elements of calcifications and nodules (maybe include the birads) to make the radiological report. 

In [1]:
from open_clip import create_model_from_pretrained, get_tokenizer # works on open-clip-torch>=2.23.0, timm>=0.9.8
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')


In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW
from PIL import Image
import os
import json
import pandas as pd
import numpy as np
import nrrd
from torchvision import transforms
from typing import List, Optional
from transformers import TrainingArguments, Trainer


ComplexMedicalDataset is a custom PyTorch Dataset class designed for handling medical imaging data, specifically mammograms, and their associated text reports for fine-tuning BioMedCLIP. The class manages the loading, processing, and batching of both image and text data.


In [3]:
"""Expected JSON Structure
jsonCopy{
    "sample_id": {
        "image_paths": ["path1.jpg", "path2.jpg", ...],
        "report": "medical report text"
    }
}
"""
class ComplexMedicalDataset(Dataset):
    def __init__(self, data_dir: str, processor, tokenizer):
        self.data_dir = data_dir
        self.processor = processor
        self.tokenizer = tokenizer
        
        with open(os.path.join(data_dir, 'dataset_info.json'), 'r') as f:
            self.data = json.load(f)
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def __len__(self):
        return len(self.data)
    
    """
    Purpose: Processes individual images with error handling and format standardization
    Features:

    Converts images to RGB format
    Handles different processor types
    Includes fallback transformation pipeline
    Ensures correct tensor dimensions
    Applies normalization with ImageNet statistics
"""
    
    def process_image(self, image_path):
        
        img = Image.open(os.path.join(self.data_dir, image_path)).convert('RGB')
            
        if self.processor and isinstance(self.processor, transforms.Compose):
            try:
                processed_img = self.processor(img)
                if not isinstance(processed_img, torch.Tensor):
                        processed_img = transforms.ToTensor()(processed_img)
            except:
                transform = transforms.Compose([
                    transforms.Resize(self.image_size),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ])
                processed_img = transform(img)
                
            # Ensure the image has the correct number of dimensions
            if processed_img.dim() == 3:
                processed_img = processed_img.unsqueeze(0)  # Add batch dimension
            #print(processed_img.size())
                
            return processed_img

    """
    Returns: Dictionary containing:

    Processed image tensors
    Tokenized text


    Processing Steps:

    Loads and processes multiple images if present
    Stacks multiple images into a single tensor
    Processes associated text report
    Returns formatted data dictionary
    """

    def __getitem__(self, idx: int) -> dict:
        item = self.data[idx]
        a = list(item.keys())[0]

        # Process images
        images = []
        for img_path in item[a]['image_paths']:
            processed_img = self.process_image(img_path)
            images.append(processed_img)
       
        
        # Stack images if multiple, otherwise use single image
        if len(images) > 1:
            images = torch.stack(images)
        else:
            images = images[0]
        print(images.size())
        images = images.squeeze(1) # remove the dimension 1

        # Process text
        text = self.tokenizer(
            item[a]['report'], context_length=256
        ).to(self.device)
        
        # Remove batch dimension added by tokenizer
        #text = {k: v.squeeze(0) for k, v in text.items()}
        
        return {
            "image": images,
            "text": text
        }

    """
    Returns: Dictionary containing:

    Processed image tensors
    Tokenized text


    Processing Steps:

    Loads and processes multiple images if present
    Stacks multiple images into a single tensor
    Processes associated text report
    Returns formatted data dictionary
"""

    @staticmethod
    def collate_fn(batch):
        """Custom collate function to handle batching"""
        # Collate images
        images = torch.stack([item['image'] for item in batch])
        print(images.size())
        
        # Collate text
        text_batch = {}
        for key in batch[0]['text'].keys():
            text_batch[key] = torch.stack([item['text'][key] for item in batch])
        
        texts = torch.stack([item['text'] for item in batch])
        
        # Collate indices
        indices = [item['idx'] for item in batch]
        print(images.size())
        
        return {
            'image': images,
            'text': texts
        }, indices


In [4]:
def fine_tune_biomed_clip(model, train_dataloader, num_epochs, device, learning_rate=5e-5):
    """
    Fine-tune BioMedCLIP model on medical imaging data.
    
    Returns: model: Fine-tuned BioMedCLIP model
        
    Example:
        >>> model = BioMedCLIPModel.from_pretrained("microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224")
        >>> dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
        >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        >>> fine_tuned_model = fine_tune_biomed_clip(model, dataloader, num_epochs=10, device=device)
    """
    # Move model to specified device
    model.to(device)
    
    # Initialize optimizer
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        num_batches = 0
        
        # Batch processing
        for batch in train_dataloader:
            try:
                # Move batch to device
                batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v 
                        for k, v in batch.items()}
                
                # Forward pass
                outputs = model(**batch)
                loss = outputs.loss
                
                # Backward pass and optimization
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                
                # Accumulate loss
                total_loss += loss.item()
                num_batches += 1
                
                # Print progress every 10 batches
                if num_batches % 10 == 0:
                    print(f"Epoch {epoch+1}/{num_epochs}, "
                          f"Batch {num_batches}, "
                          f"Loss: {loss.item():.4f}")
                
            except Exception as e:
                print(f"Error in batch: {str(e)}")
                continue
        
        # Calculate and print epoch average loss
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
    
    return model

In [5]:
def generate_text(model, processor, image_paths, entity_info=None, max_length=100):
    """
    Generate text descriptions from medical images using BioMedCLIP.
    Returns: str: Generated text description
        
    Example:
        >>> model = BioMedCLIPModel.from_pretrained("microsoft/BiomedCLIP-PubMedBERT_256")
        >>> processor = BioMedCLIPProcessor.from_pretrained("microsoft/BiomedCLIP-PubMedBERT_256")
        >>> image_paths = ["path/to/mammo1.jpg", "path/to/mammo2.jpg"]
        >>> description = generate_text(model, processor, image_paths)
    
    Note:
        - Images should be in a readable format (e.g., JPG, PNG)
        - Model should be on the appropriate device (CPU/CUDA)
        - Ensure all images exist before calling function
    """
    # Load and process all images
    images = [Image.open(path) for path in image_paths]
    
    # Process images with the BioMedCLIP processor
    inputs = processor(images, return_tensors="pt", padding=True)
    
    # Move inputs to the same device as model
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Generate text without gradient computation
    with torch.no_grad():
        outputs = model.generate(**inputs, max_length=max_length)
    
    # Decode the generated text
    generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
    
    return generated_text

In [6]:
import torch
from torch.utils.data import DataLoader
from open_clip import create_model_from_pretrained, get_tokenizer

"""
BioMedCLIP Setup and Data Loading Script

This script initializes the BioMedCLIP model, preprocessor, and tokenizer,
then sets up the dataset and dataloader for training or inference.

"""

# Initialize BioMedCLIP model and preprocessor
model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

# Initialize tokenizer
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

# Create dataset instance
dataset = ComplexMedicalDataset(
    data_dir="/Users/YusMolina/Documents/tesis/biomedCLIP/data/datosMex/",
    processor=preprocess,
    tokenizer=tokenizer
)

# Verify dataset loading by checking a sample
print(f"Sample from dataset: {dataset[4]}")

# Create DataLoader
dataloader = DataLoader(
    dataset=dataset,
    batch_size=32,          # Adjust based on available GPU memory
    shuffle=True,           # Shuffle data during training
    collate_fn=ComplexMedicalDataset.collate_fn,  # Custom batching function
    num_workers=4           # Adjust based on CPU cores available
)

print(f"DataLoader configuration: {dataloader}")

"""
Expected directory structure:
/data/datosMex/
    ├── dataset_info.json
    ├── images/
    │   ├── image1.jpg
    │   ├── image2.jpg
    │   └── ...
    └── reports/
        ├── report1.txt
        └── ...

dataset_info.json format:
{
    "case_id": {
        "image_paths": ["path/to/image1.jpg", ...],
        "report": "medical report text"
    },
    ...
}
"""

# Configuration parameters
CONFIG = {
    'batch_size': 32,
    'num_workers': 4,
    'model_name': 'microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224',
    'data_dir': "/Users/YusMolina/Documents/tesis/biomedCLIP/data/datosMex/",
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu')
}

# Optional: Print configuration for verification
def print_setup_info():
    """Print configuration and setup information for verification."""
    print("\nBioMedCLIP Setup Information:")
    print(f"Device: {CONFIG['device']}")
    print(f"Dataset size: {len(dataset)}")
    print(f"Number of batches: {len(dataloader)}")
    print(f"Batch size: {CONFIG['batch_size']}")
    print(f"Number of workers: {CONFIG['num_workers']}")
    print("\nModel Information:")
    print(f"Model name: {CONFIG['model_name']}")
    


torch.Size([4, 1, 3, 224, 224])
{'image': tensor([[[[-1.7923, -1.7923, -1.7923,  ..., -1.7923, -1.7923, -1.7923],
          [ 0.1055, -0.2594, -1.0915,  ..., -1.7923, -1.7923, -1.7923],
          [ 1.9303,  1.9303,  1.8427,  ..., -1.7923, -1.7923, -1.7923],
          ...,
          [ 1.9303,  1.9303,  1.9303,  ..., -1.7923, -1.7923, -1.7923],
          [ 1.9303,  1.9303,  1.9303,  ..., -1.7923, -1.7923, -1.7923],
          [ 1.9303,  1.9303,  1.9303,  ..., -1.7923, -1.7923, -1.7923]],

         [[-1.7521, -1.7521, -1.7521,  ..., -1.7521, -1.7521, -1.7521],
          [ 0.1989, -0.1763, -1.0317,  ..., -1.7521, -1.7521, -1.7521],
          [ 2.0749,  2.0749,  1.9848,  ..., -1.7521, -1.7521, -1.7521],
          ...,
          [ 2.0749,  2.0749,  2.0749,  ..., -1.7521, -1.7521, -1.7521],
          [ 2.0749,  2.0749,  2.0749,  ..., -1.7521, -1.7521, -1.7521],
          [ 2.0749,  2.0749,  2.0749,  ..., -1.7521, -1.7521, -1.7521]],

         [[-1.4802, -1.4802, -1.4802,  ..., -1.4802, -1.4802

In [7]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


taken from the tutorial: https://huggingface.co/docs/transformers/training

In [8]:
from transformers import TrainingArguments

training_args = TrainingArguments(output_dir="test_trainer")

In [9]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")

In [10]:
import numpy as np
import torch

def compute_metrics(eval_pred):
    """
    Compute evaluation metrics for CLIP-like models, specifically designed for
    image-text matching tasks.
    
    Args:
        eval_pred: An object containing model predictions and labels.
                  Expected to contain:
                  - A tuple of (image_embeddings, text_embeddings)
    
    Returns:
        dict: Dictionary containing computed metrics:
              - accuracy: Match accuracy between images and text
              - mean_similarity: Average similarity score for correct pairs
    
    Note:
        - Assumes paired data where index i in images corresponds to index i in text
        - Similarity is computed using dot product between normalized embeddings
        - For non-CLIP outputs, falls back to standard classification metrics
    """
    outputs = eval_pred.predictions
    
    # Handle CLIP-style embeddings (image_embeds, text_embeds)
    if isinstance(outputs, tuple):
        image_embeds, text_embeds = outputs
        
        # Compute cosine similarity matrix
        similarity = torch.matmul(
            torch.from_numpy(image_embeds), 
            torch.from_numpy(text_embeds).t()
        )
        
        # Get predictions (diagonal elements should be highest for correct pairs)
        predictions = torch.argmax(similarity, dim=1).numpy()
        labels = np.arange(len(predictions))  # Assuming paired data
        
        # Calculate evaluation metrics
        accuracy = (predictions == labels).mean()
        mean_similarity = similarity.diagonal().mean().item()
        
        # Return dictionary of metrics
        return {
            "accuracy": accuracy,                    # Matching accuracy
            "mean_similarity": mean_similarity       # Average similarity for correct pairs
        }
    
    # Handle standard classification outputs
    else:
        predictions = np.argmax(outputs, axis=-1)
        labels = eval_pred.label_ids
        accuracy = (predictions == labels).mean()
        return {"accuracy": accuracy}




In [12]:
# I know is bad, sorry
train_dataset = dataset
eval_dataset = dataset

In [13]:
# Training arguments setup
def setup_training(
    output_dir: str = "./results",
    num_epochs: int = 3,
    eval_steps: int = 100,
    learning_rate: float = 5e-5
) -> TrainingArguments:
   
    return TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_epochs,        
        eval_steps=eval_steps,
        learning_rate=learning_rate,
        metric_for_best_model="accuracy",
        logging_dir="./logs",
        logging_steps=10,
        report_to="tensorboard",
        # Gradient related
        gradient_accumulation_steps=1,
        weight_decay=0.01,
        # Mixed precision training
        fp16=torch.cuda.is_available(),
        # Other settings
        remove_unused_columns=False,  # Important for CLIP-like models
    )

In [14]:
training_args = setup_training(output_dir="./results")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset= train_dataset,
    eval_dataset= eval_dataset,
    compute_metrics=compute_metrics,
)

prints to obain more info of the model 

In [15]:
trainer.train()

Help on method forward in module open_clip.model:

forward(image: Optional[torch.Tensor] = None, text: Optional[torch.Tensor] = None) method of open_clip.model.CustomTextCLIP instance
    Defines the computation performed at every call.
    
    Should be overridden by all subclasses.
    
    .. note::
        Although the recipe for forward pass needs to be defined within
        this function, one should call the :class:`Module` instance afterwards
        instead of this since the former takes care of running the
        registered hooks while the latter silently ignores them.



In [16]:
help(model.forward)


  0%|          | 0/3 [00:00<?, ?it/s]

torch.Size([4, 1, 3, 224, 224])
torch.Size([4, 1, 3, 224, 224])
torch.Size([4, 1, 3, 224, 224])
torch.Size([4, 1, 3, 224, 224])
torch.Size([4, 1, 3, 224, 224])
torch.Size([4, 1, 3, 224, 224])
torch.Size([4, 1, 3, 224, 224])


ValueError: too many values to unpack (expected 4)

In [20]:
import torchinfo

torchinfo.summary(model, input_size=(4, 3, 224, 224))  # specify input size


Layer (type:depth-idx)                             Output Shape              Param #
CustomTextCLIP                                     [4, 512]                  109,710,849
├─TimmModel: 1-1                                   [4, 512]                  --
│    └─VisionTransformer: 2-1                      [4, 768]                  152,064
│    │    └─PatchEmbed: 3-1                        [4, 196, 768]             590,592
│    │    └─Dropout: 3-2                           [4, 197, 768]             --
│    │    └─Identity: 3-3                          [4, 197, 768]             --
│    │    └─Identity: 3-4                          [4, 197, 768]             --
│    │    └─Sequential: 3-5                        [4, 197, 768]             85,054,464
│    │    └─LayerNorm: 3-6                         [4, 197, 768]             1,536
│    │    └─Identity: 3-7                          [4, 768]                  --
│    │    └─Dropout: 3-8                           [4, 768]                  --
│    

In [17]:
model

CustomTextCLIP(
  (visual): TimmModel(
    (trunk): VisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
        (norm): Identity()
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (patch_drop): Identity()
      (norm_pre): Identity()
      (blocks): Sequential(
        (0): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path1): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768

In [None]:
# Generate text for a new set of images (inference)

base_path = "/Users/YusMolina/Documents/tesis/biomedCLIP/data"

# List of image filenames
image_filenames = ["image1.jpg", "image2.jpg", "image3.jpg", "image4.jpg"]

# Create the full paths dynamically
new_image_paths = [base_path + filename for filename in image_filenames]

    
# Example of providing entity info during inference
entity_info = {
    "entity1": True,
    "entity2": False,
    "entity3": True
}
    
# Generate text with entity info
generated_text_with_entities = generate_text(fine_tuned_model, processor, new_image_paths, entity_info)
print("Generated Text with Entities:", generated_text_with_entities)
    
# Generate text without entity info
generated_text_without_entities = generate_text(fine_tuned_model, processor, new_image_paths)
print("Generated Text without Entities:", generated_text_without_entities)