# Infer model on array

---

## Imports

In [1]:
import torch
import xbatcher
import xarray as xr
import numpy as np
import pytest

from functions import _get_output_array_size, predict_on_array

## Testing the array size function

In [2]:
%%writefile test_get_array_size.py
import torch
import xbatcher
import xarray as xr
import numpy as np
import pytest

from functions import _get_output_array_size

Overwriting test_get_array_size.py


In [3]:
%%writefile -a test_get_array_size.py

@pytest.fixture
def bgen_fixture() -> xbatcher.BatchGenerator:
    data = xr.DataArray(
        data=np.random.rand(100, 100, 10),
        dims=("x", "y", "t"),
        coords={
            "x": np.arange(100),
            "y": np.arange(100),
            "t": np.arange(10),
        }
    )
    
    bgen = xbatcher.BatchGenerator(
        data,
        input_dims=dict(x=10, y=10),
        input_overlap=dict(x=5, y=5),
    )
    return bgen

@pytest.mark.parametrize(
    "case_description, output_tensor_dim, new_dim, resample_dim, expected_output",
    [
        (
            "Resampling only: Downsample x, Upsample y",
            {'x': 5, 'y': 20},  
            [],
            ['x', 'y'],
            {'x': 50, 'y': 200} 
        ),
        (
            "New dimensions only: Add a 'channel' dimension",
            {'channel': 3},
            ['channel'],
            [],
            {'channel': 3}
        ),
        (
            "Mixed: Resample x and add new channel dimension",
            {'x': 30, 'channel': 12}, 
            ['channel'],
            ['x'],
            {'x': 300, 'channel': 12} 
        ),
        (
            "Identity resampling (ratio=1)",
            {'x': 10, 'y': 10},
            [],
            ['x', 'y'],
            {'x': 100, 'y': 100} 
        ),
        (
            "Dimension not in batcher is treated as new",
            {'t': 5},
            ['t'],
            [],
            {'t': 5}
        )
        
    ]
)
def test_get_output_array_size_scenarios(
    bgen_fixture,  # The fixture is passed as an argument
    case_description,
    output_tensor_dim,
    new_dim,
    resample_dim,
    expected_output
):
    """
    Tests various valid scenarios for calculating the output array size.
    The `case_description` parameter is not used in the code but helps make
    test results more readable.
    """
    # The `bgen_fixture` argument is the BatchGenerator instance created by our fixture
    result = _get_output_array_size(
        bgen=bgen_fixture,
        output_tensor_dim=output_tensor_dim,
        new_dim=new_dim,
        resample_dim=resample_dim
    )
    
    assert result == expected_output, f"Failed on case: {case_description}"

Appending to test_get_array_size.py


In [4]:
%%writefile -a test_get_array_size.py

def test_get_output_array_size_raises_assertion_error_on_non_integer_size():
    """
    Tests that the function raises an AssertionError when the resampling
    calculation results in a non-integer output dimension size.
    """
    # DataArray size for 'x' is 101.
    data_for_error = xr.DataArray(
        data=np.random.rand(101, 100, 10),
        dims=("x", "y", "t")
    )
    
    bgen = xbatcher.BatchGenerator(data_for_error, input_dims={'x': 10})
    
    # The resampling logic will be: 101 * (5 / 10) = 50.5, which is not an integer.
    output_tensor_dim = {'x': 5}
    
    with pytest.raises(AssertionError):
        _get_output_array_size(
            bgen=bgen,
            output_tensor_dim=output_tensor_dim,
            new_dim=[],
            resample_dim=['x']
        )

Appending to test_get_array_size.py


In [5]:
!pytest -v test_get_array_size.py

