In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset
from timm import create_model
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from sentence_transformers import SentenceTransformer

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset and transforms
data_dir = "/kaggle/input/breakhist-binary-classificationfew-shot/BreakHist/Final 200X"

transform_vit = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

transform_clip = transforms.Compose([
    transforms.Resize((224, 224)),
])

# Create datasets
train_dataset_vit = ImageFolder(root=f"{data_dir}/Train", transform=transform_vit)
test_dataset_vit = ImageFolder(root=f"{data_dir}/Test", transform=transform_vit)
train_dataset_clip = ImageFolder(root=f"{data_dir}/Train", transform=transform_clip)
test_dataset_clip = ImageFolder(root=f"{data_dir}/Test", transform=transform_clip)

# Load models
vit_model = create_model('vit_small_patch16_224', pretrained=True)
vit_model.head = nn.Linear(vit_model.num_features, len(train_dataset_vit.classes))
vit_model.to(device)

clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
text_encoder = SentenceTransformer('all-MiniLM-L6-v2').to(device)

# Hyperparameters
num_epochs = 50
learning_rate = 0.001
batch_size = 32
shots_per_class = 5

# Feature fusion helper functions
def calculate_feature_relevance(features, labels):
    """
    Calculate feature relevance using Fisher's criterion for each feature dimension
    """
    unique_labels = np.unique(labels)
    n_features = features.shape[1]
    relevance_scores = np.zeros(n_features)
    
    for i in range(n_features):
        feature_i = features[:, i]
        
        # Calculate mean and variance for each class
        class_means = np.array([np.mean(feature_i[labels == label]) for label in unique_labels])
        class_vars = np.array([np.var(feature_i[labels == label]) for label in unique_labels])
        
        # Fisher's criterion: inter-class variance / intra-class variance
        between_class_var = np.var(class_means)
        within_class_var = np.mean(class_vars)
        
        # Avoid division by zero
        relevance_scores[i] = between_class_var / (within_class_var + 1e-10)
    
    # Normalize scores
    return relevance_scores / (np.sum(relevance_scores) + 1e-10)

def calculate_feature_similarity(vit_features, clip_features):
    """
    Calculate similarity between ViT and CLIP features using canonical correlation analysis
    """
    # Normalize features
    vit_norm = vit_features - np.mean(vit_features, axis=0)
    clip_norm = clip_features - np.mean(clip_features, axis=0)
    
    # Calculate correlation matrix
    corr_matrix = np.dot(vit_norm.T, clip_norm) / (vit_features.shape[0] - 1)
    
    # Calculate similarity scores
    similarity_scores = np.abs(np.diag(corr_matrix))
    return similarity_scores / (np.sum(similarity_scores) + 1e-10)

def weighted_feature_fusion(vit_features, clip_features, labels=None):
    """
    Perform weighted feature fusion using relevance and similarity metrics
    """
    # Calculate feature relevance if labels are provided (for training set)
    if labels is not None:
        vit_relevance = calculate_feature_relevance(vit_features, labels)
        clip_relevance = calculate_feature_relevance(clip_features, labels)
    else:
        # For test set, use uniform relevance
        vit_relevance = np.ones(vit_features.shape[1]) / vit_features.shape[1]
        clip_relevance = np.ones(clip_features.shape[1]) / clip_features.shape[1]
    
    # Calculate feature similarity
    similarity_scores = calculate_feature_similarity(vit_features, clip_features)
    
    # Combine metrics to create final weights
    vit_weights = vit_relevance * np.mean(similarity_scores)
    clip_weights = clip_relevance * np.mean(similarity_scores)
    
    # Normalize weights
    vit_weights = vit_weights / (np.sum(vit_weights) + 1e-10)
    clip_weights = clip_weights / (np.sum(clip_weights) + 1e-10)
    
    # Apply weighted fusion
    fused_features = np.concatenate([
        vit_features * vit_weights.reshape(1, -1),
        clip_features * clip_weights.reshape(1, -1)
    ], axis=1)
    
    return fused_features

def get_support_query_indices(dataset, shots_per_class):
    class_indices = {label: [] for label in range(len(dataset.classes))}
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)
    
    support_indices, query_indices = [], []
    for label, indices in class_indices.items():
        np.random.shuffle(indices)
        support_indices.extend(indices[:shots_per_class])
        query_indices.extend(indices[shots_per_class:])
    
    return support_indices, query_indices

# Get support and query indices
support_indices, query_indices = get_support_query_indices(train_dataset_vit, shots_per_class)

