# Defensive Programming

**Teaching:** 30 min  
**Exercises:** 10 min

## Learning Objectives

- Explain what an assertion is.
- Add assertions that check the program's state is correct.
- Correctly add precondition and postcondition assertions to functions.
- Explain what test-driven development is, and use it when creating new functions.
- Explain why variables should be initialized using actual data values rather than arbitrary constants.

## Questions

- How can I make my programs more reliable?

---

Our previous lessons have introduced the basic tools of programming: variables and lists, file I/O, loops, conditionals, and functions. What they *haven't* done is show us how to tell whether a program is getting the right answer, and how to tell if it's *still* getting the right answer as we make changes to it.

To achieve that, we need to:
- Write programs that check their own operation.
- Write and run tests for widely-used functions.
- Make sure we know what "correct" actually means.

The good news is, doing these things will speed up our programming, not slow it down. As in real carpentry — the kind done with lumber — the time saved by measuring carefully before cutting a piece of wood is much greater than the time that measuring takes.

## Setup

Let's start by importing the libraries we'll need and loading our inflammation data:

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import glob

In [None]:
# Load our inflammation data
data = np.loadtxt('../data/inflammation-01.csv', delimiter=',')
print(f"Data shape: {data.shape}")
print(f"Data type: {data.dtype}")

## Assertions

The first step toward getting the right answers from our programs is to assume that mistakes *will* happen and to guard against them. This is called **defensive programming**, and the most common way to do it is to add **assertions** to our code so that it checks itself as it runs.

An assertion is simply a statement that something must be true at a certain point in a program. When Python sees one, it evaluates the assertion's condition. If it's true, Python does nothing, but if it's false, Python halts the program immediately and prints the error message if one is provided.

In [None]:
# Example: Check that all inflammation values are non-negative
numbers = [1.5, 2.3, 0.7, -0.001, 4.4]
total = 0.0
for num in numbers:
    assert num >= 0.0, 'Data should only contain non-negative values'
    total += num
print('total is:', total)

In [None]:
# Let's fix the data and try again
numbers = [1.5, 2.3, 0.7, 0.001, 4.4]  # Fixed negative value
total = 0.0
for num in numbers:
    assert num >= 0.0, 'Data should only contain non-negative values'
    total += num
print('total is:', total)

### Types of Assertions

Broadly speaking, assertions fall into three categories:

- A **precondition** is something that must be true at the start of a function in order for it to work correctly.
- A **postcondition** is something that the function guarantees is true when it finishes.
- An **invariant** is something that is always true at a particular point inside a piece of code.

Let's create a function to analyze our inflammation data with proper assertions:

In [None]:
def analyze_inflammation(data):
    """Analyze inflammation data and return statistics.
    
    Parameters:
    data: 2D numpy array of inflammation measurements
    
    Returns:
    dict: Dictionary containing mean, max, and min values
    """
    # Preconditions
    assert isinstance(data, np.ndarray), 'Input must be a numpy array'
    assert data.ndim == 2, 'Data must be 2-dimensional'
    assert data.size > 0, 'Data cannot be empty'
    assert np.all(data >= 0), 'All inflammation values must be non-negative'
    
    # Perform calculations
    mean_inflammation = np.mean(data)
    max_inflammation = np.max(data)
    min_inflammation = np.min(data)
    
    # Create result dictionary
    result = {
        'mean': mean_inflammation,
        'max': max_inflammation,
        'min': min_inflammation
    }
    
    # Postconditions
    assert min_inflammation <= mean_inflammation <= max_inflammation, \
        'Mean should be between min and max values'
    assert len(result) == 3, 'Result should contain exactly 3 statistics'
    
    return result

In [None]:
# Test our function with valid data
stats = analyze_inflammation(data)
print(f"Inflammation statistics: {stats}")

In [None]:
# Test with invalid data (this should trigger an assertion)
try:
    invalid_data = np.array([1, 2, -1, 4])  # Contains negative value
    stats = analyze_inflammation(invalid_data)
except AssertionError as e:
    print(f"Caught assertion error: {e}")

### Rectangle Normalization Example

Let's look at a more complex example. Suppose we are representing rectangles using a tuple of four coordinates `(x0, y0, x1, y1)`, representing the lower left and upper right corners of the rectangle. We need to normalize the rectangle so that the lower left corner is at the origin and the longest side is 1.0 units long:

