# 🧠 PyTorch Fundamentals for AI Assistant Development

Welcome to Phase 1 of building your AI assistant from scratch! This notebook covers essential PyTorch concepts needed for implementing transformer architectures and language models.

## 📚 Learning Objectives

By the end of this notebook, you will:
- Master PyTorch tensor operations and automatic differentiation
- Understand neural network building blocks
- Implement attention mechanisms from scratch
- Build transformer components
- Gain hands-on experience with gradient computation and optimization

## 🎯 Prerequisites

- Basic Python programming knowledge
- Familiarity with linear algebra and calculus
- Understanding of machine learning concepts

Let's start building your AI assistant from the ground up! 🚀

## 1. Import Required Libraries

First, let's import all the libraries we'll need for our PyTorch fundamentals exploration.

In [None]:
# Essential imports for PyTorch and deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset

# Numerical and visualization libraries
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Optional, Tuple, List, Dict, Any
import math
import random

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

print("✅ Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. PyTorch Tensor Operations

Tensors are the fundamental building blocks of PyTorch. Let's explore tensor creation, manipulation, and operations that are essential for building neural networks.

In [None]:
# 2.1 Tensor Creation
print("=== Tensor Creation ===")

# Create tensors in different ways
zeros_tensor = torch.zeros(3, 4)
ones_tensor = torch.ones(2, 3)
random_tensor = torch.randn(2, 3, 4)  # Normal distribution
uniform_tensor = torch.rand(3, 3)     # Uniform [0, 1]

print(f"Zeros tensor shape: {zeros_tensor.shape}")
print(f"Random tensor:\n{random_tensor[0]}")  # Show first matrix

# From Python lists/arrays
list_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
numpy_array = np.array([[1, 2], [3, 4]])
from_numpy = torch.from_numpy(numpy_array)

print(f"From list: {list_tensor}")
print(f"From numpy: {from_numpy}")

# 2.2 Tensor Properties and Device Management
print("\n=== Tensor Properties ===")
x = torch.randn(3, 4, 5)
print(f"Shape: {x.shape}")
print(f"Data type: {x.dtype}")
print(f"Device: {x.device}")
print(f"Number of elements: {x.numel()}")

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x_gpu = x.to(device)
print(f"Tensor on device: {x_gpu.device}")

# 2.3 Essential Tensor Operations
print("\n=== Essential Operations ===")

# Matrix operations
A = torch.randn(3, 4)
B = torch.randn(4, 5)
C = torch.matmul(A, B)  # Matrix multiplication
print(f"Matrix multiplication: {A.shape} @ {B.shape} = {C.shape}")

# Element-wise operations
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([4.0, 5.0, 6.0])

print(f"Addition: {x + y}")
print(f"Multiplication: {x * y}")
print(f"Power: {torch.pow(x, 2)}")

# Reduction operations
data = torch.randn(3, 4)
print(f"Sum: {torch.sum(data)}")
print(f"Mean: {torch.mean(data)}")
print(f"Max: {torch.max(data)}")
print(f"Sum along axis 1: {torch.sum(data, dim=1)}")

# Reshaping and view operations
original = torch.randn(2, 3, 4)
reshaped = original.view(6, 4)      # Change shape
flattened = original.flatten()      # Flatten to 1D
print(f"Original: {original.shape} -> Reshaped: {reshaped.shape} -> Flattened: {flattened.shape}")

## 3. Automatic Differentiation (Autograd)

Automatic differentiation is the heart of PyTorch's neural network training. Let's explore how gradients are computed automatically.

In [None]:
# 3.1 Basic Gradient Computation
print("=== Basic Gradients ===")

# Simple function: y = x^2 + 3x + 1
x = torch.tensor([2.0], requires_grad=True)
y = x**2 + 3*x + 1

print(f"x = {x.item()}")
print(f"y = x^2 + 3x + 1 = {y.item()}")

# Compute gradient dy/dx
y.backward()
print(f"dy/dx = 2x + 3 = {x.grad.item()}")
print(f"At x=2: dy/dx = {2*2 + 3} (analytical) vs {x.grad.item()} (autograd)")

# 3.2 Gradient Computation with Vectors
print("\n=== Vector Gradients ===")

# Multi-variable function
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = torch.sum(x**2)

print(f"x = {x}")
print(f"y = sum(x^2) = {y.item()}")

y.backward()
print(f"dy/dx = 2x = {x.grad}")

# 3.3 Gradient Accumulation
print("\n=== Gradient Accumulation ===")

x = torch.tensor([1.0], requires_grad=True)

# First computation
y1 = x**2
y1.backward()
print(f"After first backward: x.grad = {x.grad}")

# Second computation (gradients accumulate!)
y2 = x**3
y2.backward()
print(f"After second backward: x.grad = {x.grad}")

# Clear gradients
x.grad.zero_()
print(f"After zero_(): x.grad = {x.grad}")

# 3.4 Computational Graph Example
print("\n=== Computational Graph ===")

# Create a more complex computation
a = torch.tensor([2.0], requires_grad=True)
b = torch.tensor([3.0], requires_grad=True)

# Forward pass
c = a * b      # c = a * b
d = c + a      # d = c + a = a * b + a
e = d**2       # e = d^2 = (a * b + a)^2

print(f"a = {a.item()}, b = {b.item()}")
print(f"c = a * b = {c.item()}")
print(f"d = c + a = {d.item()}")
print(f"e = d^2 = {e.item()}")

# Backward pass
e.backward()

print(f"de/da = {a.grad.item()}")
print(f"de/db = {b.grad.item()}")

# Manual verification:
# e = (a*b + a)^2 = (a*(b+1))^2
# de/da = 2*(a*(b+1))*(b+1) = 2*a*(b+1)^2
# de/db = 2*(a*(b+1))*a = 2*a^2*(b+1)
print(f"Manual de/da = 2*a*(b+1)^2 = {2*a.item()*(b.item()+1)**2}")
print(f"Manual de/db = 2*a^2*(b+1) = {2*a.item()**2*(b.item()+1)}")

# 3.5 Higher-order Derivatives
print("\n=== Higher-order Derivatives ===")

x = torch.tensor([2.0], requires_grad=True)
y = x**4

# First derivative
grad1 = torch.autograd.grad(y, x, create_graph=True)[0]
print(f"First derivative dy/dx = {grad1.item()}")

# Second derivative
grad2 = torch.autograd.grad(grad1, x)[0]
print(f"Second derivative d²y/dx² = {grad2.item()}")

# Manual verification: y = x^4, dy/dx = 4x^3, d²y/dx² = 12x^2
print(f"Manual: dy/dx = 4x^3 = {4 * x.item()**3}")
print(f"Manual: d²y/dx² = 12x^2 = {12 * x.item()**2}")

## 4. Neural Network Building Blocks

Now let's build the fundamental components of neural networks using PyTorch's nn module.

In [None]:
# 4.1 Basic Neural Network Layers
print("=== Basic Neural Network Layers ===")

# Linear (fully connected) layer
input_size, output_size = 10, 5
linear_layer = nn.Linear(input_size, output_size)

print(f"Linear layer: {input_size} -> {output_size}")
print(f"Weight shape: {linear_layer.weight.shape}")
print(f"Bias shape: {linear_layer.bias.shape}")

# Test with input
x = torch.randn(3, input_size)  # Batch of 3 samples
output = linear_layer(x)
print(f"Input shape: {x.shape} -> Output shape: {output.shape}")

# 4.2 Activation Functions
print("\n=== Activation Functions ===")

# Common activation functions
x = torch.linspace(-3, 3, 100)

# ReLU
relu = F.relu(x)
leaky_relu = F.leaky_relu(x, negative_slope=0.1)

# Sigmoid and Tanh
sigmoid = torch.sigmoid(x)
tanh = torch.tanh(x)

# GELU (used in transformers)
gelu = F.gelu(x)

# Plot activation functions
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.ravel()

activations = [
    (relu, "ReLU"),
    (leaky_relu, "Leaky ReLU"),
    (sigmoid, "Sigmoid"),
    (tanh, "Tanh"),
    (gelu, "GELU"),
]

for i, (activation, name) in enumerate(activations):
    axes[i].plot(x.numpy(), activation.numpy())
    axes[i].set_title(name)
    axes[i].grid(True)
    axes[i].set_xlabel('x')
    axes[i].set_ylabel(f'{name}(x)')

# Remove empty subplot
axes[-1].remove()

plt.tight_layout()
plt.show()

# 4.3 Building a Simple Neural Network
print("\n=== Simple Neural Network ===")

class SimpleNN(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, output_size: int):
        super(SimpleNN, self).__init__()
        self.layer1 = nn.Linear(input_size, hidden_size)
        self.layer2 = nn.Linear(hidden_size, hidden_size)
        self.layer3 = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.layer1(x))
        x = self.dropout(x)
        x = F.relu(self.layer2(x))
        x = self.dropout(x)
        x = self.layer3(x)
        return x