# Create subsets
support_set_vit = Subset(train_dataset_vit, support_indices)
query_set_vit = Subset(train_dataset_vit, query_indices)
support_set_clip = Subset(train_dataset_clip, support_indices)
query_set_clip = Subset(train_dataset_clip, query_indices)

# Extract labels for support set
support_labels = np.array([label for _, label in support_set_vit])

def generate_texts(dataset):
    texts = []
    for img, _ in tqdm(dataset, desc="Generating texts"):
        if isinstance(img, torch.Tensor):
            img = transforms.ToPILImage()(img)
        
        inputs = clip_processor(
            images=img,
            text="Describe in as much detail as possible this histopathology image focusing on features distinguishing malignant and benign tumors.",
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to(device)
        
        with torch.no_grad():
            image_features = clip_model.get_image_features(pixel_values=inputs.pixel_values)
            generated_text = "Generated text description"  # Placeholder
        
        texts.append(generated_text)
    return texts

# Generate and encode texts
print("Generating support set texts...")
support_texts = generate_texts(support_set_clip)
print("Generating query set texts...")
query_texts = generate_texts(query_set_clip)
print("Generating test set texts...")
test_texts = generate_texts(test_dataset_clip)

def encode_texts(texts):
    return text_encoder.encode(texts, convert_to_tensor=True, device=device).cpu().numpy()

print("Encoding texts...")
support_text_features = encode_texts(support_texts)
query_text_features = encode_texts(query_texts)
test_text_features = encode_texts(test_texts)

# Training loop
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit_model.parameters(), lr=learning_rate)