In [None]:
def normalize_rectangle(rect):
    """Normalizes a rectangle so that it is at the origin and 1.0 units long on its longest axis.
    Input should be of the format (x0, y0, x1, y1).
    (x0, y0) and (x1, y1) define the lower left and upper right corners
    of the rectangle, respectively."""
    
    # Preconditions
    assert len(rect) == 4, 'Rectangles must contain 4 coordinates'
    x0, y0, x1, y1 = rect
    assert x0 < x1, 'Invalid X coordinates'
    assert y0 < y1, 'Invalid Y coordinates'

    dx = x1 - x0
    dy = y1 - y0
    if dx > dy:
        scaled = dy / dx  # Fixed: was dx / dy
        upper_x, upper_y = 1.0, scaled
    else:
        scaled = dx / dy
        upper_x, upper_y = scaled, 1.0

    # Postconditions
    assert 0 < upper_x <= 1.0, 'Calculated upper X coordinate invalid'
    assert 0 < upper_y <= 1.0, 'Calculated upper Y coordinate invalid'

    return (0, 0, upper_x, upper_y)

In [None]:
# Test the rectangle normalization
print("Normalizing a tall rectangle:")
print(normalize_rectangle((0.0, 0.0, 1.0, 5.0)))

print("\nNormalizing a wide rectangle:")
print(normalize_rectangle((0.0, 0.0, 5.0, 1.0)))

In [None]:
# Test with invalid input
try:
    print(normalize_rectangle((0.0, 1.0, 2.0)))  # missing the fourth coordinate
except AssertionError as e:
    print(f"Caught assertion error: {e}")

try:
    print(normalize_rectangle((4.0, 2.0, 1.0, 5.0)))  # X axis inverted
except AssertionError as e:
    print(f"Caught assertion error: {e}")

## Test-Driven Development

An assertion checks that something is true at a particular point in the program. The next step is to check the overall behavior of a piece of code, i.e., to make sure that it produces the right output when it's given a particular input.

Most novice programmers would solve this problem like this:
1. Write a function `range_overlap`.
2. Call it interactively on two or three different inputs.
3. If it produces the wrong answer, fix the function and re-run that test.

There's a better way:
1. Write a short function for each test.
2. Write a `range_overlap` function that should pass those tests.
3. If `range_overlap` produces any wrong answers, fix it and re-run the test functions.

Writing the tests *before* writing the function they exercise is called **test-driven development** (TDD).

### Example: Range Overlap Function

Suppose we need to find where two or more time series overlap. The range of each time series is represented as a pair of numbers, which are the time the interval started and ended. The output is the largest range that they all include.

Let's start by defining an empty function:

In [None]:
def range_overlap(ranges):
    """Return common overlap among a set of [left, right] ranges."""
    pass

Now let's write some tests *before* implementing the function:

In [None]:
# These should fail initially since our function doesn't do anything yet
try:
    assert range_overlap([(0.0, 1.0)]) == (0.0, 1.0)
    print("Test 1 passed")
except AssertionError:
    print("Test 1 failed (expected for empty function)")

try:
    assert range_overlap([(2.0, 3.0), (2.0, 4.0)]) == (2.0, 3.0)
    print("Test 2 passed")
except AssertionError:
    print("Test 2 failed (expected for empty function)")

try:
    assert range_overlap([(0.0, 1.0), (0.0, 2.0), (-1.0, 1.0)]) == (0.0, 1.0)
    print("Test 3 passed")
except AssertionError:
    print("Test 3 failed (expected for empty function)")

We also need to decide what to do when ranges don't overlap. We decide that:
1. Every overlap has to have non-zero width
2. We will return the special value `None` when there's no overlap

Let's add tests for edge cases:

In [None]:
# Test edge cases
try:
    assert range_overlap([(0.0, 1.0), (5.0, 6.0)]) == None  # No overlap
    print("Edge test 1 passed")
except AssertionError:
    print("Edge test 1 failed (expected for empty function)")

try:
    assert range_overlap([(0.0, 1.0), (1.0, 2.0)]) == None  # Touch at endpoints
    print("Edge test 2 passed")
except AssertionError:
    print("Edge test 2 failed (expected for empty function)")

try:
    assert range_overlap([]) == None  # Empty input
    print("Edge test 3 passed")
except AssertionError:
    print("Edge test 3 failed (expected for empty function)")

Now let's implement the function properly:

In [None]:
def range_overlap(ranges):
    """Return common overlap among a set of [left, right] ranges."""
    if not ranges:  # Handle empty input
        return None
    
    # Initialize from the first range
    max_left = ranges[0][0]
    min_right = ranges[0][1]
    
    for (left, right) in ranges:
        max_left = max(max_left, left)
        min_right = min(min_right, right)
    
    # Check if there's actually an overlap
    if max_left >= min_right:
        return None
    
    return (max_left, min_right)

