In [None]:
# ViT-MultiRAGNet

| Paper Block                        | Implemented As                              |
| ---------------------------------- | ------------------------------------------- |
| Patch Partition + Linear Embedding | ViT patch embedding (timm)                  |
| Transformer Encoder Stack          | ViT encoder (L layers)                      |
| CLS Token                          | `feats[:,0]`                                |
| Clinical Data (8-D vector)         | Metadata → MLP                              |
| Clinical Embedding & Query Gen     | Image CLS + Clinical embedding              |
| Knowledge Base                     | FAISS feature memory                        |
| Top-K Retrieval                    | FAISS `search()`                            |
| Cross-Attention RAG Fusion         | `nn.MultiheadAttention`                     |
| Final Fusion                       | Residual fusion                             |
| Classification Head                | MLP                                         |
| Segmentation Decoder               | U-Net-style decoder                         |

In [1]:
!pip install -q timm faiss-cpu albumentations opencv-python scikit-learn

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gensim 4.3.0 requires FuzzyTM>=0.4.0, which is not installed.
tables 3.8.0 requires blosc2~=2.0.0, which is not installed.
tables 3.8.0 requires cython>=0.29.21, which is not installed.
astropy 5.3.4 requires numpy<2,>=1.21, but you have numpy 2.4.1 which is incompatible.
langchain 0.1.16 requires langchain-core<0.2.0,>=0.1.42, but you have langchain-core 0.3.79 which is incompatible.
langchain 0.1.16 requires langsmith<0.2.0,>=0.1.17, but you have langsmith 0.4.42 which is incompatible.
langchain 0.1.16 requires numpy<2,>=1, but you have numpy 2.4.1 which is incompatible.
langchain 0.1.16 requires tenacity<9.0.0,>=8.1.0, but you have tenacity 9.0.0 which is incompatible.
langchain-community 0.0.38 requires langchain-core<0.2.0,>=0.1.52, but you have langchain-core 0.3.79 which is incompatible.
langchain-community

In [2]:
# STAGE 1 — IMPORTS & CONFIG
import os, cv2, torch, faiss, timm
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMAGE_SIZE = 224
BATCH_SIZE = 4
EPOCHS = 10
LR = 1e-4
BASE = "k_CBIS-DDSM"  # Update this path to your dataset location

print(f"Using device: {DEVICE}")

ImportError: DLL load failed while importing _swigfaiss: The specified module could not be found.

In [None]:
# STAGE 2 — LOAD CSV (CBIS-DDSM Dataset)
calc = pd.read_csv(f"{BASE}/calc_case(with_jpg_img).csv")
mass = pd.read_csv(f"{BASE}/mass_case(with_jpg_img).csv")
df = pd.concat([calc, mass]).reset_index(drop=True)
df = df[df["jpg_fullMammo_img_path"].notna()]

print(f"Total samples: {len(df)}")
print(f"Pathology distribution:\n{df['pathology'].value_counts()}")

In [None]:
# STAGE 3 — DATASET (IMAGE + CLINICAL VECTOR)
class CBISDataset(Dataset):
    def __init__(self, df, augment=False):
        self.df = df.reset_index(drop=True)
        self.tf = A.Compose([
            A.Resize(IMAGE_SIZE, IMAGE_SIZE),
            A.HorizontalFlip(p=0.5 if augment else 0),
            A.Normalize(),
            ToTensorV2()
        ])

    def __len__(self): 
        return len(self.df)

    def __getitem__(self, idx):
        r = self.df.iloc[idx]

        img = cv2.imread(r["jpg_fullMammo_img_path"])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # Load ROI mask if available
        if isinstance(r["jpg_ROI_img_path"], str) and os.path.exists(r["jpg_ROI_img_path"]):
            mask = cv2.imread(r["jpg_ROI_img_path"], 0)
        else:
            mask = np.zeros(img.shape[:2], np.uint8)

        t = self.tf(image=img, mask=mask)
        img, mask = t["image"], t["mask"].unsqueeze(0).float() / 255.0

        # 8-D Clinical vector (normalized)
        clinical = torch.tensor([
            r["assessment"] / 5,
            r["subtlety"] / 5,
            r["breast_density"] / 4,
            1 if r["image view"] == "MLO" else 0,
            1 if r["left or right breast"] == "LEFT" else 0,
            1 if r["abnormality type"] == "mass" else 0,
            0,  # Placeholder for additional clinical features
            0   # Placeholder for additional clinical features
        ], dtype=torch.float)

        label = torch.tensor(1 if "MALIGNANT" in r["pathology"].upper() else 0)

        return img, mask, clinical, label