# Create and test the network
model = SimpleNN(input_size=784, hidden_size=128, output_size=10)
print(f"Model architecture:\n{model}")

# Count parameters
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params:,}")

# Test forward pass
dummy_input = torch.randn(32, 784)  # Batch of 32 MNIST-like images
output = model(dummy_input)
print(f"Input: {dummy_input.shape} -> Output: {output.shape}")

# 4.4 Loss Functions and Optimization
print("\n=== Loss Functions and Optimization ===")

# Cross-entropy loss for classification
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Dummy training step
model.train()  # Set to training mode

# Forward pass
logits = model(dummy_input)
targets = torch.randint(0, 10, (32,))  # Random class labels

# Compute loss
loss = criterion(logits, targets)
print(f"Loss: {loss.item():.4f}")

# Backward pass
optimizer.zero_grad()  # Clear gradients
loss.backward()        # Compute gradients
optimizer.step()       # Update parameters

print(f"Loss after one step: {criterion(model(dummy_input), targets).item():.4f}")

# 4.5 Different Loss Functions
print("\n=== Different Loss Functions ===")

# Generate some data for demonstration
batch_size = 16
predictions = torch.randn(batch_size, 10)
targets_classification = torch.randint(0, 10, (batch_size,))
targets_regression = torch.randn(batch_size, 10)

