In [None]:
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, List, Dict, Optional
from itertools import combinations

print(f"TensorFlow version: {tf.__version__}")

# Configure plotting
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

---

## Section 1: Few-Shot Learning Fundamentals

### Problem Formulation

**Traditional Learning:**
- Large labeled dataset
- Train for many iterations
- Deploy on same distribution

**Few-Shot Learning:**
- Small labeled dataset (N-way, K-shot)
- Quick adaptation to new task
- Generalize to new classes

### Notation
- **N-way classification:** N different classes
- **K-shot:** K labeled examples per class
- **Support set:** {(x_i, y_i)} labeled examples
- **Query set:** Unlabeled examples to classify

### Example
- 5-way 1-shot: 5 classes, 1 example each = 5 support examples
- 5-way 5-shot: 5 classes, 5 examples each = 25 support examples

In [None]:
# Generate synthetic few-shot learning dataset
np.random.seed(42)

def generate_few_shot_task(n_way: int, k_shot: int, n_query: int, 
                          feature_dim: int = 64, n_tasks: int = 100):
    """
    Generate few-shot learning tasks.
    
    Returns:
        support_x: (n_tasks, n_way*k_shot, feature_dim)
        support_y: (n_tasks, n_way*k_shot)
        query_x: (n_tasks, n_way*n_query, feature_dim)
        query_y: (n_tasks, n_way*n_query)
    """
    support_x, support_y = [], []
    query_x, query_y = [], []
    
    for _ in range(n_tasks):
        # Generate class centers
        class_centers = np.random.randn(n_way, feature_dim)
        
        task_support_x, task_support_y = [], []
        task_query_x, task_query_y = [], []
        
        for class_id in range(n_way):
            center = class_centers[class_id]
            
            # Support examples
            for _ in range(k_shot):
                example = center + np.random.randn(feature_dim) * 0.1
                task_support_x.append(example)
                task_support_y.append(class_id)
            
            # Query examples
            for _ in range(n_query):
                example = center + np.random.randn(feature_dim) * 0.1
                task_query_x.append(example)
                task_query_y.append(class_id)
        
        support_x.append(np.array(task_support_x))
        support_y.append(np.array(task_support_y))
        query_x.append(np.array(task_query_x))
        query_y.append(np.array(task_query_y))
    
    return (np.array(support_x), np.array(support_y), 
            np.array(query_x), np.array(query_y))

# Generate dataset
n_way, k_shot, n_query = 5, 1, 5
support_x, support_y, query_x, query_y = generate_few_shot_task(
    n_way, k_shot, n_query, feature_dim=64, n_tasks=1000
)

print(f"Few-Shot Learning Dataset ({n_way}-way {k_shot}-shot):")
print(f"Support set: {support_x.shape}")
print(f"Query set: {query_x.shape}")
print(f"Total tasks: {len(support_x)}")

---

## Section 2: Siamese Networks for Metric Learning

In [None]:
class SiameseNetwork(keras.Model):
    """Siamese network for metric learning."""
    
    def __init__(self, input_dim: int, embedding_dim: int = 32):
        super().__init__()
        self.input_dim = input_dim
        self.embedding_dim = embedding_dim
        
        # Shared embedding network
        self.embedding_network = keras.Sequential([
            layers.Dense(128, activation='relu', input_shape=(input_dim,)),
            layers.BatchNormalization(),
            layers.Dense(64, activation='relu'),
            layers.BatchNormalization(),
            layers.Dense(embedding_dim)
        ])
    
    def call(self, inputs):
        x1, x2 = inputs
        # Get embeddings
        z1 = self.embedding_network(x1)
        z2 = self.embedding_network(x2)
        
        # Compute distance
        distance = tf.norm(z1 - z2, axis=1)
        return distance, z1, z2
    
    def get_embedding(self, x):
        """Get embedding for input."""
        return self.embedding_network(x)

def contrastive_loss(y_true, y_pred, margin=1.0):
    """
    Contrastive loss for Siamese networks.
    
    Args:
        y_true: 1 if similar pair, 0 if dissimilar
        y_pred: Euclidean distance
        margin: Margin for dissimilar pairs
    """
    y_true = tf.cast(y_true, tf.float32)
    
    # For similar pairs: minimize distance
    similar_loss = y_true * tf.square(y_pred)
    
    # For dissimilar pairs: maximize distance
    dissimilar_loss = (1 - y_true) * tf.square(tf.maximum(margin - y_pred, 0))
    
    return tf.reduce_mean(similar_loss + dissimilar_loss)

# Create Siamese network
siamese_net = SiameseNetwork(input_dim=64, embedding_dim=32)
siamese_net.compile(optimizer='adam')

print("âœ… Siamese Network defined")

---

## Section 3: Prototypical Networks

In [None]:
class PrototypicalNetwork(keras.Model):
    """Prototypical Network for few-shot learning."""
    
    def __init__(self, input_dim: int, embedding_dim: int = 32):
        super().__init__()
        self.input_dim = input_dim
        self.embedding_dim = embedding_dim
        
        # Feature extractor
        self.feature_extractor = keras.Sequential([
            layers.Dense(128, activation='relu', input_shape=(input_dim,)),
            layers.BatchNormalization(),
            layers.Dense(64, activation='relu'),
            layers.BatchNormalization(),
            layers.Dense(embedding_dim)
        ])
    
    def call(self, inputs, training=None):
        return self.feature_extractor(inputs, training=training)
    
    def compute_prototypes(self, support_x, support_y, n_way):
        """
        Compute class prototypes (mean embeddings).
        
        Args:
            support_x: Support examples (n_support, input_dim)
            support_y: Support labels (n_support,)
            n_way: Number of classes
        
        Returns:
            prototypes: (n_way, embedding_dim)
        """
        embeddings = self(support_x)
        prototypes = []
        
        for c in range(n_way):
            class_mask = support_y == c
            class_embeddings = tf.boolean_mask(embeddings, class_mask)
            prototype = tf.reduce_mean(class_embeddings, axis=0)
            prototypes.append(prototype)
        
        return tf.stack(prototypes)
    
    def predict_query(self, prototypes, query_x):
        """
        Predict labels for query examples using prototypes.
        
        Args:
            prototypes: Class prototypes (n_way, embedding_dim)
            query_x: Query examples (n_query, input_dim)
        
        Returns:
            logits: (n_query, n_way)
        """
        query_embeddings = self(query_x)
        
        # Compute distances to all prototypes
        distances = tf.norm(
            query_embeddings[:, None, :] - prototypes[None, :, :],
            axis=2
        )
        
        # Convert distances to logits (negative distance)
        logits = -distances
        return logits