print("Dataset class defined ✓")

In [None]:
# STAGE 4 — ViT ENCODER (PATCH + TRANSFORMER STACK)
class ViTEncoder(nn.Module):
    """
    Paper Block: Patch Partition + Linear Embedding + Transformer Encoder Stack
    Uses pretrained ViT-Base with 16x16 patches
    """
    def __init__(self):
        super().__init__()
        self.vit = timm.create_model(
            "vit_base_patch16_224",
            pretrained=True,
            num_classes=0  # Remove classification head
        )

    def forward(self, x):
        # Returns all patch embeddings including CLS token
        # Shape: (B, 197, 768) -> 196 patches + 1 CLS token
        return self.vit.forward_features(x)

print("ViT Encoder defined ✓")

In [None]:
# STAGE 5 — RAG MEMORY (KNOWLEDGE BASE)
class RAGMemory:
    """
    Paper Block: Knowledge Base + Top-K Retrieval
    FAISS-based feature memory for retrieval-augmented generation
    """
    def __init__(self, dim=768):
        self.dim = dim
        self.index = faiss.IndexFlatL2(dim)
        self.store = []

    def add(self, feats):
        """Add feature vectors to the knowledge base"""
        if isinstance(feats, torch.Tensor):
            feats = feats.cpu().numpy()
        feats = np.ascontiguousarray(feats.astype(np.float32))
        self.index.add(feats)
        self.store.extend(feats)

    def query(self, q, k=5):
        """Retrieve top-k similar features"""
        if isinstance(q, torch.Tensor):
            q = q.cpu().numpy()
        q = np.ascontiguousarray(q.astype(np.float32))
        _, idx = self.index.search(q, k)
        
        # Handle case where store might not have enough items
        results = []
        for row in idx:
            row_feats = []
            for i in row:
                if i >= 0 and i < len(self.store):
                    row_feats.append(self.store[i])
                else:
                    row_feats.append(np.zeros(self.dim, dtype=np.float32))
            results.append(row_feats)
        
        return torch.tensor(np.array(results), dtype=torch.float32)

    def size(self):
        return self.index.ntotal

print("RAG Memory defined ✓")

In [None]:
# STAGE 6 — CROSS-ATTENTION RAG FUSION (CORE OF PAPER)
class CrossAttention(nn.Module):
    """
    Paper Block: Cross-Attention RAG Fusion
    Query = Image CLS + Clinical embedding
    Key/Value = Retrieved features from knowledge base
    """
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)

    def forward(self, q, kv):
        # q: (B, dim) -> (B, 1, dim)
        # kv: (B, K, dim) where K = top-k retrieved features
        out, _ = self.attn(q.unsqueeze(1), kv, kv)
        return out.squeeze(1)  # (B, dim)

print("Cross-Attention module defined ✓")

In [None]:
# STAGE 7 — ViT-MultiRAGNet (FINAL MODEL)
class ViTMultiRAGNet(nn.Module):
    """
    Paper-faithful implementation of ViT-MultiRAGNet
    - ViT Encoder for image feature extraction
    - Clinical data MLP for metadata embedding  
    - Cross-attention RAG fusion with knowledge base
    - Dual heads: Classification + Segmentation
    """
    def __init__(self, num_classes=2):
        super().__init__()
        self.encoder = ViTEncoder()
        self.clin_fc = nn.Linear(8, 768)  # Clinical data embedding
        self.cross_attn = CrossAttention(768)
        
        # Classification head (MLP)
        self.cls_head = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
        
        # Segmentation decoder (U-Net style)
        self.seg_head = nn.Sequential(
            nn.Conv2d(768, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 1, 1)
        )

    def forward(self, x, clinical, rag_feats):
        # ViT encoding
        feats = self.encoder(x)  # (B, 197, 768)
        cls_token = feats[:, 0]  # CLS token (B, 768)
        
        # Clinical embedding
        clinical_emb = self.clin_fc(clinical)  # (B, 768)
        
        # Query generation: CLS + Clinical
        query = cls_token + clinical_emb
        
        # Cross-attention RAG fusion
        fused = self.cross_attn(query, rag_feats)  # (B, 768)
        
        # Residual fusion
        fused = fused + query
        
        # Classification output
        cls_out = self.cls_head(fused)
        
        # Segmentation output
        B, N, C = feats.shape
        h = w = int(np.sqrt(N - 1))  # 14x14 for ViT-Base
        fmap = feats[:, 1:].permute(0, 2, 1).reshape(B, C, h, w)
        seg_out = self.seg_head(fmap)
        seg_out = F.interpolate(seg_out, size=(IMAGE_SIZE, IMAGE_SIZE), mode='bilinear', align_corners=False)
        
        return seg_out, cls_out

