#### AveragedMedicalCLIPLoss implementation steps.

Problem statement: When working with medical text reports that consists of multiple labels or sentences. Some of the text might have similar meanings where it could refer to one or more images. Similarly, one image could be expressed with one or more report sections (from the same or different reports) or even labels. CLIP loss doesn't take into accound the underlying similarity of the text tokenized, thus for a large batch size (could consist of similar medical labels / sentences), the loss computation will be highly impacted as labels/texts are repeated in the formed logits [n,n] matrix. This mainly arrises when labels are similar, making the dot product assign high logits to several labels. This problem restricts the usage of CLIP in binary or multi-class for a list of text or labels, as they will be repeated on each column for the logits per image.

One way to handle it is to make the batch size equal to the unique number of classes, thus for binary task, a batch size of 2 is used, while for multi-class, a batch size of n_classes is used. This affects the training as the batch size is very low, assuming that the dataloaded are equally sampelled. 

To overcome this, we introduce AveragedMedicalCLIPLoss. A loss function that takes into account the similarity between the labels/text for each batch independently using the cosine similarity of the text embeddings. This way, texts with a similarity that satisfies a threshold are given the same label, where unique texts are given unique labels.

This way, we are able to average the logits per image for identical text, and compute the clip loss for each of the logits type as per the official implementation.

In [17]:
import sys
sys.path.append('../')

import mmgclip
import torch
import os
import numpy as np
from sentence_transformers import SentenceTransformer, util

torch.cuda.empty_cache() 

# for auto reload when changes are made in the package
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [22]:
mmgconfig = mmgclip.config()
mmgconfig.epochs = 30
mmgconfig.annotated_dataset_path = '../data/02_data_T_regions'

# create an experiment name, all outputs will be stored inside
# default output dir is 'outputs', experiment output path will be 
# {mmgconfig.export_dir}/{mmgconfig.experiment_name}
mmgconfig.experiment_name = 'ClassifierExperiment'
mmgconfig.export_dir = '../outputs' 
mmgconfig.warmup_epochs = 0.1
mmgconfig.weight_decay = 1e-4
mmgconfig.train_batch_size = 128
mmgconfig.val_test_batch_size = 128
mmgconfig.loss_type = 'CLIPLoss'
mmgconfig.projection_name = 'MultiLinearHead'
mmgconfig.learning_rate = 5e-5
mmgconfig.run_name = f'{mmgconfig.loss_type}_n={mmgconfig.train_batch_size}'
mmgconfig.output_projection_dimension = 256

