In [None]:
from google.colab import drive
drive.mount('/content/drive')

!pip install nltk rouge-score
import nltk
nltk.download('punkt')

from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import nltk

nltk.download('punkt', force=True)
print(nltk.data.path)

!pip install git+https://github.com/openai/CLIP.git


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import pandas as pd
import clip
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from nltk.tokenize import word_tokenize

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


['/root/nltk_data', '/usr/nltk_data', '/usr/share/nltk_data', '/usr/lib/nltk_data', '/usr/share/nltk_data', '/usr/local/share/nltk_data', '/usr/lib/nltk_data', '/usr/local/lib/nltk_data']
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-9ehk6v46
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-9ehk6v46
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [None]:
def generate_caption(row):
    if row['abnormal'] == 0:
        return "Healthy knee"

    findings = []
    if row['acl'] == 1:
        findings.append("ACL tear")
    if row['meniscus'] == 1:
        findings.append("Meniscus tear")
    if findings:
        return " and a ".join(findings) + "."
    else:
        return "Unspecified abnormality."


class ClipCaptionModel(nn.Module):
    """
    A neural network model that combines CLIP embeddings with GPT-2 for image caption generation.

    The model projects the CLIP image embeddings into a prefix representation that is then concatenated
    with the input captions. It uses the GPT-2 model to generate captions conditioned on the image embeddings.

    Args:
        clip_dim (int, optional): The dimensionality of the CLIP image embeddings. Default is 512.
        prefix_len (int, optional): The number of tokens in the prefix generated from the image embedding.
                                    This prefix is concatenated with the input captions. Default is 10.

    Forward Method:
        The forward pass takes the CLIP image embedding, the caption tokens, and the attention mask as inputs.
        It generates the image-conditioned caption by combining the image embedding prefix with the caption
        embeddings and feeding them through the GPT-2 model.

    Args:
        image_embedding (torch.Tensor): A tensor of shape (batch_size, clip_dim) representing the image embeddings from CLIP.
        captions (torch.Tensor): A tensor of shape (batch_size, caption_len) representing the tokenized captions.
        attention_mask (torch.Tensor): A tensor of shape (batch_size, caption_len) indicating the padding positions in captions.

    Returns:
        transformers.modeling_outputs.CausalLMOutputWithCrossAttentions: The output of the GPT-2 model,
        containing the logits and additional information.
    """
    def __init__(self, clip_dim=512, prefix_len=10):
        super().__init__()
        self.gpt = GPT2LMHeadModel.from_pretrained("gpt2")
        self.prefix_len = prefix_len
        self.clip_project = nn.Linear(clip_dim, self.gpt.config.n_embd * prefix_len)

    def forward(self, image_embedding, captions, attention_mask):
        batch_size = captions.shape[0]

        image_embedding = image_embedding.float()

        prefix_embedding = self.clip_project(image_embedding).view(batch_size, self.prefix_len, -1)
        caption_embeddings = self.gpt.transformer.wte(captions)

        embeddings = torch.cat((prefix_embedding, caption_embeddings), dim=1)

        extended_attention = torch.cat((
            torch.ones((batch_size, self.prefix_len), device=attention_mask.device),
            attention_mask
        ), dim=1)

        labels = torch.cat((
            torch.full((batch_size, self.prefix_len), -100, device=captions.device),
            captions
        ), dim=1)

        outputs = self.gpt(inputs_embeds=embeddings, attention_mask=extended_attention, labels=labels)
        return outputs