print("ViT-MultiRAGNet model defined ✓")

In [None]:
# STAGE 8 — DATA PREPARATION & TRAINING SETUP
train_df, val_df = train_test_split(df, test_size=0.2, stratify=df["pathology"], random_state=42)
train_loader = DataLoader(CBISDataset(train_df, augment=True), batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(CBISDataset(val_df, augment=False), batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"Training samples: {len(train_df)}")
print(f"Validation samples: {len(val_df)}")

In [None]:
# STAGE 8B — MODEL & OPTIMIZER INITIALIZATION
model = ViTMultiRAGNet().to(DEVICE)
rag = RAGMemory()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
seg_criterion = nn.BCEWithLogitsLoss()
cls_criterion = nn.CrossEntropyLoss()

print(f"Model initialized on {DEVICE}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# STAGE 8C — BUILD KNOWLEDGE BASE (POPULATE RAG MEMORY)
print("Building knowledge base from training data...")
model.eval()
with torch.no_grad():
    for img, _, clin, _ in tqdm(train_loader, desc="Building RAG Memory"):
        img = img.to(DEVICE)
        feats = model.encoder(img)[:, 0]  # CLS token features
        rag.add(feats.cpu().numpy())

print(f"Knowledge base populated with {rag.size()} feature vectors ✓")

In [None]:
# STAGE 9 — TRAINING LOOP
print("Starting training...")

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for img, mask, clin, y in pbar:
        img = img.to(DEVICE)
        mask = mask.to(DEVICE)
        clin = clin.to(DEVICE)
        y = y.to(DEVICE)
        
        # Query RAG memory with current features
        with torch.no_grad():
            q = model.encoder(img)[:, 0].detach().cpu().numpy()
        rag_feats = rag.query(q).to(DEVICE)
        
        # Forward pass
        seg_out, cls_out = model(img, clin, rag_feats)
        
        # Compute losses
        seg_loss = seg_criterion(seg_out, mask)
        cls_loss = cls_criterion(cls_out, y)
        loss = seg_loss + cls_loss
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item()
        pred = cls_out.argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
        
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100*correct/total:.1f}%'
        })
    
    epoch_loss = total_loss / len(train_loader)
    epoch_acc = 100 * correct / total
    print(f"Epoch {epoch+1}: Loss = {epoch_loss:.4f}, Train Acc = {epoch_acc:.2f}%")

In [None]:
# STAGE 10 — VALIDATION & EVALUATION
print("Running validation...")
model.eval()
preds, gts = [], []

with torch.no_grad():
    for img, _, clin, y in tqdm(val_loader, desc="Validating"):
        img = img.to(DEVICE)
        clin = clin.to(DEVICE)
        
        q = model.encoder(img)[:, 0].cpu().numpy()
        rag_feats = rag.query(q).to(DEVICE)
        
        _, cls_out = model(img, clin, rag_feats)
        preds.extend(cls_out.argmax(1).cpu().numpy())
        gts.extend(y.numpy())

print(f"\n{'='*50}")
print(f"Validation Accuracy: {accuracy_score(gts, preds)*100:.2f}%")
print(f"{'='*50}")
print("\nClassification Report:")
print(classification_report(gts, preds, target_names=['BENIGN', 'MALIGNANT']))
print("\nConfusion Matrix:")
print(confusion_matrix(gts, preds))