# Classification losses
cross_entropy = nn.CrossEntropyLoss()(predictions, targets_classification)
print(f"Cross Entropy Loss: {cross_entropy.item():.4f}")

# Regression losses
mse_loss = nn.MSELoss()(predictions, targets_regression)
mae_loss = nn.L1Loss()(predictions, targets_regression)
huber_loss = nn.HuberLoss()(predictions, targets_regression)

print(f"MSE Loss: {mse_loss.item():.4f}")
print(f"MAE Loss: {mae_loss.item():.4f}")
print(f"Huber Loss: {huber_loss.item():.4f}")

## 5. Attention Mechanism Implementation

The attention mechanism is the heart of transformer models. Let's implement it from scratch to understand how it works.

In [None]:
# 5.1 Simple Attention Mechanism
print("=== Simple Attention Mechanism ===")

class SimpleAttention(nn.Module):
    """Simple attention mechanism for understanding the concept."""
    
    def __init__(self, hidden_size: int):
        super(SimpleAttention, self).__init__()
        self.hidden_size = hidden_size
        self.W_q = nn.Linear(hidden_size, hidden_size, bias=False)
        self.W_k = nn.Linear(hidden_size, hidden_size, bias=False)
        self.W_v = nn.Linear(hidden_size, hidden_size, bias=False)
        
    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            query: [batch_size, seq_len, hidden_size]
            key: [batch_size, seq_len, hidden_size]  
            value: [batch_size, seq_len, hidden_size]
            mask: Optional attention mask
        """
        # Transform inputs
        Q = self.W_q(query)  # [batch_size, seq_len, hidden_size]
        K = self.W_k(key)    # [batch_size, seq_len, hidden_size]
        V = self.W_v(value)  # [batch_size, seq_len, hidden_size]
        
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.hidden_size)
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Apply softmax to get attention weights
        attention_weights = F.softmax(scores, dim=-1)
        
        # Apply attention to values
        output = torch.matmul(attention_weights, V)
        
        return output, attention_weights

# Test simple attention
batch_size, seq_len, hidden_size = 2, 8, 64
attention = SimpleAttention(hidden_size)

# Create sample sequences
sequences = torch.randn(batch_size, seq_len, hidden_size)

# Self-attention (query, key, value are the same)
output, weights = attention(sequences, sequences, sequences)

print(f"Input shape: {sequences.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")

# Visualize attention weights for first sample
plt.figure(figsize=(8, 6))
plt.imshow(weights[0].detach().numpy(), cmap='Blues')
plt.colorbar()
plt.title('Attention Weights Visualization')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.show()

# 5.2 Multi-Head Attention
print("\n=== Multi-Head Attention ===")

class MultiHeadAttention(nn.Module):
    """Multi-Head Attention as used in Transformers."""
    
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear transformations for Q, K, V
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def scaled_dot_product_attention(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
                                   mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute scaled dot-product attention."""
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        output = torch.matmul(attention_weights, V)
        return output, attention_weights
    
    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        batch_size, seq_len, d_model = query.size()
        
        # Linear transformations and reshape for multi-head attention
        Q = self.W_q(query).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # Apply scaled dot-product attention
        if mask is not None:
            mask = mask.unsqueeze(1)  # Add head dimension
        
        attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Concatenate heads and put through final linear layer
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, d_model
        )
        
        output = self.W_o(attention_output)
        return output