class MRICaptionDataset(Dataset):
    """
    A custom dataset for loading MRI scans and their corresponding captions, suitable for image captioning tasks.

    This dataset assumes the MRI scans are stored as numpy arrays in a specified directory, with each array representing
    a 3D volume of slices. For each MRI scan, the dataset loads a representative slice, preprocesses it, and tokenizes
    the associated caption. The caption is tokenized and padded/truncated to a specified length.

    Args:
        dataframe (pandas.DataFrame): A dataframe containing the metadata for the MRI scans. The dataframe should
                                      contain columns 'exam' (MRI scan identifier) and 'caption' (text description of the scan).
        image_dir (str): The directory containing the MRI scan files (in `.npy` format). Each file is named by its exam ID.
        transform (callable): A transformation function to be applied to the image (e.g., resizing, normalization).
        tokenizer (transformers.PreTrainedTokenizer): A tokenizer to tokenize the captions, typically from HuggingFace's `transformers` library.
        max_length (int, optional): The maximum length for tokenizing captions. Default is 50 tokens.
        num_slices_to_use (int, optional): The number of slices to select from the MRI scan volume. Default is 5 slices.

    Attributes:
        clip_dim (int): The dimension of the CLIP embedding (set to 512).

    Methods:
        __len__(): Returns the number of items in the dataset.
        __getitem__(idx): Loads the MRI scan and its corresponding caption at the specified index `idx`, preprocesses the slice,
                          and returns the image tensor, tokenized caption (input IDs and attention mask), the exam ID, and the selected slice indices.

    Returns:
        tuple: A tuple containing:
            - img_tensor (torch.Tensor): The transformed image tensor for the representative slice.
            - input_ids (torch.Tensor): The tokenized caption's input IDs.
            - attention_mask (torch.Tensor): The attention mask for the tokenized caption.
            - exam_id (str): The MRI scan's exam ID.
            - selected_slices (list): A list of the indices of the selected slices from the MRI scan.
    """
    def __init__(self, dataframe, image_dir, transform, tokenizer, max_length=50, num_slices_to_use=5):
        self.data = dataframe
        self.image_dir = image_dir
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.num_slices_to_use = num_slices_to_use
        self.clip_dim = 512

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        exam_id = str(row['exam']).zfill(4)
        caption = row['caption']
        img_path = os.path.join(self.image_dir, f"{exam_id}.npy")

        scan = np.load(img_path)
        num_slices = scan.shape[0]

        # select evenly spaced slices
        if num_slices <= self.num_slices_to_use:
            selected_slices = range(num_slices)
        else:
            selected_slices = np.linspace(0, num_slices-1, self.num_slices_to_use, dtype=int)

        # take the middle slice for dataset output (representative slice)
        mid_slice = num_slices // 2
        slice_img = scan[mid_slice]
        slice_img = ((slice_img - slice_img.min()) / (slice_img.max() - slice_img.min()) * 255).astype(np.uint8)
        slice_rgb = np.stack([slice_img, slice_img, slice_img], axis=-1)
        pil_img = Image.fromarray(slice_rgb)
        img_tensor = self.transform(pil_img)

        tokens = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        input_ids = tokens.input_ids.squeeze(0)
        attention_mask = tokens.attention_mask.squeeze(0)

        return img_tensor, input_ids, attention_mask, exam_id, selected_slices.tolist() if hasattr(selected_slices, 'tolist') else list(selected_slices)


def collate_fn(batch):
    """
    A collate function to combine a list of samples into a batch for the DataLoader.

    This function takes a batch of data, which is a list of tuples returned by the dataset's `__getitem__` method.
    It stacks the image tensors, tokenized input IDs, and attention masks along the batch dimension to form a single
    batch tensor. The exam IDs and selected slice indices are returned as lists without modification.

    Args:
        batch (list of tuples): A list where each element is a tuple containing:
            - images (torch.Tensor): The transformed image tensors for each sample.
            - input_ids (torch.Tensor): The tokenized captions' input IDs.
            - attention_masks (torch.Tensor): The attention masks for the tokenized captions.
            - exam_ids (str): The MRI scan exam ID.
            - selected_slices (list): A list of indices for the selected slices in the MRI scan.

    Returns:
        tuple: A tuple containing:
            - images_tensor (torch.Tensor): A tensor of shape (batch_size, C, H, W) representing the stacked image tensors.
            - input_ids_tensor (torch.Tensor): A tensor of shape (batch_size, max_length) representing the stacked input IDs.
            - attention_masks_tensor (torch.Tensor): A tensor of shape (batch_size, max_length) representing the stacked attention masks.
            - exam_ids (tuple): A tuple of exam IDs for each sample in the batch.
            - selected_slices (tuple): A tuple of lists containing the selected slice indices for each sample in the batch.
    """
    images, input_ids, attention_masks, exam_ids, selected_slices = zip(*batch)
    images_tensor = torch.stack(images)
    input_ids_tensor = torch.stack(input_ids)
    attention_masks_tensor = torch.stack(attention_masks)

    return images_tensor, input_ids_tensor, attention_masks_tensor, exam_ids, selected_slices

