# Preliminary Research ProtoPNet

- ProtoPNet augments a normal CNN by storing a small batch of prototypes (like example feature matches). 
- For a new image, it finds which prototypes match part of the image, and final descision is made from those patches.
- The model can show which training patch each prototype matches and where it matched on the test image, adding to the explaination.

### Overall High Level Architecture

In [None]:
Input image
  └→ Backbone CNN → Feature map (H × W × D)
       └→ (optional) 1×1 adapter → Feature map (H × W × D')
            └→ Prototype layer (M prototypes of size D'×p×p)
                 └→ Per-prototype similarity scores (M values)
                      └→ Linear classifier (weights W: num_classes × M)
                           └→ Softmax → class probabilities

### 1. Feature Extractor (Backbone CNN)
- This part is just like a regular CNN similar to EfficientNetV2-S 
- It takes the image(256*256*3) and turns it into a feature map (a compressed version of the image where each number represents some feature like texture, edge, nucleus shape, etc)
- It transforms the raw image into useful signals 

### 2. Prototype Layer
- A prototype is a small patch in this feature map that represents a "typical pattern" for a class. Each prototype is learned during training
- When the image passes through the CNN backbone, ProtoPNet checks if any patch in the image's feature map look similar to the stored prototypes 

### 3. Simialrity Computation
- For each prototype, the model looks at every patch in the feature and asks how close it is to the prototype
- Similarity is computed using distance like squared euclidian distance
- It then picks the most similar patch via meethods like Global Max Pooling
- Each prototype

### 4.Fully Connected Layer(Classifier)
- It is a linear layer that adds up evidences from the prototype layers

### 5. Final Prediction
- The network outputs probabilities for each class.
- It also shows which prototypes were activated and where in the image they matched providing explainability 

## Training ProtoPNet

ProtoPNet is trained in 3 main phases (don’t worry, simpler than it sounds):

### 1. Joint Training:

Train CNN backbone + prototypes together with normal classification loss.

Also add two extra losses:

- Clustering loss: prototypes should be close to patches of their own class.

- Separation loss: prototypes should be far from patches of other classes.

### 2. Projection Step:

After some training, each prototype is “projected” onto the closest real patch from the training data.

This makes prototypes real and interpretable (you can actually look at them as small image patches).

### 3. Last-Layer Fine-Tuning:

The final classifier weights are optimized so that prototypes contribute correctly to class predictions.

In [None]:
Input Image (e.g., 224×224×3 cervical cell slide)
        │
        ▼
────────────────────────────────────────────
1. Backbone CNN (e.g., EfficientNetV2-S)
────────────────────────────────────────────
- Extracts features (edges, textures, nucleus shapes).
- Output: Feature Map (e.g., 7×7×1280).
        │
        ▼
────────────────────────────────────────────
2. Prototype Layer
────────────────────────────────────────────
- Contains several "prototypes" (small patches in feature space).
- Each prototype = typical pattern for a class.
  Example: "normal nucleus edge", "abnormal chromatin texture".
- For each prototype:
    Compare with all patches in feature map.
        │
        ▼
────────────────────────────────────────────
3. Similarity Computation
────────────────────────────────────────────
- Compute similarity between prototype & feature patches.
- Use distance (like L2 norm).
- Global Max Pooling → pick the strongest match.
- Output: Similarity scores for each prototype.
        │
        ▼
────────────────────────────────────────────
4. Fully Connected Layer (Linear Classifier)
────────────────────────────────────────────
- Combines similarity scores from prototypes.
- Positive weights = evidence for a class.
- Negative weights = evidence against a class.
- Example:
    Prototype #7 (abnormal nucleus) → strong for Dyskeratotic.
    Prototype #2 (smooth edge) → strong for Normal.
        │
        ▼
────────────────────────────────────────────
5. Final Prediction
────────────────────────────────────────────
- Softmax layer → Probability distribution across classes.
- Example: 
    Normal: 0.05
    Parabasal: 0.10
    Koilocytotic: 0.15
    Dyskeratotic: 0.65
    Metaplastic: 0.05
        │
        ▼
────────────────────────────────────────────
6. Explainability Output
────────────────────────────────────────────
- Shows which prototypes were activated.
- Shows where in the input image the prototype matched.
- Gives a "this looks like that" explanation.
