# 2. The Neural Network Model

This notebook explains the `EmbeddingClassifier` model - a simple 2-layer neural network that:
1. Projects embeddings to a new space (optional)
2. Classifies them into species/categories

## Why do we need a model?

We already have embeddings from BirdNET/Perch, so why train another model?

**Answer:** The pre-trained embeddings are general-purpose. We want to:
- Adapt them to OUR specific dataset
- Learn which features matter most for OUR classification task

## Model Architecture
The model consist of two customisable components 1) A **projection layer** which projects feature embeddings to a new embedding of the same dimension (this is visualising learning, the feature embedding doesn't change), 2) A **classification head** with outputs corresponding to the number of classes.

```
Input Embedding (1024 dims)
        ↓
   [Linear Layer] ← projection layer
        ↓
Hidden Embedding (1024 dims or custom)
        ↓
   [Linear Layer] ← classification head
        ↓
Class Logits (23 classes)
```

In [1]:
import torch
import torch.nn as nn
import sys
from pathlib import Path

# Add parent directory to path to import our model
sys.path.insert(0, str(Path().absolute().parent.parent))
from core.model import EmbeddingClassifier

# Create a model instance
model = EmbeddingClassifier(
    input_dim=1024,    # BirdNET embedding size
    hidden_dim=512,    # Reduce to 512 dims (optional)
    num_classes=23     # Number of bird species
)

print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

EmbeddingClassifier(
  (projection): Linear(in_features=1024, out_features=512, bias=True)
  (classifier): Linear(in_features=512, out_features=23, bias=True)
)

Total parameters: 536,599


## How the Model Works

Let's trace what happens to a single embedding:

In [2]:
# Create a fake embedding (batch_size=1, features=1024)
fake_embedding = torch.randn(1, 1024)

print("Input shape:", fake_embedding.shape)

# Forward pass
with torch.no_grad():  # Don't compute gradients for this demo
    logits = model(fake_embedding)
    
print("Output shape:", logits.shape)
print(f"\nThese are 'logits' - raw scores for each of {logits.shape[1]} classes")
print("First 5 logits:", logits[0, :5])

Input shape: torch.Size([1, 1024])
Output shape: torch.Size([1, 23])

These are 'logits' - raw scores for each of 23 classes
First 5 logits: tensor([-0.2530,  0.0870,  0.2666,  2.8586,  0.6068])


## Converting Logits to Probabilities

Logits are raw scores. To get probabilities (that sum to 1), we use **softmax**:

In [3]:
import torch.nn.functional as F

with torch.no_grad():
    logits = model(fake_embedding)
    probabilities = F.softmax(logits, dim=1)
    
print("Probabilities shape:", probabilities.shape)
print("Sum of probabilities:", probabilities.sum().item())
print("\nTop 5 class probabilities:")
top5_probs, top5_classes = torch.topk(probabilities, 5)
for i in range(5):
    print(f"  Class {top5_classes[0, i].item()}: {top5_probs[0, i].item():.4f}")

Probabilities shape: torch.Size([1, 23])
Sum of probabilities: 0.9999999403953552

Top 5 class probabilities:
  Class 3: 0.3044
  Class 9: 0.1671
  Class 22: 0.0641
  Class 18: 0.0629
  Class 15: 0.0601


## Getting the Intermediate Embedding

The model has a special method to extract the hidden/intermediate embedding. This is useful for:
- Visualization (we can reduce 512 dims to 3D using PCA)
- Understanding what the model learnt

In [None]:
with torch.no_grad():
    hidden_embedding = model.get_embedding(fake_embedding)
    
print("Hidden embedding shape:", hidden_embedding.shape)
print("This is the 'learned' representation of our input")
print("\nFirst 10 values:", hidden_embedding[0, :10])

## Training the Model

Training happens via standard supervised learning:
1. Forward pass: Get predictions
2. Compute loss: Compare predictions to true labels
3. Backward pass: Calculate gradients
4. Update weights: Adjust model parameters

Here's a minimal training loop:

In [None]:
# Create fake training data
batch_size = 8
X_batch = torch.randn(batch_size, 1024)  # 8 embeddings
y_batch = torch.randint(0, 23, (batch_size,))  # 8 random labels (0-22)

# Setup training
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Single training step
model.train()
optimizer.zero_grad()

# Forward
logits = model(X_batch)
loss = criterion(logits, y_batch)

# Backward
loss.backward()
optimizer.step()

# Evaluate
with torch.no_grad():
    predictions = torch.argmax(logits, dim=1)
    accuracy = (predictions == y_batch).float().mean()

print(f"Loss: {loss.item():.4f}")
print(f"Accuracy: {accuracy.item():.4f}")
print(f"\nPredictions: {predictions.tolist()}")
print(f"True labels:  {y_batch.tolist()}")

## Summary

**The model does 3 things:**
1. **Projects** input embeddings to a (possibly smaller) hidden space
2. **Classifies** hidden embeddings into species categories
3. **Learns** which features are important through training

**Key insight:** As we add more labeled data through active learning, this model gets better at:
- Distinguishing between species
- Creating meaningful embeddings
- Generalizing to unlabeled data