def encode_selected_slices(clip_model, exam_id, selected_slices, image_dir, transform, device):
    """
    Encodes selected MRI scan slices using a CLIP model and returns the mean embedding of the selected slices.

    This function loads an MRI scan for a given `exam_id`, selects specific slices based on the provided indices,
    preprocesses each slice to a 3-channel image, and computes their embeddings using the provided CLIP model.
    It then averages the embeddings of the selected slices to produce a single mean embedding representing the scan.

    Args:
        clip_model (CLIPModel): A pre-trained CLIP model for encoding images into embeddings.
        exam_id (str): The exam ID (MRI scan identifier) to locate the corresponding scan in the `image_dir`.
        selected_slices (list of int): A list of indices specifying which slices of the scan to use for embedding.
        image_dir (str): The directory containing the MRI scan files, which are expected to be in `.npy` format.
        transform (callable): A transformation function to preprocess the slices (e.g., resizing, normalization).
        device (torch.device): The device (CPU or GPU) on which the CLIP model and the tensors should be processed.

    Returns:
        torch.Tensor: A tensor representing the mean embedding of the selected MRI slices. The tensor shape is (1, embedding_dim).
    """
    img_path = os.path.join(image_dir, f"{exam_id}.npy")
    scan = np.load(img_path)

    slice_embeddings = []
    for slice_idx in selected_slices:
        # convert slice to 3-channel image
        slice_img = scan[slice_idx]
        slice_img = ((slice_img - slice_img.min()) / (slice_img.max() - slice_img.min()) * 255).astype(np.uint8)
        slice_rgb = np.stack([slice_img, slice_img, slice_img], axis=-1)
        pil_img = Image.fromarray(slice_rgb)
        img_tensor = transform(pil_img).unsqueeze(0).to(device)

        with torch.no_grad():
            embedding = clip_model.encode_image(img_tensor)
        slice_embeddings.append(embedding)

    all_embeddings = torch.cat(slice_embeddings, dim=0)
    mean_embedding = torch.mean(all_embeddings, dim=0, keepdim=True)

    return mean_embedding

def train_epoch(caption_model, clip_model, dataloader, optimizer, device, image_dir, transform):
    """
    Trains the captioning model for one epoch on the provided dataloader.

    This function runs through one full pass of the dataset, computing the embeddings for each MRI scan slice using
    the provided CLIP model, and then computes the loss and updates the parameters of the captioning model.

    Args:
        caption_model (nn.Module): The captioning model to be trained (e.g., a model combining CLIP and GPT-2).
        clip_model (CLIPModel): A pre-trained CLIP model used to encode the MRI scan slices into embeddings.
        dataloader (DataLoader): A PyTorch DataLoader that provides batches of data for training.
        optimizer (torch.optim.Optimizer): The optimizer used for model parameter updates (e.g., Adam).
        device (torch.device): The device (CPU or GPU) where the models and tensors are processed.
        image_dir (str): The directory containing the MRI scan files in `.npy` format.
        transform (callable): A transformation function applied to each MRI slice before encoding.

    Returns:
        float: The average loss over the entire epoch.
    """
    caption_model.train()
    total_loss = 0
    batch_count = 0

    for batch in tqdm(dataloader, desc="Training"):
        images, input_ids, attention_mask, exam_ids, selected_slices = batch
        input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)

        batch_embeddings = []
        for i in range(len(exam_ids)):
            exam_embedding = encode_selected_slices(
                clip_model,
                exam_ids[i],
                selected_slices[i],
                image_dir,
                transform,
                device
            )
            batch_embeddings.append(exam_embedding)

        image_embeddings = torch.cat(batch_embeddings, dim=0)

        outputs = caption_model(image_embeddings, input_ids, attention_mask)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

