# Grouped/Gridded Data Fitting in nbragg

This tutorial demonstrates how to analyze spatially-resolved or multi-sample neutron transmission data using nbragg's grouped fitting capabilities.

## Table of Contents
1. [Introduction](#introduction)
2. [Loading Grouped Data](#loading)
3. [Fitting Grouped Data](#fitting)
4. [Visualizing Results](#visualization)
5. [Accessing Individual Results](#individual)
6. [Saving and Loading](#saving)
7. [Advanced Features](#advanced)

In [None]:
import nbragg
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import tempfile

## 1. Introduction <a name="introduction"></a>

nbragg supports three types of grouped data:

- **2D Grids**: Spatially-resolved measurements on a regular grid (e.g., imaging data)
- **1D Arrays**: Linear sequences of measurements (e.g., scan along a line)
- **Named Groups**: Arbitrary collections with custom identifiers (e.g., different samples)

All group types support:
- Parallel fitting with `n_jobs` parameter
- Flexible index access (tuples, integers, strings)
- Parameter mapping and visualization
- Save/load functionality

## 2. Loading Grouped Data <a name="loading"></a>

### Creating Example Data

For this tutorial, we'll create synthetic grouped data with varying thickness across a 2D grid:

In [None]:
# Create temporary directory for example data
tmp_dir = tempfile.mkdtemp()
tmp_path = Path(tmp_dir)
print(f"Creating example data in: {tmp_path}")

# Generate 3x3 grid of measurements with varying properties
for i in range(3):
    for j in range(3):
        # Create synthetic data with position-dependent thickness
        channels = np.arange(100, 300)
        # Thickness varies across the grid
        thickness_variation = 1.0 + 0.2 * i + 0.1 * j
        signal_counts = 1000 - channels * (2 * thickness_variation) + np.random.randint(-20, 20, len(channels))
        signal_counts = np.maximum(signal_counts, 300)
        
        # Write signal file
        signal_file = tmp_path / f"signal_x{i}_y{j}.csv"
        with open(signal_file, 'w') as f:
            f.write("channel,counts\n")
            for ch, cnt in zip(channels, signal_counts):
                f.write(f"{ch},{cnt}\n")
        
        # Write openbeam file
        ob_counts = 1000 + np.random.randint(-10, 10, len(channels))
        ob_file = tmp_path / f"ob_x{i}_y{j}.csv"
        with open(ob_file, 'w') as f:
            f.write("channel,counts\n")
            for ch, cnt in zip(channels, ob_counts):
                f.write(f"{ch},{cnt}\n")

print(f"Created {len(list(tmp_path.glob('signal_*.csv')))} signal files")
print(f"Created {len(list(tmp_path.glob('ob_*.csv')))} openbeam files")

### Loading with Glob Patterns

The most common way to load grouped data is using glob patterns:

In [None]:
# Load grouped data using glob patterns
data = nbragg.Data.from_grouped(
    signal=str(tmp_path / "signal_*.csv"),
    openbeam=str(tmp_path / "ob_*.csv"),
    L=10,  # sample-detector distance in meters
    tstep=10e-6,  # time step in seconds
    verbosity=1
)

print(f"\nLoaded grouped data:")
print(f"  Number of groups: {len(data.indices)}")
print(f"  Group shape: {data.group_shape}")
print(f"  Indices: {data.indices[:5]}...")  # Show first 5

### Alternative Loading Methods

You can also load data from:
- **Folders**: `Data.from_grouped(signal="path/to/signal_folder/", openbeam="path/to/ob_folder/")`
- **File lists**: `Data.from_grouped(signal=["file1.csv", "file2.csv"], openbeam=[...])`

### Visualizing Grouped Data

You can plot individual groups or view the entire dataset:

In [None]:
# Plot a specific group using tuple index
fig, ax = plt.subplots(1, 1, figsize=(8, 5))
data.plot(index=(1, 1), ax=ax)
ax.set_title("Transmission data for group (1, 1)")
plt.tight_layout()
plt.show()

You can also access groups using string indices:

In [None]:
# Both formats work: "(1,1)" or "(1, 1)"
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
data.plot(index="(0,0)", ax=ax1)  # No spaces
data.plot(index="(2, 2)", ax=ax2)  # With spaces
ax1.set_title("Group (0, 0)")
ax2.set_title("Group (2, 2)")
plt.tight_layout()
plt.show()

## 3. Fitting Grouped Data <a name="fitting"></a>

### Basic Fitting

Fitting grouped data works the same as single datasets, but with automatic parallel processing:

In [None]:
# Define the model
xs = nbragg.CrossSection(iron=nbragg.materials["Fe_sg229_Iron-alpha"])
model = nbragg.TransmissionModel(xs, vary_basic=True)

print("Model parameters:")
print(model.params)

In [None]:
# Fit all groups in parallel
result = model.fit(
    data,
    n_jobs=4,  # Use 4 parallel workers
    progress_bar=True,  # Show progress
    wlmin=1.5,
    wlmax=5.0
)

print(f"\nFitting complete!")
print(f"Number of fitted groups: {len(result.indices)}")

### Summary Statistics

View a comprehensive summary of all fit results:

In [None]:
# Get summary DataFrame with all parameters and errors
summary_df = result.summary()

# Display first few rows
print("\nSummary DataFrame (first 5 rows):")
print(summary_df.head())

### HTML Report (Jupyter)

In Jupyter notebooks, you can display a formatted HTML summary:

In [None]:
# Display HTML report in Jupyter
from IPython.display import HTML, display
display(HTML(result.fit_report()))

## 4. Visualizing Results <a name="visualization"></a>

### Parameter Maps

The most powerful visualization feature is parameter mapping. nbragg automatically detects the appropriate plot type based on your data structure:

In [None]:
# Plot thickness parameter map (auto-detects 2D heatmap for grid data)
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
result.plot_parameter_map("thickness", ax=ax, cmap="viridis")
ax.set_title("Thickness variation across sample")
plt.tight_layout()
plt.show()

You can plot different parameters and include error bars:

In [None]:
# Plot norm parameter with errors
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

result.plot_parameter_map("norm", ax=ax1, cmap="plasma")
ax1.set_title("Normalization parameter")

result.plot_parameter_map("norm", plot_errors=True, ax=ax2, cmap="plasma")
ax2.set_title("Normalization errors")

plt.tight_layout()
plt.show()

### Filtering with Queries

You can filter which groups to display using pandas query syntax:

In [None]:
# Plot only groups where fit converged successfully
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
result.plot_parameter_map(
    "thickness",
    query="success == True and redchi < 2.0",
    ax=ax,
    cmap="coolwarm"
)
ax.set_title("Thickness (filtered: successful fits with χ² < 2)")
plt.tight_layout()
plt.show()

## 5. Accessing Individual Results <a name="individual"></a>

You can access and analyze individual group results just like standard fit results:

In [None]:
# Access result for a specific group
single_result = result[(1, 1)]  # or result["(1,1)"] or result["(1, 1)"]

print(f"Result for group (1, 1):")
print(f"  Success: {single_result.success}")
print(f"  Reduced χ²: {single_result.redchi:.4f}")
print(f"  Thickness: {single_result.params['thickness'].value:.4f} ± {single_result.params['thickness'].stderr:.4f}")
print(f"  Norm: {single_result.params['norm'].value:.4f} ± {single_result.params['norm'].stderr:.4f}")

### Plotting Individual Fits

In [None]:
# Plot fit for a specific group
fig, ax = plt.subplots(1, 1, figsize=(10, 6))
result.plot(index=(1, 1), ax=ax)
plt.tight_layout()
plt.show()

### Stages Summary for Individual Groups

If you used multi-stage fitting, you can view the stages summary for any group:

In [None]:
# View stages summary for a specific group (if multi-stage fitting was used)
# stages_table = result.stages_summary(index=(1, 1))
# print(stages_table)

### Plotting Total Cross-Section

View the total cross-section contribution for individual groups:

In [None]:
# Plot total cross-section for a specific group
fig, ax = plt.subplots(1, 1, figsize=(10, 6))
result.plot_total_xs(index=(1, 1), plot_dspace=False, ax=ax)
plt.tight_layout()
plt.show()

## 6. Saving and Loading <a name="saving"></a>

### Saving Grouped Results

Save all grouped fit results to a single file:

In [None]:
# Save grouped results
result_file = tmp_path / "grouped_results.json"
result.save(str(result_file))
print(f"Saved grouped results to {result_file}")

### Loading Grouped Results

In [None]:
# Load grouped results
loaded_result = nbragg.models.GroupedFitResult.load(str(result_file))

print(f"Loaded results:")
print(f"  Number of groups: {len(loaded_result.indices)}")
print(f"  Group shape: {loaded_result.group_shape}")

# Verify the loaded data matches
print(f"\nVerifying loaded data...")
original_thickness = result[(1, 1)].params['thickness'].value
loaded_thickness = loaded_result[(1, 1)].params['thickness'].value
print(f"Original thickness (1,1): {original_thickness:.6f}")
print(f"Loaded thickness (1,1): {loaded_thickness:.6f}")
print(f"Match: {np.isclose(original_thickness, loaded_thickness)}")

## 7. Advanced Features <a name="advanced"></a>

### Working with Named Groups

For non-spatial data or custom groupings:

In [None]:
# Create named group data
named_dir = tmp_path / "named_groups"
named_dir.mkdir(exist_ok=True)

for name in ["sample_a", "sample_b", "sample_c"]:
    channels = np.arange(100, 300)
    signal_counts = 1000 - channels * 2 + np.random.randint(-20, 20, len(channels))
    signal_counts = np.maximum(signal_counts, 300)
    
    with open(named_dir / f"{name}_signal.csv", 'w') as f:
        f.write("channel,counts\n")
        for ch, cnt in zip(channels, signal_counts):
            f.write(f"{ch},{cnt}\n")
    
    ob_counts = 1000 + np.random.randint(-10, 10, len(channels))
    with open(named_dir / f"{name}_ob.csv", 'w') as f:
        f.write("channel,counts\n")
        for ch, cnt in zip(channels, ob_counts):
            f.write(f"{ch},{cnt}\n")

# Load named groups
named_data = nbragg.Data.from_grouped(
    signal=str(named_dir / "*_signal.csv"),
    openbeam=str(named_dir / "*_ob.csv"),
    L=10,
    tstep=10e-6,
    verbosity=0
)

print(f"Named groups: {named_data.indices}")

In [None]:
# Fit named groups
named_result = model.fit(named_data, n_jobs=2, progress_bar=False)

# For named groups, parameter maps default to bar charts
fig, ax = plt.subplots(1, 1, figsize=(10, 6))
named_result.plot_parameter_map("thickness", ax=ax)
ax.set_title("Thickness comparison across named samples")
plt.tight_layout()
plt.show()

### Multi-Stage Fitting with Grouped Data

You can use Rietveld-type multi-stage fitting with grouped data:

In [None]:
# Define stages for sequential refinement
model.stages = {
    'basic': ['norm', 'thickness'],
    # Add more stages as needed
}

# Fit with stages
staged_result = model.fit(data, stages='all', n_jobs=4, progress_bar=True)

# View stages summary for a specific group
# staged_result.stages_summary(index=(1, 1))

### Adding Grouped Datasets

You can combine measurements from multiple runs:

In [None]:
# Load two datasets with same group structure
data1 = nbragg.Data.from_grouped(
    signal=str(tmp_path / "signal_*.csv"),
    openbeam=str(tmp_path / "ob_*.csv"),
    L=10, tstep=10e-6, verbosity=0
)

data2 = nbragg.Data.from_grouped(
    signal=str(tmp_path / "signal_*.csv"),
    openbeam=str(tmp_path / "ob_*.csv"),
    L=10, tstep=10e-6, verbosity=0
)

# Add them together (combines counts)
combined_data = data1 + data2
print(f"Combined {len(combined_data.indices)} groups")

## Summary

This tutorial covered:

1. **Loading grouped data** from files using glob patterns, folders, or lists
2. **Parallel fitting** with automatic worker management
3. **Parameter mapping** with auto-detection of plot types
4. **Individual result access** with flexible indexing
5. **Saving/loading** grouped results
6. **Advanced features** like named groups and multi-stage fitting

### Key Features

- **Flexible indexing**: Access groups with tuples `(0, 0)`, strings `"(0,0)"` or `"(0, 0)"`, integers, or names
- **Automatic parallelization**: Use `n_jobs` parameter for parallel fitting
- **Smart visualization**: Auto-detects appropriate plot type (heatmap, line, bar) based on data structure
- **Query filtering**: Use pandas queries to filter groups in visualizations
- **Complete compatibility**: All standard methods (plot, stages_summary, plot_total_xs) work with grouped results

### For More Information

- See the [main nbragg tutorial](nbragg_tutorial.ipynb) for general fitting concepts
- See the [Rietveld tutorial](Rietveld_in_nbragg_tutorial.ipynb) for multi-stage fitting
- Check the [documentation](https://nbragg.readthedocs.io) for API reference

In [None]:
# Cleanup temporary files
import shutil
shutil.rmtree(tmp_dir)
print("Cleaned up temporary files")