[32mINFO[0m	 | config[epochs] will change from '15' to '30'
[32mINFO[0m	 | config[annotated_dataset_path] will change from '/local/abdel/mmg-clip/mmgclip/../data/02_data_T_regions' to '../data/02_data_T_regions'
[32mINFO[0m	 | config[experiment_name] will change from 'clip' to 'ClassifierExperiment'
[32mINFO[0m	 | config[export_dir] will change from 'outputs' to '../outputs'
[32mINFO[0m	 | config[warmup_epochs] will change from '0.1' to '0.1'
[32mINFO[0m	 | config[weight_decay] will change from '0.0001' to '0.0001'
[32mINFO[0m	 | config[train_batch_size] will change from '128' to '128'
[32mINFO[0m	 | config[val_test_batch_size] will change from '128' to '128'
[32mINFO[0m	 | config[loss_type] will change from 'CLIPLoss' to 'CLIPLoss'
[32mINFO[0m	 | config[projection_name] will change from 'LinearProjectionLayer' to 'MultiLinearHead'
[32mINFO[0m	 | config[learning_rate] will change from '5e-05' to '5e-05'
[32mINFO[0m	 | config[run_name] will change from 'run01' to

Load a checkpoint for experimenting.

In [24]:
model = mmgclip.model(config=mmgconfig)
model.load_state_dict(
    torch.load(
        os.path.join(mmgconfig.export_dir, mmgconfig.experiment_name, mmgconfig.run_name, 'checkpoints', '20240306_6f4bbca6_lossCLIPLoss_epo30_seed42_lr5e-05_weight_decay0.001warmup_epochs0.1_train_bs128_projectionMultiLinearHead256.pth'))['model_state_dict'])

clf   = mmgclip.PromptClassifier(model=model)

[32mINFO[0m	 | Initializing pretrained `emilyalsentzer/Bio_ClinicalBERT` as the text encoder and tokenizer.


[32mINFO[0m	 | Embeddings are projected to 256 features using MultiLinearHead.


Shape of the labels has to be equal to the n batch size of the images

In [25]:
# tokenize the labels using the same model tokenizer
labels = [
    'this benign',
    'malignant',
    'benign case',
    'malignant case',
    'benign',
    'malignant',
    'benign',
    'xyz'
    ]

# labels = [
#     'circumscribed',
#     'obsecured',
#     'microlobulated',
#     'indistinct',
#     'spiculated',
#     'microlobulated',
#     'indistinct',
#     'spiculated',
#     ]

text_tokens = clf.tokenizer(
    labels, 
    padding="max_length", 
    truncation=True, 
    return_tensors="pt", 
    max_length=model.config.sequence_length)

print(text_tokens['input_ids'].shape)

batch = dict()
batch['text_tokens'] = text_tokens

torch.Size([8, 256])


Obtain the text embeddings. Note that here we must pass through the learned projextion layer other wise the embeddings will be random.

In [27]:
text_embeddings = model.encode_text(batch=batch)
text_embeddings = model.text_projection_layer(text_embeddings)
text_embeddings = text_embeddings / text_embeddings.norm(dim=1, keepdim=True) 
text_embeddings.shape

torch.Size([8, 256])

Measure the cosine similarity between each and every sentence combination.

In [28]:
def mesaure_embeddings_similarity(embeddings):
    return util.cos_sim(embeddings, embeddings)

def embedding_labels_similarities(similarities, labels):
    d = {}
    for i, v1 in enumerate(labels):
        for j, v2 in enumerate(labels):
            if i >= j:
                continue
            d[v1 + ' vs. ' + v2] = similarities[i][j].item()

    # sort by score
    d_sorted = dict(sorted(d.items(), key=lambda x: x[1], reverse=True))
    return d_sorted

# measure the cosine similarity
cosine_scores = mesaure_embeddings_similarity(text_embeddings)
print(cosine_scores)
# for visualizing the scores, print them 
d_sorted = embedding_labels_similarities(cosine_scores, labels)
d_sorted

tensor([[1.0000, 0.7714, 0.8803, 0.8041, 0.9058, 0.7714, 0.9058, 0.7866],
        [0.7714, 1.0000, 0.8387, 0.9547, 0.8594, 1.0000, 0.8594, 0.6772],
        [0.8803, 0.8387, 1.0000, 0.9082, 0.9064, 0.8387, 0.9064, 0.8011],
        [0.8041, 0.9547, 0.9082, 1.0000, 0.8636, 0.9547, 0.8636, 0.7153],
        [0.9058, 0.8594, 0.9064, 0.8636, 1.0000, 0.8594, 1.0000, 0.7719],
        [0.7714, 1.0000, 0.8387, 0.9547, 0.8594, 1.0000, 0.8594, 0.6772],
        [0.9058, 0.8594, 0.9064, 0.8636, 1.0000, 0.8594, 1.0000, 0.7719],
        [0.7866, 0.6772, 0.8011, 0.7153, 0.7719, 0.6772, 0.7719, 1.0000]],
       device='cuda:0', grad_fn=<MmBackward0>)


{'benign vs. benign': 1.0000001192092896,
 'malignant vs. malignant': 1.0,
 'malignant vs. malignant case': 0.9547221660614014,
 'malignant case vs. malignant': 0.9547221660614014,
 'benign case vs. malignant case': 0.9082435965538025,
 'benign case vs. benign': 0.9063839912414551,
 'this benign vs. benign': 0.9057816863059998,
 'this benign vs. benign case': 0.8802900910377502,
 'malignant case vs. benign': 0.8636056184768677,
 'malignant vs. benign': 0.8594213724136353,
 'benign vs. malignant': 0.8594213724136353,
 'malignant vs. benign case': 0.8387349247932434,
 'benign case vs. malignant': 0.8387349247932434,
 'this benign vs. malignant case': 0.8040817379951477,
 'benign case vs. xyz': 0.8010841012001038,
 'this benign vs. xyz': 0.786582887172699,
 'benign vs. xyz': 0.7719182968139648,
 'this benign vs. malignant': 0.771402895450592,
 'malignant case vs. xyz': 0.7153266072273254,
 'malignant vs. xyz': 0.677194356918335}

Assign a unique label to each of the similar values higher than the speicified threshold.

In [31]:
def assign_labels(cosine_sim_matrix, threshold=0.7):
    num_texts = cosine_sim_matrix.shape[0]
    labels = [-1] * num_texts  # Initialize labels with -1
    
    current_label = 0
    
    for i in range(num_texts):
        if labels[i] == -1:  # If the text hasn't been assigned a label yet
            labels[i] = current_label  # Assign it the current label
            for j in range(i+1, num_texts):
                if cosine_sim_matrix[i][j] >= threshold:
                    if labels[j] == -1:
                        labels[j] = current_label  # Assign the same label if similarity >= threshold
            current_label += 1
    
    return labels

list_labels = assign_labels(cosine_scores, threshold=0.8)
print(list_labels)


[0, 1, 0, 0, 0, 1, 0, 2]


Loss implementation based on text similarity.

In [32]:
logits_per_image = torch.tensor(
    [[-0.3695, -0.8987, -0.3323, -0.3540, -0.3375, -0.5998, -0.3583, -0.0797],
     [-0.9398, -1.1682, -0.9602, -0.7505, -1.0275, -0.5558, -0.3456, -0.3068],
     [-0.8346, -1.1233, -0.7055, -0.4546, -0.6598, -0.6412, -0.6927, -0.1958],
     [-0.8875, -1.3657, -0.6414, -0.8099, -0.8178, -0.8100, -0.6184, -0.1464],
     [-0.7839, -1.2652, -0.6129, -0.4527, -0.5410, -0.4618, -0.4844, -0.3835],
     [-1.0263, -1.3110, -0.7902, -0.7323, -0.6832, -0.9224, -0.6688, -0.6417],
     [-0.5663, -0.5041, -0.5145, -0.0413, -0.2905, -0.2322, -0.3936,  0.0914],
     [-0.1942, -0.7119, -0.3226, -0.1033, -0.2929, -0.1779, -0.2586, -0.1330]])

logits_per_image.shape

torch.Size([8, 8])

In [33]:
set(list_labels)

{0, 1, 2}

In [34]:
unique_labels = set(list_labels)
# num_unique_labels = len(unique_labels)
averaged_logits = []

for label in unique_labels:
    label_indices = [i for i, l in enumerate(list_labels) if l == label]
    label_logits = logits_per_image[:, label_indices]
    averaged_logit = torch.mean(label_logits, dim=1)
    averaged_logits.append(averaged_logit)

averaged_logits_tensor = torch.stack(averaged_logits, dim=1)
print(averaged_logits_tensor.softmax(axis=-1))

tensor([[0.3354, 0.2250, 0.4396],
        [0.2786, 0.2631, 0.4583],
        [0.2929, 0.2368, 0.4703],
        [0.2813, 0.2017, 0.5170],
        [0.3378, 0.2531, 0.4091],
        [0.3493, 0.2495, 0.4012],
        [0.2805, 0.2785, 0.4410],
        [0.3428, 0.2777, 0.3794]])


In [35]:
import torch.nn.functional as F


labels = torch.tensor(list_labels)
print(labels)

loss_i = F.cross_entropy(averaged_logits_tensor, labels)

print(loss_i)

tensor([0, 1, 0, 0, 0, 1, 0, 2])
tensor(1.2048)