def evaluate(caption_model, clip_model, dataloader, device, tokenizer, image_dir, transform):
    caption_model.eval()
    total_loss = 0

    true_caption_counts = {}
    pred_caption_counts = {}
    caption_pairs = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation"):
            images, input_ids, attention_mask, exam_ids, selected_slices = batch
            input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)

            batch_embeddings = []
            for i in range(len(exam_ids)):
                exam_embedding = encode_selected_slices(
                    clip_model,
                    exam_ids[i],
                    selected_slices[i],
                    image_dir,
                    transform,
                    device
                )
                batch_embeddings.append(exam_embedding)

            image_embeddings = torch.cat(batch_embeddings, dim=0)

            outputs = caption_model(image_embeddings, input_ids, attention_mask)
            loss = outputs.loss
            total_loss += loss.item()

            batch_size = input_ids.size(0)
            for i in range(batch_size):
                true_caption = tokenizer.decode(input_ids[i], skip_special_tokens=True)

                true_caption_counts[true_caption] = true_caption_counts.get(true_caption, 0) + 1

                prefix_embed = caption_model.clip_project(image_embeddings[i].float().unsqueeze(0)) \
                    .view(1, caption_model.prefix_len, -1)

                attention_prefix = torch.ones(1, caption_model.prefix_len, device=device)

                generated = caption_model.gpt.generate(
                    inputs_embeds=prefix_embed,
                    max_length=50,
                    num_beams=5,
                    early_stopping=True,
                    pad_token_id=tokenizer.eos_token_id,
                    attention_mask=attention_prefix
                )
                pred_caption = tokenizer.decode(generated[0], skip_special_tokens=True)
                pred_caption_counts[pred_caption] = pred_caption_counts.get(pred_caption, 0) + 1

                caption_pairs.append((true_caption, pred_caption))

    # Calculate classification metrics
    classification_metrics = calculate_classification_metrics(caption_pairs)

    avg_loss = total_loss / max(1, len(dataloader))

    total_samples = len(caption_pairs)
    unique_pred_captions = len(pred_caption_counts)
    unique_true_captions = len(true_caption_counts)

    pred_to_true_map = {}
    for true_cap, pred_cap in caption_pairs:
        if pred_cap not in pred_to_true_map:
            pred_to_true_map[pred_cap] = {}
        pred_to_true_map[pred_cap][true_cap] = pred_to_true_map[pred_cap].get(true_cap, 0) + 1

    sorted_pred_captions = sorted(pred_caption_counts.items(), key=lambda x: x[1], reverse=True)
    sorted_true_captions = sorted(true_caption_counts.items(), key=lambda x: x[1], reverse=True)

    caption_analysis = {
        "total_samples": total_samples,
        "unique_predicted_captions": unique_pred_captions,
        "unique_true_captions": unique_true_captions,
        "top_predicted_captions": sorted_pred_captions[:10],
        "top_true_captions": sorted_true_captions[:10],
        "prediction_to_true_map": pred_to_true_map,
        "repetition_rate": 1 - (unique_pred_captions / total_samples)
    }

    # Print all results
    print("=== BASIC METRICS ===")
    for key, value in caption_analysis.items():
        if key not in ["prediction_to_true_map"]:  # Skip the large map for cleaner output
            print(f"{key}: {value}")

    print(f"\nAverage Training Loss: {avg_loss:.4f}")

    print("\n=== CLASSIFICATION METRICS ===")
    print(f"Overall Accuracy: {classification_metrics['overall_accuracy']:.4f}")
    print(f"Macro Average Precision: {classification_metrics['macro_precision']:.4f}")
    print(f"Macro Average Recall: {classification_metrics['macro_recall']:.4f}")
    print(f"Macro Average F1: {classification_metrics['macro_f1']:.4f}")

    return avg_loss, classification_metrics


def calculate_classification_metrics(caption_pairs):
    """
    Calculate per-caption classification metrics treating each unique caption as a class.

    Args:
        caption_pairs: List of tuples (true_caption, predicted_caption)

    Returns:
        Dictionary containing overall accuracy, macro averages, and per-caption metrics
    """
    from collections import defaultdict

    # Get all unique captions
    all_captions = set()
    for true_cap, pred_cap in caption_pairs:
        all_captions.add(true_cap)
        all_captions.add(pred_cap)

    # Initialize confusion matrix components for each caption
    per_caption_metrics = {}

    for caption in all_captions:
        # For each caption as a "positive" class
        tp = sum(1 for true_cap, pred_cap in caption_pairs if true_cap == caption and pred_cap == caption)
        fp = sum(1 for true_cap, pred_cap in caption_pairs if true_cap != caption and pred_cap == caption)
        fn = sum(1 for true_cap, pred_cap in caption_pairs if true_cap == caption and pred_cap != caption)
        tn = sum(1 for true_cap, pred_cap in caption_pairs if true_cap != caption and pred_cap != caption)

        # Calculate metrics for this caption
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
        accuracy = (tp + tn) / len(caption_pairs) if len(caption_pairs) > 0 else 0.0

        per_caption_metrics[caption] = {
            'tp': tp,
            'fp': fp,
            'fn': fn,
            'tn': tn,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'accuracy': accuracy,
            'support': tp + fn  # Number of true instances of this caption
        }

    # Calculate macro averages
    macro_precision = sum(metrics['precision'] for metrics in per_caption_metrics.values()) / len(per_caption_metrics)
    macro_recall = sum(metrics['recall'] for metrics in per_caption_metrics.values()) / len(per_caption_metrics)
    macro_f1 = sum(metrics['f1'] for metrics in per_caption_metrics.values()) / len(per_caption_metrics)

    # Calculate overall accuracy
    correct_predictions = sum(1 for true_cap, pred_cap in caption_pairs if true_cap == pred_cap)
    overall_accuracy = correct_predictions / len(caption_pairs) if len(caption_pairs) > 0 else 0.0

    return {
        'per_caption_metrics': per_caption_metrics,
        'macro_precision': macro_precision,
        'macro_recall': macro_recall,
        'macro_f1': macro_f1,
        'overall_accuracy': overall_accuracy,
        'total_caption_classes': len(all_captions)
    }


