In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Sample input and target
x = torch.randn(10, 5)  # 10 samples, 5 features
y = torch.randint(0, 2, (10,))  # Binary labels (0 or 1)

In [None]:
class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(5, 3)
        self.fc2 = nn.Linear(3, 2)  # Output logits for 2 classes

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x


In [None]:
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(1):  # One epoch for demo
    optimizer.zero_grad()
    out = model(x)

    # Basic resourcefulness: Print shapes
    #print("Model output shape:", out.shape)  # Expected: (10, 2)
    #print("Target shape:", y.shape)          # Expected: (10,)
    
    # Print actual values for first few samples
    #print("Output (first 2):", out[:2])
    #print("Target (first 2):", y[:2])
    
    try:
        loss = criterion(out, y)
    except RuntimeError as e:
        print("❌ Runtime error during loss computation:")
        print(e)
        import pdb; pdb.set_trace()  # Interactive debug
    else:
        loss.backward()
        optimizer.step()
        print("✅ Step completed. Loss:", loss.item())


In [None]:
def test_model_output_shape():
    test_x = torch.randn(4, 5)
    test_model = SimpleNN()
    out = test_model(test_x)
    assert out.shape == (4, 2), "Model output shape should be (batch_size, 2)"

test_model_output_shape()
print("✅ Shape test passed!")

<div align="center">  
  <img src="../asset/fail%20early%20fail%20fast%20fail%20often.webp" >  
  <br><br>
</div>


Fail-fast debugging is about **catching errors as close to their source as possible**, instead of letting them cascade into confusing tracebacks.

#### How to Implement Fail Fast:

* **Assert assumptions immediately:**

  ```python
  assert x.dim() == 2, "Expected 2D tensor for input x"
  assert y.max() < num_classes, "Target label exceeds number of classes"
  ```

* **Check for `NaN`, `inf`, or exploding values early:**

  ```python
  if torch.isnan(x).any():
      raise ValueError("Found NaN in input tensor x")
  if x.abs().max() > 1e6:
      raise ValueError("Input values exploding: check your initialization")
  ```

* **Validate function inputs:**

  ```python
  def train_step(x, y):
      if not isinstance(x, torch.Tensor):
          raise TypeError("Expected torch.Tensor for input x")
  ```

* **Use dummy batches before full training:**

  ```python
  dummy_x = torch.randn(2, 5)
  dummy_y = torch.randint(0, 3, (2,))
  loss = F.cross_entropy(model(dummy_x), dummy_y)
  print(f"Dummy loss: {loss.item():.4f}")
  ```

* **Write helper test cases**:
  Create mini test functions to verify key behavior:

  ```python
  def test_shapes():
      out = model(torch.randn(4, 5))
      assert out.shape == (4, 3), f"Unexpected output shape: {out.shape}"
  ```



## How to look up things

### Read the **Traceback** Carefully
The last line shows the error type and location. Traceback lines show the call stack, top is where it **crashed**, bottom is where it **started**. Look for things like `Expected input shape (N, C, H, W) but got (N, D)` to understand shape mismatches.

### Use **Google + Stackoverflow + ChatGPT** Effectively
- Copy-paste error messages (or parts) into Google or StackOverflow. Add keywords like `"PyTorch"`, `"CrossEntropyLoss"` or `"tensor shape"` to narrow down. Use [Discuss PyTorch](https://discuss.pytorch.org) – great place for subtle issues.