# Create prototypical network
proto_net = PrototypicalNetwork(input_dim=64, embedding_dim=32)
proto_net.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()]
)

print("âœ… Prototypical Network defined")

In [None]:
# Training Prototypical Network on few-shot tasks
print("ðŸš€ Training Prototypical Network...\n")

train_accuracies = []
val_accuracies = []

for epoch in range(20):
    # Generate training batch
    batch_support_x, batch_support_y, batch_query_x, batch_query_y = \
        generate_few_shot_task(n_way, k_shot, n_query, n_tasks=32)
    
    train_loss = 0
    train_acc = 0
    
    for task_idx in range(32):
        support_x = batch_support_x[task_idx]
        support_y = batch_support_y[task_idx]
        query_x = batch_query_x[task_idx]
        query_y = batch_query_y[task_idx]
        
        with tf.GradientTape() as tape:
            # Compute prototypes
            prototypes = proto_net.compute_prototypes(support_x, support_y, n_way)
            
            # Predict query labels
            logits = proto_net.predict_query(prototypes, query_x)
            
            # Compute loss
            loss = proto_net.loss(query_y, logits)
        
        # Backpropagation
        gradients = tape.gradient(loss, proto_net.trainable_variables)
        proto_net.optimizer.apply_gradients(zip(gradients, proto_net.trainable_variables))
        
        # Compute accuracy
        predictions = tf.argmax(logits, axis=1)
        accuracy = tf.reduce_mean(tf.cast(predictions == query_y, tf.float32))
        
        train_loss += loss.numpy()
        train_acc += accuracy.numpy()
    
    train_loss /= 32
    train_acc /= 32
    
    # Validation
    val_support_x, val_support_y, val_query_x, val_query_y = \
        generate_few_shot_task(n_way, k_shot, n_query, n_tasks=10)
    
    val_acc = 0
    for task_idx in range(10):
        support_x = val_support_x[task_idx]
        support_y = val_support_y[task_idx]
        query_x = val_query_x[task_idx]
        query_y = val_query_y[task_idx]
        
        prototypes = proto_net.compute_prototypes(support_x, support_y, n_way)
        logits = proto_net.predict_query(prototypes, query_x)
        predictions = tf.argmax(logits, axis=1)
        accuracy = tf.reduce_mean(tf.cast(predictions == query_y, tf.float32))
        val_acc += accuracy.numpy()
    
    val_acc /= 10
    
    train_accuracies.append(train_acc)
    val_accuracies.append(val_acc)
    
    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch + 1}/20 | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

print(f"\nâœ… Training complete!")
print(f"Final validation accuracy: {val_accuracies[-1]:.4f}")

---

## Section 4: Key Takeaways

### Meta-Learning Approaches

| Approach | Key Idea | Pros | Cons |
|----------|----------|------|------|
| **Metric Learning** | Learn distance metric | Simple, interpretable | Requires good embeddings |
| **Prototypical Nets** | Mean class embeddings | Efficient, works well | Assumes Gaussian dist. |
| **MAML** | Meta-gradient updates | Theoretically sound | Computationally expensive |
| **Siamese Nets** | Pairwise comparison | Flexible | Requires careful pair selection |

### Few-Shot Learning Challenges
1. **Data Scarcity:** Limited labeled examples
2. **Distribution Shift:** Domain adaptation needed
3. **Task Variability:** Different tasks, varying difficulty
4. **Overfitting Risk:** Small datasets â†’ high variance

### Practical Applications
- Character recognition (Omniglot dataset)
- Image classification (miniImageNet)
- Face recognition with new identities
- Personalized recommendation systems
- Rapid model adaptation to new domains

In [None]:
# Visualize learning curves
plt.figure(figsize=(10, 5))
plt.plot(train_accuracies, label='Train Accuracy', linewidth=2)
plt.plot(val_accuracies, label='Val Accuracy', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Prototypical Network Training Progress')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print("""
ðŸ“š Meta-Learning & Few-Shot Learning - Summary
==============================================

âœ… Topics Covered:
  â€¢ Few-shot learning problem formulation
  â€¢ Metric learning and embeddings
  â€¢ Siamese networks
  â€¢ Prototypical networks
  â€¢ Training on few-shot tasks
  â€¢ Evaluation on novel classes

ðŸ’¡ Key Insights:
  â€¢ Few-shot learning enables rapid adaptation
  â€¢ Metric learning is core to modern approaches
  â€¢ Task-based training is more effective than standard training
  â€¢ Prototypical networks are simple yet effective

ðŸŽ¯ Next Steps:
  1. Try different embedding dimensions
  2. Experiment with various distance metrics
  3. Apply to real datasets (Omniglot, miniImageNet)
  4. Implement MAML for better performance
  5. Explore multi-task meta-learning
""")