def init_model(device):
    """
    Initializes and returns the components required for training or inference, including a CLIP model,
    GPT-2 tokenizer, a captioning model, and an optimizer.

    This function loads the pre-trained CLIP model (ViT-B/32), freezes its parameters, and loads the GPT-2 tokenizer.
    It also initializes the `ClipCaptionModel` for generating captions based on CLIP embeddings and sets up the
    optimizer for the caption model.

    Args:
        device (torch.device): The device (CPU or GPU) to load the models and tensors onto.

    Returns:
        tuple: A tuple containing:
            - clip_model (CLIPModel): The pre-trained CLIP model used for image embedding extraction.
            - preprocess (callable): The preprocessing function associated with the CLIP model.
            - tokenizer (GPT2Tokenizer): The tokenizer for the GPT-2 model used for caption tokenization.
            - caption_model (ClipCaptionModel): The model that generates captions from image embeddings.
            - optimizer (torch.optim.AdamW): The optimizer for training the caption model.
    """
    clip_model, preprocess = clip.load("ViT-B/32", device=device)

    for param in clip_model.parameters():
        param.requires_grad = False

    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

    caption_model = ClipCaptionModel(clip_dim=512, prefix_len=10).to(device)

    optimizer = torch.optim.AdamW(
        caption_model.parameters(),
        lr=5e-5,
        weight_decay=0.01
    )

    return clip_model, preprocess, tokenizer, caption_model, optimizer

def train_model(train_df, val_df, num_epochs=25, batch_size=8):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    clip_model, preprocess, tokenizer, caption_model, optimizer = init_model(device)

    #train_df, val_df = train_test_split(df_balanced, test_size=0.1, random_state=42, stratify=df_balanced['caption'])

    train_dataset = MRICaptionDataset(train_df, image_dir, preprocess, tokenizer, num_slices_to_use=10)
    val_dataset = MRICaptionDataset(val_df, val_image_dir, preprocess, tokenizer, num_slices_to_use=10)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn)

    best_val_loss = 100
    metrics = {'macro_precision': [],
        'macro_recall': [],
        'macro_f1': [],
        'overall_accuracy':[],
        'train_losses': [],
        'val_losses': []}

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")

        avg_train_loss = train_epoch(caption_model, clip_model, train_loader, optimizer, device, image_dir, preprocess)
        avg_val_loss, classification_metrics = evaluate(
            caption_model, clip_model, val_loader, device, tokenizer, val_image_dir, preprocess
        )

        print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(caption_model.state_dict(), f"/content/drive/MyDrive/biodata Project/best_model_sagittal.pt")

        metrics['macro_precision'].append(classification_metrics['macro_precision'])
        metrics['macro_recall'].append(classification_metrics['macro_recall'])
        metrics['macro_f1'].append(classification_metrics['macro_f1'])
        metrics['overall_accuracy'].append(classification_metrics['overall_accuracy'])
        metrics['train_losses'].append(avg_train_loss)
        metrics['val_losses'].append(avg_val_loss)

    return metrics

In [None]:
base_dir = "/content/drive/MyDrive/biodata Project/MRNet-v1.0"
plane = "axial"  # can be 'axial', 'coronal', 'sagittal'
image_dir = os.path.join(base_dir, "train", plane)
val_image_dir = os.path.join(base_dir, "valid", plane)


abnormal_df_train = pd.read_csv(os.path.join(base_dir, "train-abnormal.csv"))
acl_df_train = pd.read_csv(os.path.join(base_dir, "train-acl.csv"))
meniscus_df_train = pd.read_csv(os.path.join(base_dir, "train-meniscus.csv"))