platform linux -- Python 3.10.16, pytest-8.4.1, pluggy-1.6.0 -- /srv/conda/envs/notebook/bin/python3.10
cachedir: .pytest_cache
rootdir: /home/jovyan/xbatcher-deep-learning/notebooks
plugins: anyio-4.8.0
collected 6 items                                                              [0m[1m

test_get_array_size.py::test_get_output_array_size_scenarios[Resampling only: Downsample x, Upsample y-output_tensor_dim0-new_dim0-resample_dim0-expected_output0] [32mPASSED[0m[32m [ 16%][0m
test_get_array_size.py::test_get_output_array_size_scenarios[New dimensions only: Add a 'channel' dimension-output_tensor_dim1-new_dim1-resample_dim1-expected_output1] [32mPASSED[0m[32m [ 33%][0m
test_get_array_size.py::test_get_output_array_size_scenarios[Mixed: Resample x and add new channel dimension-output_tensor_dim2-new_dim2-resample_dim2-expected_output2] [32mPASSED[0m[32m [ 50%][0m
test_get_array_size.py::test_get_output_array_size_scenarios[Identity resampling (ratio=1)-output_tensor_dim3-n

## Testing the predict_on_array function

In [19]:
%%writefile test_predict_on_array.py
import xarray as xr
import numpy as np
import torch
import xbatcher
import pytest
from xbatcher.loaders.torch import MapDataset

from functions import _get_output_array_size, predict_on_array
from dummy_models import *

Overwriting test_predict_on_array.py


In [20]:
import xarray as xr
import numpy as np
import torch
import xbatcher
import pytest
from xbatcher.loaders.torch import MapDataset

from functions import _get_output_array_size, predict_on_array
from dummy_models import *

In [21]:
input_tensor = torch.arange(125).reshape((5, 5, 5)).to(torch.float32)
input_tensor[0,0,:]

tensor([0., 1., 2., 3., 4.])

In [22]:
model = MeanAlongDim(-1)
model(input_tensor)

tensor([[  2.,   7.,  12.,  17.,  22.],
        [ 27.,  32.,  37.,  42.,  47.],
        [ 52.,  57.,  62.,  67.,  72.],
        [ 77.,  82.,  87.,  92.,  97.],
        [102., 107., 112., 117., 122.]])

In [23]:
%%writefile -a test_predict_on_array.py

@pytest.fixture
def map_dataset_fixture() -> MapDataset:
    """
    Creates a MapDataset with a predictable BatchGenerator for testing.
    - Data is an xarray DataArray with dimensions x=20, y=10
    - Values are a simple np.arange sequence for easy verification.
    - Batches are size x=10, y=5 with overlap x=2, y=2
    """
    # Using a smaller, more manageable dataset for testing
    data = xr.DataArray(
        data=np.arange(20 * 10).reshape(20, 10),
        dims=("x", "y"),
        coords={"x": np.arange(20), "y": np.arange(10)}
    ).astype(float)
    
    bgen = xbatcher.BatchGenerator(
        data,
        input_dims=dict(x=10, y=5),
        input_overlap=dict(x=2, y=2),
    )
    return MapDataset(bgen)

Appending to test_predict_on_array.py


In [None]:
    data = xr.DataArray(
        data=np.arange(20 * 10).reshape(20, 10),
        dims=("x", "y"),
        coords={"x": np.arange(20), "y": np.arange(10)}
    ).astype(float)
    
    bgen = xbatcher.BatchGenerator(
        data,
        input_dims=dict(x=10, y=5),
        input_overlap=dict(x=2, y=2),
    )

In [24]:
%%writefile -a test_predict_on_array.py

@pytest.mark.parametrize(
    "model, output_tensor_dim, new_dim, resample_dim, expected_transform",
    [
        # Case 1: Resampling - Downsampling with a subset model
        (
            SubsetAlongAxis(ax=1, n=5), # Corresponds to 'x' dim in batch
            {'x': 5, 'y': 5},
            [],
            ['x'],
            lambda da: da.isel(x=slice(0, 5)) # Expected: take first 5 elements of original 'x'
        ),
        # Case 2: Dimension reduction with a mean model
        (
            MeanAlongDim(ax=2), # Corresponds to 'y' dim in batch
            {'x': 10},
            [],
            ['x'],
            lambda da: da.mean(dim='y') # Expected: mean along original 'y'
        ),
    ]
)
def test_predict_on_array_reassembly(
    map_dataset_fixture,
    model,
    output_tensor_dim,
    new_dim,
    resample_dim,
    expected_transform
):
    """
    Tests that predict_on_array correctly reassembles batches from different models.
    """
    # --- Run the function under test ---
    # Using a small batch_size to ensure multiple iterations
    predicted_da, predicted_n = predict_on_array(
        dataset=map_dataset_fixture,
        model=model,
        output_tensor_dim=output_tensor_dim,
        new_dim=new_dim,
        resample_dim=resample_dim,
        batch_size=4 
    )

    # --- Manually calculate the expected result ---
    bgen = map_dataset_fixture.generator
    # 1. Create the expected output array structure
    expected_size = _get_output_array_size(bgen, output_tensor_dim, new_dim, resample_dim)
    expected_da = xr.DataArray(np.zeros(list(expected_size.values())), dims=list(expected_size.keys()))
    expected_n = xr.full_like(expected_da, 0)

    # 2. Manually iterate through batches and apply the same logic as the function
    for i in range(len(map_dataset_fixture)):
        batch_da = bgen[i]
        
        # Apply the same transformation the model would
        transformed_batch = expected_transform(batch_da)
        
        # Get the rescaled indexer
        old_indexer = bgen.batch_selectors[i]
        new_indexer = {}
        for key in old_indexer:
            if key in resample_dim:
                resample_ratio = output_tensor_dim[key] / bgen.input_dims[key]
                new_indexer[key] = slice(
                    int(old_indexer[key].start * resample_ratio),
                    int(old_indexer[key].stop * resample_ratio)
                )
        
        # Add the result to our manually calculated array
        expected_da.loc[new_indexer] += transformed_batch.values
        expected_n.loc[new_indexer] += 1

    # --- Assert that the results are identical ---
    # We test the raw summed output and the overlap counter array
    xr.testing.assert_allclose(predicted_da, expected_da)
    xr.testing.assert_allclose(predicted_n, expected_n)

Appending to test_predict_on_array.py


In [25]:
!pytest -v test_predict_on_array.py

platform linux -- Python 3.10.16, pytest-8.4.1, pluggy-1.6.0 -- /srv/conda/envs/notebook/bin/python3.10
cachedir: .pytest_cache
rootdir: /home/jovyan/xbatcher-deep-learning/notebooks
plugins: anyio-4.8.0
collected 2 items                                                              [0m[1m

test_predict_on_array.py::test_predict_on_array_reassembly[model0-output_tensor_dim0-new_dim0-resample_dim0-<lambda>] [31mFAILED[0m[31m [ 50%][0m
[31mFAILED[0m[31m [100%][0mpredict_on_array_reassembly[model1-output_tensor_dim1-new_dim1-resample_dim1-<lambda>] 

[31m[1m_ test_predict_on_array_reassembly[model0-output_tensor_dim0-new_dim0-resample_dim0-<lambda>] _[0m

map_dataset_fixture = <xbatcher.loaders.torch.MapDataset object at 0x7f4d4a77cdc0>
model = SubsetAlongAxis(), output_tensor_dim = {'x': 5, 'y': 5}, new_dim = []
resample_dim = ['x'], expected_transform = <function <lambda> at 0x7f4d4a136cb0>

    [0m[37m@pytest[39;49;00m.mark.parametrize([90m[39;49;00m
        [33m"[39