Let's create a comprehensive test function:

In [None]:
def test_range_overlap():
    """Test the range_overlap function with various inputs."""
    assert range_overlap([(0.0, 1.0), (5.0, 6.0)]) == None
    assert range_overlap([(0.0, 1.0), (1.0, 2.0)]) == None
    assert range_overlap([(0.0, 1.0)]) == (0.0, 1.0)
    assert range_overlap([(2.0, 3.0), (2.0, 4.0)]) == (2.0, 3.0)
    assert range_overlap([(0.0, 1.0), (0.0, 2.0), (-1.0, 1.0)]) == (0.0, 1.0)
    assert range_overlap([]) == None
    print("All tests passed!")

# Run all tests
test_range_overlap()

## Applying Defensive Programming to Inflammation Data

Let's create a robust function to analyze inflammation data across multiple files:

In [None]:
def analyze_inflammation_files(filenames):
    """Analyze inflammation data from multiple files with defensive programming."""
    # Preconditions
    assert isinstance(filenames, list), 'Filenames must be provided as a list'
    assert len(filenames) > 0, 'At least one filename must be provided'
    
    results = []
    
    for filename in filenames:
        assert isinstance(filename, str), f'Filename must be a string: {filename}'
        
        try:
            data = np.loadtxt(filename, delimiter=',')
        except IOError:
            print(f"Warning: Could not read file {filename}")
            continue
        except ValueError:
            print(f"Warning: Invalid data format in file {filename}")
            continue
        
        # Validate data
        assert data.ndim == 2, f'Data in {filename} must be 2-dimensional'
        assert data.size > 0, f'Data in {filename} cannot be empty'
        
        # Check for reasonable inflammation values
        if np.any(data < 0):
            print(f"Warning: Negative values found in {filename}")
        if np.any(data > 20):
            print(f"Warning: Unusually high inflammation values in {filename}")
        
        # Calculate statistics
        file_stats = {
            'filename': filename,
            'shape': data.shape,
            'mean': np.mean(data),
            'max': np.max(data),
            'min': np.min(data),
            'std': np.std(data)
        }
        
        # Postcondition: verify statistics make sense
        assert file_stats['min'] <= file_stats['mean'] <= file_stats['max'], \
            f'Invalid statistics for {filename}'
        assert file_stats['std'] >= 0, f'Standard deviation cannot be negative for {filename}'
        
        results.append(file_stats)
    
    # Final postcondition
    assert len(results) > 0, 'No valid files were processed'
    
    return results

In [None]:
# Test with our inflammation data files
inflammation_files = glob.glob('../data/inflammation-*.csv')[:3]  # First 3 files
print(f"Analyzing files: {inflammation_files}")

results = analyze_inflammation_files(inflammation_files)

for result in results:
    print(f"\nFile: {result['filename']}")
    print(f"  Shape: {result['shape']}")
    print(f"  Mean: {result['mean']:.2f}")
    print(f"  Range: {result['min']:.2f} - {result['max']:.2f}")
    print(f"  Std Dev: {result['std']:.2f}")

## Exercise: Creating Defensive Functions

Let's practice creating defensive functions for common data analysis tasks:

In [None]:
def safe_average(values):
    """Calculate average with defensive programming.
    
    Your task: Add appropriate preconditions and postconditions
    """
    # TODO: Add preconditions here
    
    result = sum(values) / len(values)
    
    # TODO: Add postconditions here
    
    return result

# Test the function
test_data = [1, 2, 3, 4, 5]
print(f"Average of {test_data}: {safe_average(test_data)}")

### Solution:

In [None]:
def safe_average(values):
    """Calculate average with defensive programming."""
    # Preconditions
    assert len(values) > 0, 'Cannot calculate average of empty sequence'
    assert all(isinstance(x, (int, float)) for x in values), 'All values must be numeric'
    
    result = sum(values) / len(values)
    
    # Postconditions
    assert min(values) <= result <= max(values), \
        'Average should be between min and max values'
    
    return result

# Test the function
test_data = [1, 2, 3, 4, 5]
print(f"Average of {test_data}: {safe_average(test_data)}")

# Test with invalid data
try:
    safe_average([])
except AssertionError as e:
    print(f"Caught expected error: {e}")

try:
    safe_average([1, 2, 'three'])
except AssertionError as e:
    print(f"Caught expected error: {e}")

## Best Practices for Defensive Programming

1. **Fail early, fail often**: The greater the distance between when and where an error occurs and when it's noticed, the harder the error will be to debug.