abnormal_df_train.columns = ['exam', 'abnormal']
acl_df_train.columns = ['exam', 'acl']
meniscus_df_train.columns = ['exam', 'meniscus']

train_df = abnormal_df_train.merge(acl_df_train, on='exam').merge(meniscus_df_train, on='exam')


train_df['caption'] = train_df.apply(generate_caption, axis=1)


acl_meniscus_mask = (train_df['acl'] == 1) & (train_df['meniscus'] == 1)
acl_only_mask     = (train_df['acl'] == 1) & (train_df['meniscus'] == 0)
meniscus_only_mask = (train_df['acl'] == 0) & (train_df['meniscus'] == 1)
healthy_mask      = (train_df['abnormal'] == 0)
unspecified_mask  = (train_df['abnormal'] == 1) & (train_df['acl'] == 0) & (train_df['meniscus'] == 0)

# Sample 83 from each group
df_acl_meniscus = train_df[acl_meniscus_mask].sample(n=83, random_state=42)
df_acl_only     = train_df[acl_only_mask].sample(n=83, random_state=42)
df_meniscus     = train_df[meniscus_only_mask].sample(n=83, random_state=42)
df_healthy      = train_df[healthy_mask].sample(n=83, random_state=42)
df_unspecified  = train_df[unspecified_mask].sample(n=83, random_state=42)

# Concatenate them
train_df_balanced = pd.concat([
    df_acl_meniscus,
    df_acl_only,
    df_meniscus,
    df_healthy,
    df_unspecified
], ignore_index=True).sample(frac=1, random_state=42)


clip_preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                         std=(0.26862954, 0.26130258, 0.27577711))
])

In [None]:
train_df["caption"].value_counts()

Unnamed: 0_level_0,count
caption,Unnamed: 1_level_1
Unspecified abnormality.,432
Meniscus tear.,272
Healthy knee,217
ACL tear and a Meniscus tear.,125
ACL tear.,83


In [None]:
train_df_balanced["caption"].value_counts()

Unnamed: 0_level_0,count
caption,Unnamed: 1_level_1
ACL tear and a Meniscus tear.,83
ACL tear.,83
Meniscus tear.,83
Unspecified abnormality.,83
Healthy knee,83


In [None]:
abnormal_df_val = pd.read_csv(os.path.join(base_dir, "valid-abnormal.csv"))
acl_df_val = pd.read_csv(os.path.join(base_dir, "valid-acl.csv"))
meniscus_df_val = pd.read_csv(os.path.join(base_dir, "valid-meniscus.csv"))

abnormal_df_val.columns = ['exam', 'abnormal']
acl_df_val.columns = ['exam', 'acl']
meniscus_df_val.columns = ['exam', 'meniscus']

val_df = abnormal_df_val.merge(acl_df_val, on='exam').merge(meniscus_df_val, on='exam')


val_df['caption'] = train_df.apply(generate_caption, axis=1)

val_df["caption"].value_counts()

Unnamed: 0_level_0,count
caption,Unnamed: 1_level_1
Unspecified abnormality.,40
Meniscus tear.,40
Healthy knee,23
ACL tear and a Meniscus tear.,10
ACL tear.,6


In [None]:
metrics = train_model(train_df_balanced, val_df, num_epochs=25, batch_size=8)


Epoch 1/25


Training:   2%|▏         | 1/52 [00:21<18:41, 22.00s/it]

In [None]:
for key, value in metrics.items():
    print(f"{key}: {value}")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Find max values
max_values = {metric: max(values) for metric, values in metrics.items() if metric in ["macro_precision", "macro_recall", "macro_f1", "overall_accuracy"]}

# Plotting
fig, ax = plt.subplots()
bars = ax.bar(max_values.keys(), max_values.values(), color='skyblue')

# Annotate bars with values
for bar in bars:
    height = bar.get_height()
    ax.annotate(f'{height:.3f}',
                xy=(bar.get_x() + bar.get_width() / 2, height),
                xytext=(0, 0.),
                textcoords="offset points",
                ha='center', va='bottom')

ax.set_ylabel("Max Value")
ax.set_title("Maximum Metric Values")
plt.xticks(rotation=15)
plt.tight_layout()
plt.savefig("/content/axial_results.png", dpi=300)
plt.show()
