<a href="https://www.kaggle.com/code/rishabhsingh18/adamv2-experiment-rissingh?scriptVersionId=191460691" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [3]:
import os
import shutil
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from sklearn.metrics import pairwise_distances, accuracy_score, f1_score, precision_score, recall_score
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

Data Selection: for this experiment I have used the Chest X-Ray Datasets from the NIH which is also available on Kaggle. And use Torchvision for preprocessing the data.

On the other end I tried to organize the dataset by different subcategories as well in order to define which is which. 


In [None]:
# Define paths
data_dir = '/kaggle/input/data'
labels_file = 'path_to_nih_chest_xray_labels_file.csv'

# Step 1: Organize the dataset
def organize_nih_dataset(data_dir, labels_file):
    images_dir = os.path.join(data_dir, 'images')
    train_dir = os.path.join(data_dir, 'train')
    test_dir = os.path.join(data_dir, 'test')
    
    if not os.path.exists(train_dir):
        os.makedirs(train_dir)
    
    if not os.path.exists(test_dir):
        os.makedirs(test_dir)
    
    labels = pd.read_csv(labels_file)
    
    # Create directories for each class
    classes = labels['Finding Labels'].unique()
    for cls in classes:
        if cls == 'No Finding':
            continue
        os.makedirs(os.path.join(train_dir, cls), exist_ok=True)
        os.makedirs(os.path.join(test_dir, cls), exist_ok=True)
    
    # Move images to corresponding directories
    for idx, row in tqdm(labels.iterrows(), total=len(labels)):
        file_name = row['Image Index']
        label = row['Finding Labels']
        
        # Skip images without any findings
        if label == 'No Finding':
            continue
        
        # Split data into training and testing (e.g., 80-20 split)
        if idx % 5 == 0:
            dest_dir = test_dir
        else:
            dest_dir = train_dir
        
        # Move the image to the corresponding class directory
        src_path = os.path.join(images_dir, file_name)
        dest_path = os.path.join(dest_dir, label, file_name)
        shutil.move(src_path, dest_path)
        
        # Step 2: Data Preparation
batch_size = 32

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'train'), transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'test'), transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

As we know we will be implementing the AdamV2 model with its 3 branches

- Localizability - learns to distinguish different anatomical structures. Utilized student and teacher encoders to process images and generate embeddings, maximizing its consistency between the anchor image and its augmented views.

- Composability - This branch learns part-whole relationships by assembling larger anatomical structures from smaller parts. Uses Student and teacher networks to process parts of an image, generating embeddings that represent the whole structure.

- decomposability - learns whole-part relationships by decomposing larger structures into smaller parts. Processes the whole image and its parts to maximeze the agreement between the embeddings of the whole and its parts.

here is an example using a simple ResNet backbone.

Zero Shot Evaluation - is used to evaluate the model's ability to produce meaningful embeddings without any fine tuning. Use a nearest neighbor search to evaluate the embeddings.



In [None]:
# Model Architecture
class AdamV2(nn.Module):
    def __init__(self):
        super(AdamV2, self).__init__()
        self.backbone = models.resnet50(weights=True)
        self.backbone.fc = nn.Identity()  # Remove the last classification layer
        
        # Define heads for each branch
        self.localizability_head = nn.Linear(2048, 128)
        self.composability_head = nn.Linear(2048, 128)
        self.decomposability_head = nn.Linear(2048, 128)
    
    def forward(self, x):
        features = self.backbone(x)
        localizability = self.localizability_head(features)
        composability = self.composability_head(features)
        decomposability = self.decomposability_head(features)
        return localizability, composability, decomposability

model = AdamV2().cuda()  # Move model to GPU if available

# Zero Shot evaluation

In [None]:
# Zero-shot Evaluation
model.eval()
embeddings = []
labels = []

# Collect embeddings from the test dataset
with torch.no_grad():
    for images, label in test_loader:
        images = images.cuda()
        localizability, composability, decomposability = model(images)
        embeddings.append(localizability.cpu().numpy())
        labels.append(label.numpy())

embeddings = np.concatenate(embeddings, axis=0)
labels = np.concatenate(labels, axis=0)

# Perform nearest neighbor search
def nearest_neighbor_accuracy(embeddings, labels, k=1):
    distances = pairwise_distances(embeddings, embeddings, metric='euclidean')
    sorted_indices = np.argsort(distances, axis=1)
    nearest_labels = labels[sorted_indices[:, 1:k+1]]
    
    correct = 0
    for i in range(len(labels)):
        if labels[i] in nearest_labels[i]:
            correct += 1
            
    return correct / len(labels)

accuracy = nearest_neighbor_accuracy(embeddings, labels, k=1)
print(f'Zero-shot Nearest Neighbor Accuracy: {accuracy:.4f}')


- Few - shot Transfer learning - will fine-tune the model on a small labeled dataset and evaluate its performance.

In [None]:
# Few-shot Transfer Learning
few_shot_epochs = 5
few_shot_optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)


for epoch in range(few_shot_epochs):
    model.train()
    for images, labels in train_loader:  # Replace `train_loader` with `few_shot_train_loader` if available
        images, labels = images.cuda(), labels.cuda()
        few_shot_optimizer.zero_grad()

        # Forward pass
        localizability, composability, decomposability = model(images)

        # Use a classification loss for few-shot learning
        classification_loss = nn.CrossEntropyLoss()(localizability, labels)

        # Backward pass and optimization
        classification_loss.backward()
        few_shot_optimizer.step()

    print(f'Few-shot Epoch [{epoch + 1}/{few_shot_epochs}], Loss: {classification_loss.item():.4f}')
    # Evaluate on the test set