2. **Turn bugs into assertions or tests**: Whenever you fix a bug, write an assertion that catches the mistake should you make it again.

3. **Initialize from data**: Always initialize variables using actual data values rather than arbitrary constants.

4. **Write tests first**: Test-driven development helps you think about what your function should actually do.

5. **Document your assumptions**: Use assertions to make your assumptions about data explicit.

## Exercise: Defensive Data Quality Check

Create a function that performs comprehensive quality checks on inflammation data:

In [None]:
def check_data_quality(data, filename="unknown"):
    """Perform comprehensive quality checks on inflammation data.
    
    Returns a dictionary with quality metrics and warnings.
    """
    quality_report = {
        'filename': filename,
        'valid': True,
        'warnings': [],
        'errors': []
    }
    
    # Add your quality checks here
    # Check for:
    # - Correct data type and shape
    # - Missing values (NaN)
    # - Negative values
    # - Outliers (values > 20)
    # - Suspicious patterns (all zeros, all same value)
    
    return quality_report

# Test with our data
quality = check_data_quality(data, "inflammation-01.csv")
print(quality)

### Solution:

In [None]:
def check_data_quality(data, filename="unknown"):
    """Perform comprehensive quality checks on inflammation data."""
    quality_report = {
        'filename': filename,
        'valid': True,
        'warnings': [],
        'errors': []
    }
    
    try:
        # Check data type and basic structure
        if not isinstance(data, np.ndarray):
            quality_report['errors'].append('Data is not a numpy array')
            quality_report['valid'] = False
            return quality_report
        
        if data.ndim != 2:
            quality_report['errors'].append(f'Data should be 2D, got {data.ndim}D')
            quality_report['valid'] = False
        
        if data.size == 0:
            quality_report['errors'].append('Data is empty')
            quality_report['valid'] = False
            return quality_report
        
        # Check for missing values
        if np.any(np.isnan(data)):
            nan_count = np.sum(np.isnan(data))
            quality_report['warnings'].append(f'Found {nan_count} missing values (NaN)')
        
        # Check for negative values
        if np.any(data < 0):
            neg_count = np.sum(data < 0)
            quality_report['warnings'].append(f'Found {neg_count} negative values')
        
        # Check for outliers
        if np.any(data > 20):
            outlier_count = np.sum(data > 20)
            max_val = np.max(data)
            quality_report['warnings'].append(f'Found {outlier_count} outliers (>20), max={max_val:.2f}')
        
        # Check for suspicious patterns
        if np.all(data == 0):
            quality_report['warnings'].append('All values are zero')
        
        unique_values = len(np.unique(data))
        total_values = data.size
        if unique_values == 1:
            quality_report['warnings'].append('All values are identical')
        elif unique_values < total_values * 0.1:  # Less than 10% unique values
            quality_report['warnings'].append(f'Low diversity: only {unique_values} unique values')
        
        # Add summary statistics
        quality_report['stats'] = {
            'shape': data.shape,
            'mean': float(np.mean(data)),
            'std': float(np.std(data)),
            'min': float(np.min(data)),
            'max': float(np.max(data)),
            'unique_values': unique_values
        }
        
    except Exception as e:
        quality_report['errors'].append(f'Unexpected error: {str(e)}')
        quality_report['valid'] = False
    
    return quality_report

# Test with our data
quality = check_data_quality(data, "inflammation-01.csv")
print(f"Data quality report for {quality['filename']}:")
print(f"Valid: {quality['valid']}")
print(f"Warnings: {len(quality['warnings'])}")
for warning in quality['warnings']:
    print(f"  - {warning}")
print(f"Errors: {len(quality['errors'])}")
for error in quality['errors']:
    print(f"  - {error}")
if 'stats' in quality:
    print(f"Stats: {quality['stats']}")

## Summary

In this lesson, we learned about defensive programming techniques:

1. **Assertions** help catch errors early and document assumptions about our code
2. **Preconditions** check that function inputs are valid
3. **Postconditions** verify that function outputs are correct
4. **Test-driven development** involves writing tests before implementing functions
5. **Defensive programming** makes our code more robust and easier to debug

### Key Points

- Program defensively: assume that errors are going to arise, and write code to detect them when they do
- Put assertions in programs to check their state as they run, and to help readers understand how those programs are supposed to work
- Use preconditions to check that the inputs to a function are safe to use
- Use postconditions to check that the output from a function is safe to use
- Write tests before writing code in order to help determine exactly what that code is supposed to do

Defensive programming takes a little extra time up front, but it saves much more time in the long run by catching errors early and making code more reliable and maintainable.