# Test multi-head attention
d_model, num_heads = 512, 8
mha = MultiHeadAttention(d_model, num_heads)

# Test input
batch_size, seq_len = 2, 10
x = torch.randn(batch_size, seq_len, d_model)

# Forward pass
output = mha(x, x, x)  # Self-attention
print(f"Multi-head attention input: {x.shape}")
print(f"Multi-head attention output: {output.shape}")

# Count parameters
mha_params = sum(p.numel() for p in mha.parameters())
print(f"Multi-head attention parameters: {mha_params:,}")

# 5.3 Positional Encoding
print("\n=== Positional Encoding ===")

class PositionalEncoding(nn.Module):
    """Positional encoding for transformer models."""
    
    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # Create div_term for sine and cosine functions
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           (-math.log(10000.0) / d_model))
        
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Register as buffer (not a parameter)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Add positional encoding to input embeddings."""
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

# Test positional encoding
pos_encoding = PositionalEncoding(d_model=128, max_len=100)

# Create sample embeddings
seq_len, batch_size, d_model = 20, 4, 128
embeddings = torch.randn(seq_len, batch_size, d_model)

# Apply positional encoding
encoded = pos_encoding(embeddings)
print(f"Embeddings shape: {embeddings.shape}")
print(f"Positionally encoded shape: {encoded.shape}")

# Visualize positional encoding
plt.figure(figsize=(12, 6))
pe_matrix = pos_encoding.pe[:50, 0, :].numpy()  # First 50 positions
plt.imshow(pe_matrix.T, cmap='RdBu', aspect='auto')
plt.colorbar()
plt.title('Positional Encoding Visualization')
plt.xlabel('Position')
plt.ylabel('Embedding Dimension')
plt.show()

print("Position encoding provides unique patterns for each position!")

## 6. Mini Transformer Block

Let's combine everything we've learned to build a basic transformer block!

In [None]:
# 6.1 Feed-Forward Network
class FeedForward(nn.Module):
    """Position-wise feed-forward network."""
    
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

# 6.2 Transformer Block
class TransformerBlock(nn.Module):
    """A single transformer decoder block."""
    
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super(TransformerBlock, self).__init__()
        
        # Multi-head self-attention
        self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
        
        # Feed-forward network
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Transformer block forward pass with residual connections."""
        # Self-attention with residual connection and layer norm
        attn_output = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed-forward with residual connection and layer norm
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x