model.eval()
few_shot_embeddings = []
few_shot_labels = []

with torch.no_grad():
    for images, label in test_loader:
        images = images.cuda()
        localizability, composability, decomposability = model(images)
        few_shot_embeddings.append(localizability.cpu().numpy())
        few_shot_labels.append(label.numpy())

few_shot_embeddings = np.concatenate(few_shot_embeddings, axis=0)
few_shot_labels = np.concatenate(few_shot_labels, axis=0)

few_shot_accuracy = nearest_neighbor_accuracy(few_shot_embeddings, few_shot_labels, k=1)
print(f'Few-shot Nearest Neighbor Accuracy: {few_shot_accuracy:.4f}')

- Full fine-tuning the model on the entire labeled dataset

In [None]:
# Full Fine-tuning
full_tuning_epochs = 10
full_tuning_optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

for epoch in range(full_tuning_epochs):
    model.train()
    for images, labels in train_loader:
        images, labels = images.cuda(), labels.cuda()
        full_tuning_optimizer.zero_grad()

        # Forward pass
        localizability, composability, decomposability = model(images)

        # Use a classification loss for full fine-tuning
        classification_loss = nn.CrossEntropyLoss()(localizability, labels)

        # Backward pass and optimization
        classification_loss.backward()
        full_tuning_optimizer.step()

    print(f'Full Fine-tuning Epoch [{epoch + 1}/{full_tuning_epochs}], Loss: {classification_loss.item():.4f}')

# Evaluate on the test set
model.eval()
full_tuning_embeddings = []
full_tuning_labels = []

with torch.no_grad():
    for images, label in test_loader:
        images = images.cuda()
        localizability, composability, decomposability = model(images)
        full_tuning_embeddings.append(localizability.cpu().numpy())
        full_tuning_labels.append(label.numpy())

full_tuning_embeddings = np.concatenate(full_tuning_embeddings, axis=0)
full_tuning_labels = np.concatenate(full_tuning_labels, axis=0)

full_tuning_accuracy = nearest_neighbor_accuracy(full_tuning_embeddings, full_tuning_labels, k=1)
print(f'Full Fine-tuning Nearest Neighbor Accuracy: {full_tuning_accuracy:.4f}')

- Analysis and Validation 

Perform Analyses to assess the model's understanding of anatomical structures. Use the t-SNE for visualizationg of embeddings.

In [None]:
# Feature Analysis
def plot_embeddings(embeddings, labels):
    tsne = TSNE(n_components=2, perplexity=30, n_iter=300)
    tsne_results = tsne.fit_transform(embeddings)

    plt.figure(figsize=(10, 10))
    for label in np.unique(labels):
        indices = np.where(labels == label)
        plt.scatter(tsne_results[indices, 0], tsne_results[indices, 1], label=label)

    plt.legend()
    plt.show()


plot_embeddings(full_tuning_embeddings, full_tuning_labels)

In [None]:
#Validation
#def evaluate_classification_performance(labels_true, labels_pred):
    #accuracy = accuracy_score(labels_true, labels_pred)
    #f1 = f1_score(labels_true, labels_pred, average='weighted')
    #precision = precision_score(labels_true, labels_pred, average='weighted')
    #recall = recall_score(labels_true, labels_pred, average='weighted')

    #print(f'Accuracy: {accuracy:.4f}')
    #print(f'F1 Score: {f1:.4f}')
    #print(f'Precision: {precision:.4f}')
    #print(f'Recall: {recall:.4f}')

#Assuming labels_pred are obtained from a classifier on top of embeddings
#labels_pred = ...  # Obtain these from your classifier
#evaluate_classification_performance(full_tuning_labels, labels_pred)

References:

SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton
Paper
BYOL: Bootstrap Your Own Latent - A New Approach to Self-Supervised Learning

Jean-Bastien Grill, Florian Strub, Florent Altché, Corentin Tallec, Pierre H. Richemond, Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Daniel Guo, Mohammad Gheshlaghi Azar, Bilal Piot, Koray Kavukcuoglu, Rémi Munos, Michal Valko
Paper
MoCo: Momentum Contrast for Unsupervised Visual Representation Learning

Kaiming He, Haoqi Fan, Yuxin Wu, Saining Xie, Ross Girshick
Paper
Hierarchical Representation Learning
GLOM: How to represent part-whole hierarchies in a neural network

Geoffrey Hinton
Paper
Learning Disentangled Representations with Semi-Supervised Deep Generative Models

Lars Maaløe, Casper Kaae Sønderby, Søren Kaae Sønderby, Ole Winther
Paper
Medical Image Analysis
CheXNet: Radiologist-Level Pneumonia Detection on Chest X-Rays with Deep Learning

Pranav Rajpurkar, Jeremy Irvin, Kaylie Zhu, Brandon Yang, Hershel Mehta, Tony Duan, Daisy Ding, Aarti Bagul, Curtis Langlotz, Katie Shpanskaya, Matthew P. Lungren, Andrew Y. Ng
Paper
UNet: Convolutional Networks for Biomedical Image Segmentation

Olaf Ronneberger, Philipp Fischer, Thomas Brox
Paper
General Deep Learning
Deep Learning
Ian Goodfellow, Yoshua Bengio, Aaron Courville
Book
Further Reading and Tools
PyTorch Documentation

Official PyTorch documentation and tutorials
Website
Keras Documentation

Official Keras documentation and tutorials
Website
Self-Supervised Learning: The Dark Matter of Intelligence

Yann LeCun, Ishan Misra, Soumith Chintala, Sergey Zagoruyko
Paper