# Architecture Patterns: Connecting Encoders and Decoders

This notebook covers common patterns for composing transformer systems: standard encoder–decoder, parameter sharing, hierarchical skip connections, and parallel processing. It also touches on scaling strategies (width/depth/compound), efficiency (distillation, quantization, pruning), and fine-tuning approaches (adapters, LoRA).

## Learning objectives
- Understand multiple ways to connect transformer components
- Learn when to prefer each pattern
- Try a small demo of cross-attention and shape manipulations with `einops`

## Outline
1. Standard encoder–decoder and cross-attention
2. Shared-parameter encoder/decoder
3. Hierarchical skip connections
4. Parallel encoder/decoder updates
5. Scaling strategies and practical tips
6. Efficiency techniques and fine-tuning

## References (Papers)
- Vaswani et al., 2017 — "Attention Is All You Need" (arXiv:1706.03762)
- Kaplan et al., 2020 — "Scaling Laws for Neural Language Models" (arXiv:2001.08361)
- Tan & Le, 2019 — "EfficientNet: Rethinking Model Scaling" (arXiv:1905.11946)
- Hinton et al., 2015 — "Distilling the Knowledge in a Neural Network" (arXiv:1503.02531)
- Jacob et al., 2017 — "Quantization and Training of Neural Networks" (arXiv:1712.05877)
- Frankle & Carbin, 2019 — "The Lottery Ticket Hypothesis" (arXiv:1803.03635)
- Houlsby et al., 2019 — "Parameter-Efficient Transfer Learning for NLP" (arXiv:1902.00751)
- Hu et al., 2021 — "LoRA: Low-Rank Adaptation of Large Language Models" (arXiv:2106.09685)


In [None]:
# Cross-attention demo with shape ops
import torch
from einops import rearrange
from connection_patterns import CrossAttention

B, S, T, C = 2, 16, 12, 64  # batch, src_len, tgt_len, channels
query = torch.randn(T, B, C)   # nn.MultiheadAttention expects (seq, batch, embed)
key   = torch.randn(S, B, C)
value = torch.randn(S, B, C)

x = query
attn = CrossAttention(d_model=C, num_heads=4)
output, weights = attn(x, key, value)
print('Cross-attn out:', output.shape, 'weights:', weights.shape)

# Example of requested rearrange pattern
# Suppose we flattened (B*N, C) and want to restore (B, N, C)
B_, N, C_ = 2, 8, 64
flat = torch.randn(B_*N, C_)
restored = rearrange(flat, '(B N) C -> B N C', B=B_, N=N)
print('Restored shape:', restored.shape)
