In [None]:
# Real-Time Predictor and Streaming Plotter for LSTM Joint Angle Model

import torch
import numpy as np
import matplotlib.pyplot as plt
from collections import deque

class RealTimePredictor:
    def __init__(self, model, scaler_x, scaler_y, seq_len=500, input_dim=99):
        self.model = model.eval()
        self.scaler_x = scaler_x
        self.scaler_y = scaler_y
        self.seq_len = seq_len
        self.input_dim = input_dim

        # Rolling buffer to store incoming frames
        self.buffer = deque(maxlen=seq_len)

    def update(self, new_frame_raw):
        """
        new_frame_raw: np.array of shape (input_dim,) - raw unnormalized 3D coordinates for a frame
        """
        self.buffer.append(new_frame_raw)

        if len(self.buffer) == self.seq_len:
            buffer_array = np.array(self.buffer).reshape(-1, self.input_dim)
            norm_buffer = self.scaler_x.transform(buffer_array)

            input_tensor = torch.tensor(norm_buffer.reshape(1, self.seq_len, self.input_dim), dtype=torch.float32)

            with torch.no_grad():
                pred_norm = self.model(input_tensor)
                pred_last = pred_norm[0, -1, :].cpu().numpy()  # Last timestep prediction

            pred_denorm = self.scaler_y.inverse_transform(pred_last.reshape(1, -1)).flatten()
            return pred_denorm
        else:
            return None  # Buffer not filled yet

# ====================
# Streaming Plotter
# ====================

class StreamingPlotter:
    def __init__(self, joint_angle_columns, channels=[0, 5, 10], window_size=500):
        self.joint_angle_columns = joint_angle_columns
        self.channels = channels
        self.window_size = window_size

        self.predictions = {ch: deque(maxlen=window_size) for ch in channels}

        self.fig, self.axs = plt.subplots(len(channels), 1, figsize=(12, 8))
        if len(channels) == 1:
            self.axs = [self.axs]

        self.lines = []
        for ax, ch in zip(self.axs, channels):
            line, = ax.plot([], [], label=f"{joint_angle_columns[ch]}")
            ax.set_xlim(0, window_size)
            ax.set_ylim(-100, 100)  # Adjust based on angle ranges
            ax.set_xlabel("Frames")
            ax.set_ylabel("Angle (deg)")
            ax.set_title(f"Real-Time Prediction: {joint_angle_columns[ch]}")
            ax.grid(True)
            ax.legend()
            self.lines.append(line)

    def update_plot(self, new_pred):
        """
        new_pred: np.array of shape (output_dim,) - predicted joint angles
        """
        for idx, ch in enumerate(self.channels):
            self.predictions[ch].append(new_pred[ch])

            ydata = list(self.predictions[ch])
            xdata = list(range(len(ydata)))
            self.lines[idx].set_data(xdata, ydata)
            self.axs[idx].set_xlim(0, self.window_size)

        plt.pause(0.01)

# ====================
# Example Usage (after loading a frame-by-frame input)
# ====================
# predictor = RealTimePredictor(model, scaler_x=dataset.scaler_x, scaler_y=dataset.scaler_y)
# plotter = StreamingPlotter(joint_angle_columns)

# for new_frame in incoming_frames:
#     pred_angles = predictor.update(new_frame)
#     if pred_angles is not None:
#         plotter.update_plot(pred_angles)

# plt.show()
