In [1123]:
from hybrid_sys_id_dataset import HybridSysIDDataset
from kLinReg import KLinRegMultiOutput
import numpy as np
import scipy.io
from torch.utils.data import DataLoader

In [1124]:
data = scipy.io.loadmat('xout.mat')
x = data['xout']  # Assuming your states are stored in 'xout'

# Create dataset
na = 4
dataset = HybridSysIDDataset(x, na=na)

# Extract all features and targets
X = np.stack([dataset[i][0].numpy() for i in range(len(dataset))])  # (N, na*4)
y = np.stack([dataset[i][1].numpy() for i in range(len(dataset))])  # (N, 4)

print(X.shape, y.shape)  # Should be (N, na*4), (N, 4)



(313, 16) (313, 4)


In [1125]:
s = 4  # number of modes
d = X.shape[1]
M = y.shape[1]

# Initial guess for theta: shape (d, s, M)
theta_init = np.random.randn(d, s, M)

klinreg = KLinRegMultiOutput(X, y, s, theta_init, max_iters=100, tol=1e-6)

In [1126]:
theta_final, mode_assignments = klinreg.fit()


Converged at iteration 8


In [1127]:
def predict_kLinReg(X_new, klinreg):
    N_new = X_new.shape[0]
    y_pred = np.zeros((N_new, M))
    
    for i in range(N_new):
        # Compute residuals for all modes
        residuals = np.zeros(s)
        for j in range(s):
            y_mode = X_new[i] @ klinreg.theta[:, j, :, klinreg.i]
            residuals[j] = np.sum((y_mode - y[i])**2)  # optional, or just pick mode j
        best_mode = np.argmin(residuals)
        y_pred[i] = X_new[i] @ klinreg.theta[:, best_mode, :, klinreg.i]
    
    return y_pred

y_pred = predict_kLinReg(X, klinreg)
print(y_pred)

[[ 9.94159268e-01  4.98752196e+00  1.91388880e-01 -3.01838434e-01]
 [ 9.94170576e-01  4.98315988e+00  1.91282490e-01 -3.99951670e-01]
 [ 9.94180410e-01  4.97782565e+00  1.91106404e-01 -4.97862265e-01]
 ...
 [ 9.72051621e-01  1.76151208e-03  1.91775301e+00 -3.06407313e-03]
 [ 9.91176635e-01  1.79630105e-03  1.91776081e+00 -3.12930176e-03]
 [ 1.01030165e+00  1.83109002e-03  1.91776862e+00 -3.19453039e-03]]


In [1128]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

N_plot = 200
y_true_plot = y[:N_plot]
y_pred_plot = y_pred[:N_plot]

labels = ['x_pos', 'y_pos', 'x_vel', 'y_vel']

# Create 2x2 subplot
fig = make_subplots(rows=2, cols=2, subplot_titles=labels)

for i in range(4):
    row = i // 2 + 1
    col = i % 2 + 1
    
    # True
    fig.add_trace(go.Scatter(y=y_true_plot[:, i], mode='lines', name=f'True {labels[i]}', line=dict(color='blue', width=2)), row=row, col=col)
    
    # Predicted
    fig.add_trace(go.Scatter(y=y_pred_plot[:, i], mode='lines', name=f'Predicted {labels[i]}', line=dict(color='red', width=2, dash='dash')), row=row, col=col)

fig.update_layout(height=600, width=800, title_text="Time Series: True vs Predicted")
fig.show()


In [1129]:
fig2 = go.Figure()

# True trajectory
fig2.add_trace(go.Scatter(x=y[:, 0], y=y[:, 1], mode='lines', name='True trajectory', line=dict(color='blue', width=2)))

# Predicted trajectory with transparency
fig2.add_trace(go.Scatter(x=y_pred[:, 0], y=y_pred[:, 1], mode='lines', name='Predicted trajectory', line=dict(color='red', width=2), opacity=0.5))

fig2.update_layout(
    # title="2D Trajectory: True vs Predicted",
    xaxis_title="x_pos",
    yaxis_title="y_pos",
    width=800,
    height=600,
    yaxis=dict(scaleanchor="x", scaleratio=1),  # equal scaling
    legend=dict(
        x=0.02,  # 2% from left edge
        y=0.98,  # 98% from bottom (near top)
        xanchor='left',
        yanchor='top',
        bgcolor='rgba(255,255,255,0.8)',  # Semi-transparent white background
        bordercolor='black',
        borderwidth=1
    )
)
fig2.show()


In [1130]:
import plotly.graph_objects as go

fig = go.Figure()
fig.add_trace(go.Scatter(y=mode_assignments, mode='lines+markers', line=dict(shape='hv')))
fig.update_layout(
    title="Mode Assignment Sequence",
    xaxis_title="Time step",
    yaxis_title="Mode",
    yaxis=dict(dtick=1)  # ensure integer ticks
)
fig.show()


In [1131]:
import plotly.graph_objects as go
import plotly.express as px

# Function to find continuous segments of each mode
def find_mode_segments(mode_assignments):
    """Find start and end times for each continuous mode segment."""
    segments = []
    current_mode = mode_assignments[0]
    start_time = 0
    
    for i in range(1, len(mode_assignments)):
        if mode_assignments[i] != current_mode:
            # End of current segment
            segments.append({
                'mode': current_mode,
                'start': start_time,
                'end': i - 1,
                'duration': i - start_time
            })
            # Start of new segment
            current_mode = mode_assignments[i]
            start_time = i
    
    # Add the final segment
    segments.append({
        'mode': current_mode,
        'start': start_time,
        'end': len(mode_assignments) - 1,
        'duration': len(mode_assignments) - start_time
    })
    
    return segments

# Get mode segments
segments = find_mode_segments(mode_assignments)
# Create filled rectangles for each mode segment
fig = go.Figure()

segments = find_mode_segments(mode_assignments)
colors = px.colors.qualitative.Set1  # Use a nice color palette

for segment in segments:
    mode = segment['mode']
    color = colors[mode % len(colors)]
    
    # Create a filled rectangle
    fig.add_trace(go.Scatter(
        x=[segment['start'], segment['end'], segment['end'], segment['start'], segment['start']],
        y=[mode - 0.4, mode - 0.4, mode + 0.4, mode + 0.4, mode - 0.4],
        fill='toself',
        fillcolor=color,
        line=dict(color=color, width=1),
        mode='none',  # Add this line to remove markers/points
        name=f'Mode {mode}',
        showlegend=False,
        hovertemplate=f'Mode {mode}<br>Start: {segment["start"]}<br>End: {segment["end"]}<br>Duration: {segment["duration"]}<extra></extra>'
    ))

# Add legend
unique_modes = sorted(set(mode_assignments))
for mode in unique_modes:
    fig.add_trace(go.Scatter(
        x=[None], y=[None],
        mode='markers',
        marker=dict(size=15, color=colors[mode % len(colors)], symbol='square'),
        name=f'Mode {mode}'
    ))

fig.update_layout(
    title="Mode Activity Timeline (Filled Bars)",
    xaxis_title="Time Step",
    yaxis_title="Mode",
    yaxis=dict(
        dtick=1,
        range=[-0.5, max(mode_assignments) + 0.5]
    ),
    height=400
)

fig.show()