# Trace forward and backward
**Eivind Brendryen, 2026-02-15**

In this example, we will trace a segment both backward and forward. We will see how to do it, and discuss how this affects aggregation.

In [1]:
from pathlib import Path
from aqua_tracekit import SdtModel, SdtSchema, Aggregation
from pathlib import Path
from datetime import datetime
from IPython.display import HTML, Javascript, display
import polars as pl

In [2]:
# create and load the model
base_path = Path("data")
model = SdtModel(base_path=str(base_path.resolve()))
df_containers = model.load_containers("containers.csv")
df_segments = model.load_segments("segments.csv")
df_transfers = model.load_transfers("transfers.csv")

# plot the graph
html = model.visualize_trace()
display(HTML(html))
# highlight the treatment segment (the graph is interactive - click on segments to highlight trace forward and backward)
display(Javascript('sdtSelectsegment("P4")'))


<IPython.core.display.Javascript object>

P4 is the segment we want to trace forward and backward. 

In [4]:
df_origin_segments = df_segments.filter(pl.col(SdtSchema.Segment.SEGMENT_ID)=="P4")
df_traceabilty_index = model.trace_segments(df_origin_segments)
df_traceabilty_index.head(11)

origin_segment_id,traced_segment_id,direction,share_count_forward,share_biomass_forward,share_count_backward,share_biomass_backward
str,str,str,f64,f64,f64,f64
"""P4""","""P4""","""identity""",1.0,1.0,1.0,1.0
"""P4""","""P6""","""forward""",0.4,0.4,1.0,1.0
"""P4""","""P7""","""forward""",0.6,0.6,0.2,0.2
"""P4""","""P1""","""backward""",0.1,0.1,0.4,0.4
"""P4""","""P2""","""backward""",1.0,1.0,0.6,0.6


The column "direction" indicates if the trace_segment_id comes fra backward or forward tracing. We will now discuss how
this can affect how we aggragate data. First we will load some mortality data. For simplicity, the mortality is 10 dead fish at 16:00 
for all segments, every day they exist.

In [5]:
# Load mortality data on segment level
df_mortality = model.load_segment_timeseries("mortality.csv")
df_mortality = model.parse_float(df_mortality, "mortality_count")
# map to tracability index
df_traced_data = model.add_data_to_trace(df_mortality, df_traceabilty_index)


In [6]:
# Now lets look at the data for jan 8th at 16:00
jan_8th = datetime(2025, 1, 8, 16)
df_jan_8th = df_traced_data.filter(pl.col(SdtSchema.TimeSeries.DATE_TIME)==jan_8th)
df_jan_8th.head(11)

origin_segment_id,traced_segment_id,direction,share_count_forward,share_biomass_forward,share_count_backward,share_biomass_backward,date_time,mortality_count
str,str,str,f64,f64,f64,f64,datetime[μs],f64
"""P4""","""P1""","""backward""",0.1,0.1,0.4,0.4,2025-01-08 16:00:00,10.0
"""P4""","""P2""","""backward""",1.0,1.0,0.6,0.6,2025-01-08 16:00:00,10.0


As expected, there are to real segments, each with 10 mortality, and we have the factors. 
10% of the fish in P1 went to P4, so we attribute 1 dead fish from P1 to P4 (share_count_forward = 0.1)
All of the fish in P2 went to P4, so we attribute 10 dead fish from P2 to P4 (share_count_forward = 1.0)

So the traced mortality for P4 on jan 8th is a total of 11 fish. We can implement this with an aggragation that will scale and sum:




In [7]:
def my_agg(group: pl.DataFrame) -> dict:
    weighted_mortality = (group["share_count_forward"] * group["mortality_count"]).sum()
    return {"weighted_mortality": weighted_mortality} 

aggs = [
    Aggregation.custom(my_agg),
]
result = SdtModel.aggregate_traced_data(df_traced_data, aggs)
result = result.sort(SdtSchema.TimeSeries.DATE_TIME)

df_jan_8th = result.filter(pl.col(SdtSchema.TimeSeries.DATE_TIME)==jan_8th)
df_jan_8th.head(11)


origin_segment_id,date_time,weighted_mortality
str,datetime[μs],f64
"""P4""",2025-01-08 16:00:00,11.0


In [8]:
# Now lets look at the data for feb 8th at 16:00
feb_8th = datetime(2025, 2, 8, 16)
df_feb_8th = df_traced_data.filter(pl.col(SdtSchema.TimeSeries.DATE_TIME)==feb_8th)
df_feb_8th.head(11)

origin_segment_id,traced_segment_id,direction,share_count_forward,share_biomass_forward,share_count_backward,share_biomass_backward,date_time,mortality_count
str,str,str,f64,f64,f64,f64,datetime[μs],f64
"""P4""","""P6""","""forward""",0.4,0.4,1.0,1.0,2025-02-08 16:00:00,10.0
"""P4""","""P7""","""forward""",0.6,0.6,0.2,0.2,2025-02-08 16:00:00,10.0


In [9]:
# lets look at the aggregated data for this date:
df_feb_8th = result.filter(pl.col(SdtSchema.TimeSeries.DATE_TIME)==feb_8th)
df_feb_8th.head(11)

origin_segment_id,date_time,weighted_mortality
str,datetime[μs],f64
"""P4""",2025-02-08 16:00:00,10.0


We can see that this does not work. P6 has fish only from P4, so all of this mortality should be attributed to P4.

***=> we have to use the backward factor (1.0), not the forward (0.4)***

or as a more general rule: When trace direction changes, we need to swap factors:





In [10]:
def my_agg(group: pl.DataFrame) -> dict:
    # Get the direction (should be same for all rows in group)
    direction = group[SdtSchema.TraceabilityIndex.TRACE_DIRECTION][0]
    
    if direction == SdtSchema.DIRECTION.FORWARD:
        share_col = SdtSchema.TraceabilityIndex.FACTORS.SHARE_COUNT_BACKWARD
    else:
        share_col = SdtSchema.TraceabilityIndex.FACTORS.SHARE_COUNT_FORWARD
    
    weighted_mortality = (group[share_col] * group["mortality_count"]).sum()
    return {"weighted_mortality": weighted_mortality}

aggs = [
    Aggregation.custom(my_agg),
]
result = SdtModel.aggregate_traced_data(df_traced_data, aggs)
result = result.sort(SdtSchema.TimeSeries.DATE_TIME)



df_8th = result.filter(pl.col(SdtSchema.TimeSeries.DATE_TIME).is_in([jan_8th, feb_8th]))
df_8th.head(11)


origin_segment_id,date_time,weighted_mortality
str,datetime[μs],f64
"""P4""",2025-01-08 16:00:00,11.0
"""P4""",2025-02-08 16:00:00,12.0


Perfect! We can compare with the built-in aggregations:

In [11]:
aggs = [
    Aggregation.weighted_sum(["mortality_count"], SdtSchema.AGGREGATE_BY.COUNT)
]
result = SdtModel.aggregate_traced_data(df_traced_data, aggs)
result = result.sort(SdtSchema.TimeSeries.DATE_TIME)

df_8th = result.filter(pl.col(SdtSchema.TimeSeries.DATE_TIME).is_in([jan_8th, feb_8th]))
df_8th.head(11)


origin_segment_id,date_time,mortality_count
str,datetime[μs],f64
"""P4""",2025-01-08 16:00:00,11.0
"""P4""",2025-02-08 16:00:00,12.0