# 6.3 Mini Language Model
class MiniLanguageModel(nn.Module):
    """A very simple language model using transformer blocks."""
    
    def __init__(self, vocab_size: int, d_model: int, num_heads: int, 
                 num_layers: int, d_ff: int, max_seq_len: int = 1024, dropout: float = 0.1):
        super(MiniLanguageModel, self).__init__()
        
        self.d_model = d_model
        self.vocab_size = vocab_size
        
        # Token embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        
        # Positional encoding
        self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
        
        # Transformer blocks
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        # Output projection
        self.output_projection = nn.Linear(d_model, vocab_size)
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        """Initialize model weights."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.xavier_uniform_(module.weight)
    
    def create_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
        """Create causal mask for autoregressive generation."""
        mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
        return mask.unsqueeze(0).unsqueeze(0)
    
    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len = input_ids.size()
        device = input_ids.device
        
        # Token embeddings (scaled by sqrt(d_model))
        token_embeds = self.token_embedding(input_ids) * math.sqrt(self.d_model)
        
        # Add positional encoding
        # Transpose for positional encoding (expects [seq_len, batch_size, d_model])
        token_embeds = token_embeds.transpose(0, 1)
        hidden_states = self.pos_encoding(token_embeds)
        # Transpose back to [batch_size, seq_len, d_model]
        hidden_states = hidden_states.transpose(0, 1)
        
        # Create causal mask
        causal_mask = self.create_causal_mask(seq_len, device)
        
        # Pass through transformer blocks
        for transformer_block in self.transformer_blocks:
            hidden_states = transformer_block(hidden_states, causal_mask)
        
        # Output projection to vocabulary
        logits = self.output_projection(hidden_states)
        
        return logits
    
    def count_parameters(self) -> int:
        """Count trainable parameters."""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

# Test the mini language model
print("=== Mini Language Model Test ===")

# Model configuration
vocab_size = 1000
d_model = 256
num_heads = 8
num_layers = 4
d_ff = 1024
max_seq_len = 512

# Create model
model = MiniLanguageModel(vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len)

print(f"Model created with {model.count_parameters():,} parameters")

# Test forward pass
batch_size, seq_len = 2, 32
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))

print(f"Input shape: {input_ids.shape}")

# Forward pass
logits = model(input_ids)
print(f"Output logits shape: {logits.shape}")
print(f"Expected shape: [{batch_size}, {seq_len}, {vocab_size}]")

# Test loss computation
targets = torch.randint(0, vocab_size, (batch_size, seq_len))
criterion = nn.CrossEntropyLoss()

# Reshape for loss computation
loss = criterion(logits.view(-1, vocab_size), targets.view(-1))
print(f"Loss: {loss.item():.4f}")

print("\n✅ Mini transformer model working correctly!")

## 7. 🎯 Key Takeaways and Next Steps

### What You've Learned:

1. **PyTorch Fundamentals**
   - Tensor operations and device management
   - Automatic differentiation (autograd)
   - Neural network building blocks

2. **Attention Mechanisms**
   - Simple attention implementation
   - Multi-head attention from scratch
   - Positional encoding for sequence modeling

3. **Transformer Components**
   - Feed-forward networks
   - Layer normalization and residual connections
   - Complete transformer block implementation

4. **Mini Language Model**
   - End-to-end transformer-based language model
   - Causal masking for autoregressive generation
   - Parameter counting and model architecture

### What's Next:

🔬 **Phase 2**: Set up your training environment with GPU optimization and distributed training capabilities.

📊 **Phase 3**: Implement large-scale data collection and preprocessing pipelines.

🏗️ **Phase 4**: Build a full-scale transformer model with all the optimizations.

🚀 **Phase 5**: Train your first small language model and see it generate text!

### 💡 Pro Tips:

- **Practice**: Try modifying the model architecture (more layers, heads, etc.)
- **Experiment**: Change activation functions, dropout rates, layer normalization
- **Understand**: Make sure you understand each component before moving forward
- **Debug**: Use `print()` statements and visualization to understand data flow

### 🔧 Exercises to Try:

1. **Modify the attention mechanism** to use different attention patterns
2. **Implement different positional encodings** (learnable vs fixed)
3. **Add more sophisticated masking** for different sequence tasks
4. **Experiment with model sizes** and see how performance changes
5. **Implement gradient clipping** and learning rate scheduling

Remember: Building AI from scratch is challenging but incredibly rewarding! Each component you implement deepens your understanding of how modern AI systems work.

Good luck with your AI assistant journey! 🤖✨