# üìä Notebook 06: Training & Evaluating Baseline Model

## From Model to Metrics

This notebook teaches you how to train your neural network and evaluate its performance using standard classification metrics. You'll implement the training loop, compute accuracy and F1-scores, and analyze your baseline model's strengths and weaknesses.


## üß† Concept Primer: Training and Evaluation

### What We're Doing
Training your neural network to minimize loss and evaluating its performance on unseen data using classification metrics.

### Why This Step is Critical
**Training updates model weights** to learn patterns in the data. **Evaluation measures** how well the model generalizes to new examples.

### Training Loop Components
1. **Forward pass**: Input ‚Üí Model ‚Üí Predictions
2. **Loss computation**: Predictions vs true labels
3. **Backward pass**: Compute gradients
4. **Parameter update**: Adjust weights to reduce loss

### Evaluation Metrics
- **Accuracy**: Percentage of correct predictions
- **Precision**: True positives / (True positives + False positives)
- **Recall**: True positives / (True positives + False negatives)  
- **F1-Score**: Harmonic mean of precision and recall

### Shape Expectations
- **Logits**: `[batch, n_aspects]` ‚Üí raw model outputs
- **Predictions**: `[batch]` ‚Üí argmax of logits
- **Labels**: `[batch]` ‚Üí true class indices


## üîß TODO #1: Implement Training Loop

**Task:** Create training loop that iterates through batches and updates model parameters.

**Hint:** Use `model.train()`, then loop with `optimizer.zero_grad()`, forward pass, loss computation, `loss.backward()`, `optimizer.step()`

**Expected Function:**
```python
def train_model(model, train_loader, criterion, optimizer, num_epochs=20):
    for epoch in range(num_epochs):
        total_loss = 0
        # TODO: Training loop here
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")
```

**Track:** Total loss per epoch to monitor training progress


In [None]:
# TODO #1: Implement training loop
# Your code here


## üîß TODO #2: Implement Evaluation Function

**Task:** Create evaluation function that computes predictions and metrics on test data.

**Hint:** Use `model.eval()`, `with torch.no_grad():`, collect predictions with `torch.argmax(torch.softmax(logits, dim=1), dim=1)`

**Expected Function:**
```python
def evaluate_model(model, test_loader):
    model.eval()
    all_predictions = []
    all_labels = []
    # TODO: Evaluation loop here
    return all_predictions, all_labels
```

**Shape:** Predictions and labels should both be lists of integers


In [None]:
# TODO #2: Implement evaluation function
# Your code here


## üîß TODO #3: Compute Metrics and Analysis

**Task:** Train model, evaluate, and compute classification metrics.

**Hint:** Use `accuracy_score(y_true, y_pred)`, `classification_report(y_true, y_pred)`, `confusion_matrix(y_true, y_pred)`

**Expected Output:**
- Train model for 20-50 epochs
- Print accuracy, precision, recall, F1-score
- Show confusion matrix
- Write 1-2 paragraph interpretation

**Sample Output:**
```
Accuracy: 0.65
F1-Score (macro): 0.63
Classification Report:
              precision    recall  f1-score   support
           0       0.68      0.62      0.65      1000
           1       0.63      0.69      0.66      1000
```


In [None]:
# TODO #3: Compute metrics and analysis
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

# Your code here


## üìù Reflection Prompts

### ü§î Understanding Check
1. **Which aspects does the baseline confuse most?** Look at the confusion matrix‚Äîwhat patterns do you see?

2. **Why does loss decrease but accuracy plateau?** What does this tell you about the model's learning?

3. **How does the baseline performance compare to random guessing?** Is your model actually learning?

4. **What would you expect to improve with a transformer model?** Consider vocabulary coverage and context understanding.

### üéØ Baseline Analysis
- What are the model's strengths and weaknesses?
- Which aspect classes are easiest/hardest to predict?
- How does the simple architecture limit performance?

---

**Write your reflections here:**