def train_model_on_support():
    support_loader = DataLoader(support_set_vit, batch_size=batch_size, shuffle=True)
    for epoch in range(num_epochs):
        vit_model.train()
        epoch_loss = 0
        correct = 0
        total = 0
        
        for images, labels in tqdm(support_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)
            
            outputs = vit_model(images)
            loss = criterion(outputs, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            epoch_loss += loss.item()
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Acc: {100*correct/total:.2f}%")

print("Training model on support set...")
train_model_on_support()

def extract_fused_features(loader, text_features, labels=None):
    vit_features, all_labels = [], []
    vit_model.eval()
    
    with torch.no_grad():
        for i, (images, targets) in enumerate(tqdm(loader)):
            images = images.to(device)
            features = vit_model(images).cpu().numpy()
            vit_features.append(features)
            all_labels.append(targets.numpy())
    
    vit_features = np.vstack(vit_features)
    all_labels = np.hstack(all_labels)
    
    # Apply weighted feature fusion
    fused_features = weighted_feature_fusion(vit_features, text_features, labels)
    return fused_features, all_labels

# Prepare loaders
support_loader = DataLoader(support_set_vit, batch_size=1, shuffle=False)
query_loader = DataLoader(query_set_vit, batch_size=1, shuffle=False)
test_loader = DataLoader(test_dataset_vit, batch_size=1, shuffle=False)

# Extract fused features
print("Extracting support set features...")
support_fused, _ = extract_fused_features(support_loader, support_text_features, support_labels)
print("Extracting query set features...")
query_fused, query_labels = extract_fused_features(query_loader, query_text_features)
print("Extracting test set features...")
test_fused, test_labels = extract_fused_features(test_loader, test_text_features)

# Calculate prototypes using fused features
class_prototypes = []
for label in np.unique(support_labels):
    class_features = support_fused[support_labels == label]
    class_prototypes.append(class_features.mean(axis=0))
class_prototypes = np.array(class_prototypes)

# Evaluate using fused features
all_features = np.vstack([query_fused, test_fused])
all_labels = np.concatenate([query_labels, test_labels])

correct = 0
for feature, label in zip(all_features, all_labels):
    similarities = cosine_similarity([feature], class_prototypes)[0]
    predicted = np.argmax(similarities)
    correct += (predicted == label)

print(f"Final accuracy: {100*correct/len(all_labels):.2f}%")

Generating support set texts...


Generating texts: 100%|██████████| 10/10 [00:00<00:00, 77.20it/s]


Generating query set texts...


Generating texts: 100%|██████████| 1399/1399 [00:17<00:00, 78.42it/s]


Generating test set texts...


Generating texts: 100%|██████████| 604/604 [00:07<00:00, 81.08it/s]

Encoding texts...





Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/44 [00:00<?, ?it/s]

Batches:   0%|          | 0/19 [00:00<?, ?it/s]

Training model on support set...


Epoch 1/50: 100%|██████████| 1/1 [00:00<00:00,  8.80it/s]


Epoch [1/50], Loss: 0.5896, Acc: 80.00%


Epoch 2/50: 100%|██████████| 1/1 [00:00<00:00,  9.15it/s]


Epoch [2/50], Loss: 5.8440, Acc: 50.00%


Epoch 3/50: 100%|██████████| 1/1 [00:00<00:00,  8.78it/s]


Epoch [3/50], Loss: 7.2101, Acc: 50.00%


Epoch 4/50: 100%|██████████| 1/1 [00:00<00:00,  9.40it/s]


Epoch [4/50], Loss: 1.7024, Acc: 50.00%


Epoch 5/50: 100%|██████████| 1/1 [00:00<00:00,  9.59it/s]


Epoch [5/50], Loss: 3.6750, Acc: 50.00%


Epoch 6/50: 100%|██████████| 1/1 [00:00<00:00,  9.63it/s]


Epoch [6/50], Loss: 0.7946, Acc: 50.00%


Epoch 7/50: 100%|██████████| 1/1 [00:00<00:00,  9.67it/s]


Epoch [7/50], Loss: 0.6550, Acc: 50.00%


Epoch 8/50: 100%|██████████| 1/1 [00:00<00:00,  9.49it/s]


Epoch [8/50], Loss: 0.7473, Acc: 50.00%


Epoch 9/50: 100%|██████████| 1/1 [00:00<00:00,  9.34it/s]


Epoch [9/50], Loss: 0.7753, Acc: 50.00%


Epoch 10/50: 100%|██████████| 1/1 [00:00<00:00,  9.60it/s]


Epoch [10/50], Loss: 0.6512, Acc: 90.00%


Epoch 11/50: 100%|██████████| 1/1 [00:00<00:00,  9.27it/s]


Epoch [11/50], Loss: 0.7200, Acc: 50.00%


Epoch 12/50: 100%|██████████| 1/1 [00:00<00:00,  9.15it/s]


Epoch [12/50], Loss: 0.6475, Acc: 70.00%


Epoch 13/50: 100%|██████████| 1/1 [00:00<00:00,  9.20it/s]


Epoch [13/50], Loss: 0.6913, Acc: 50.00%


Epoch 14/50: 100%|██████████| 1/1 [00:00<00:00,  9.49it/s]


Epoch [14/50], Loss: 0.6372, Acc: 100.00%


Epoch 15/50: 100%|██████████| 1/1 [00:00<00:00,  9.60it/s]


Epoch [15/50], Loss: 0.6666, Acc: 50.00%


Epoch 16/50: 100%|██████████| 1/1 [00:00<00:00,  9.31it/s]


Epoch [16/50], Loss: 0.6219, Acc: 80.00%


Epoch 17/50: 100%|██████████| 1/1 [00:00<00:00,  9.04it/s]


Epoch [17/50], Loss: 0.6375, Acc: 50.00%


Epoch 18/50: 100%|██████████| 1/1 [00:00<00:00,  7.90it/s]


Epoch [18/50], Loss: 0.5988, Acc: 80.00%


Epoch 19/50: 100%|██████████| 1/1 [00:00<00:00,  8.95it/s]


Epoch [19/50], Loss: 0.5990, Acc: 60.00%


Epoch 20/50: 100%|██████████| 1/1 [00:00<00:00,  9.12it/s]


Epoch [20/50], Loss: 0.5624, Acc: 80.00%


Epoch 21/50: 100%|██████████| 1/1 [00:00<00:00,  9.50it/s]


Epoch [21/50], Loss: 0.5470, Acc: 70.00%


Epoch 22/50: 100%|██████████| 1/1 [00:00<00:00,  8.99it/s]


Epoch [22/50], Loss: 0.4921, Acc: 100.00%


Epoch 23/50: 100%|██████████| 1/1 [00:00<00:00,  9.50it/s]


Epoch [23/50], Loss: 0.4655, Acc: 80.00%


Epoch 24/50: 100%|██████████| 1/1 [00:00<00:00,  9.20it/s]


Epoch [24/50], Loss: 0.3820, Acc: 100.00%


Epoch 25/50: 100%|██████████| 1/1 [00:00<00:00,  9.46it/s]


Epoch [25/50], Loss: 0.3054, Acc: 100.00%


Epoch 26/50: 100%|██████████| 1/1 [00:00<00:00,  9.52it/s]


Epoch [26/50], Loss: 0.2587, Acc: 90.00%


Epoch 27/50: 100%|██████████| 1/1 [00:00<00:00,  8.84it/s]


Epoch [27/50], Loss: 0.3258, Acc: 80.00%


Epoch 28/50: 100%|██████████| 1/1 [00:00<00:00,  9.23it/s]


Epoch [28/50], Loss: 0.8586, Acc: 70.00%


Epoch 29/50: 100%|██████████| 1/1 [00:00<00:00,  9.44it/s]


Epoch [29/50], Loss: 0.0548, Acc: 100.00%


Epoch 30/50: 100%|██████████| 1/1 [00:00<00:00,  9.48it/s]


Epoch [30/50], Loss: 0.8334, Acc: 70.00%


Epoch 31/50: 100%|██████████| 1/1 [00:00<00:00,  8.88it/s]


Epoch [31/50], Loss: 0.7610, Acc: 70.00%


Epoch 32/50: 100%|██████████| 1/1 [00:00<00:00,  9.54it/s]


Epoch [32/50], Loss: 0.7779, Acc: 70.00%


Epoch 33/50: 100%|██████████| 1/1 [00:00<00:00,  9.71it/s]


Epoch [33/50], Loss: 0.1342, Acc: 100.00%


Epoch 34/50: 100%|██████████| 1/1 [00:00<00:00,  9.31it/s]


Epoch [34/50], Loss: 0.4574, Acc: 70.00%


Epoch 35/50: 100%|██████████| 1/1 [00:00<00:00,  9.34it/s]


Epoch [35/50], Loss: 0.4803, Acc: 70.00%


Epoch 36/50: 100%|██████████| 1/1 [00:00<00:00,  9.61it/s]


Epoch [36/50], Loss: 0.1427, Acc: 100.00%


Epoch 37/50: 100%|██████████| 1/1 [00:00<00:00,  9.48it/s]


Epoch [37/50], Loss: 0.1960, Acc: 90.00%


Epoch 38/50: 100%|██████████| 1/1 [00:00<00:00,  9.60it/s]


Epoch [38/50], Loss: 0.2681, Acc: 90.00%


Epoch 39/50: 100%|██████████| 1/1 [00:00<00:00,  9.26it/s]


Epoch [39/50], Loss: 0.1105, Acc: 100.00%


Epoch 40/50: 100%|██████████| 1/1 [00:00<00:00,  9.56it/s]


Epoch [40/50], Loss: 0.0750, Acc: 100.00%


Epoch 41/50: 100%|██████████| 1/1 [00:00<00:00,  9.56it/s]


Epoch [41/50], Loss: 0.1299, Acc: 90.00%


Epoch 42/50: 100%|██████████| 1/1 [00:00<00:00,  9.30it/s]


Epoch [42/50], Loss: 0.1147, Acc: 90.00%


Epoch 43/50: 100%|██████████| 1/1 [00:00<00:00,  9.68it/s]


Epoch [43/50], Loss: 0.0383, Acc: 100.00%


Epoch 44/50: 100%|██████████| 1/1 [00:00<00:00,  9.62it/s]


Epoch [44/50], Loss: 0.0197, Acc: 100.00%


Epoch 45/50: 100%|██████████| 1/1 [00:00<00:00,  9.54it/s]


Epoch [45/50], Loss: 0.0407, Acc: 100.00%


Epoch 46/50: 100%|██████████| 1/1 [00:00<00:00,  9.63it/s]


Epoch [46/50], Loss: 0.0465, Acc: 100.00%


Epoch 47/50: 100%|██████████| 1/1 [00:00<00:00,  9.26it/s]


Epoch [47/50], Loss: 0.0150, Acc: 100.00%


Epoch 48/50: 100%|██████████| 1/1 [00:00<00:00,  9.55it/s]


Epoch [48/50], Loss: 0.0040, Acc: 100.00%


Epoch 49/50: 100%|██████████| 1/1 [00:00<00:00,  9.58it/s]


Epoch [49/50], Loss: 0.0021, Acc: 100.00%


Epoch 50/50: 100%|██████████| 1/1 [00:00<00:00,  9.25it/s]


Epoch [50/50], Loss: 0.0038, Acc: 100.00%
Extracting support set features...


100%|██████████| 10/10 [00:00<00:00, 108.33it/s]


Extracting query set features...


100%|██████████| 1399/1399 [00:14<00:00, 96.76it/s] 


Extracting test set features...


100%|██████████| 604/604 [00:06<00:00, 98.70it/s] 


Final accuracy: 87.12%
