<a href="https://colab.research.google.com/github/MachineSaver/MachineSaver/blob/main/Class_TwinProx.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import minimalmodbus
import time
from datetime import datetime
import numpy as np
import struct
from tqdm import tqdm
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
from typing import List, Dict, Tuple
import math

class TwinProx:
    def __init__(self, port: str, slaveaddress: int):
        self.slave = minimalmodbus.Instrument(port=port, slaveaddress=slaveaddress)
        self.configure_modbus_settings()
        self.data_array = []
        self.header = []
        self.groove_array = []
        self.waveforms = []
        self.channel_a = []
        self.channel_b = []
        self.clipsize = 3600
        self.trace_a, self.trace_b, self.trace_orbit = [], [], []

    def configure_modbus_settings(self):
        self.slave.serial.baudrate = 115200
        self.slave.serial.bytesize = 8
        self.slave.serial.parity = "N"
        self.slave.serial.stopbits = 1
        self.slave.serial.timeout = 0.10
        self.slave.close_port_after_each_call = True
        self.slave.mode = minimalmodbus.MODE_RTU
        self.slave.clear_buffers_before_each_transaction = True

    def setup_capture(self, modbus_channel_ab: int = 6) -> int:
        self.clipsize = 3600
        self.slave.write_register(32, modbus_channel_ab)
        return self.slave.read_register(32)

    def trigger_capture(self) -> int:
        self.data_array, self.header, self.groove_array, self.waveforms, self.channel_a, self.channel_b = [], [], [], [], [], []
        self.slave.write_register(33, 1)
        return int(str(time.time())[slice(10)])

    def check_capture_engine(self) -> int:
        return self.slave.read_register(34, functioncode=3)

    def get_dataclip_size(self) -> int:
        return self.slave.read_long(36, functioncode=3)

    def extract_header(self) -> List[int]:
        self.header = self.data_array[:8]
        del self.data_array[:8]
        return self.header

    def extract_waves(self) -> List[int]:
        self.waveforms = self.data_array[:3600]
        del self.data_array[:3600]
        return self.waveforms

    def extract_grooves(self) -> List[int]:
        groove_pairs = int(self.header[4]/2)
        self.groove_array = self.data_array[:self.header[4]]
        del self.data_array[:groove_pairs]
        return self.groove_array

    def scale_waveforms(self) -> List[int]:
        self.waveforms = [value / 100.0 for value in self.waveforms]
        return self.waveforms

    def split_channels(self) -> List[int]:
        self.channel_a = self.waveforms[:1800]
        del self.waveforms[:1800]
        self.channel_b = self.waveforms[:1800]
        del self.data_array[:1800]
        return self.waveforms

    def decimal_to_binary(self) -> None:
        '''Takes a list of decimal numbers and converts them into binary numbers.'''
        max_length = len(bin(max(self.groove_array))[2:])
        self.groove_array = [bin(num)[2:].zfill(max_length) for num in self.groove_array]
        return None

    def edge_removal(self) -> None:
        '''Takes a list of binary numbers and removes the leftmost '1' if it exists.'''
        self.groove_array = [num.replace('1', '0', 1) if num[0] == '1' else num for num in self.groove_array]
        return None

    def binary_to_decimal(self) -> None:
        '''Takes a list of binary numbers and converts them into decimal numbers.'''
        self.groove_array = [int(num, 2) for num in self.groove_array]
        return None

    def update_groove_array(self) -> None:
        '''Update groove array by converting to binary, removing edge, and converting back to decimal.'''
        # Convert groove array to binary
        self.decimal_to_binary()
        # Remove edge from binary representation
        self.edge_removal()
        # Convert back to decimal
        self.binary_to_decimal()
        return None

    def replace_with_nan(self):
        self.channel_a = np.array(self.channel_a)
        self.channel_b = np.array(self.channel_b)
        if len(self.groove_array) % 2 != 0:
            self.groove_array = self.groove_array[:len(self.groove_array)-1]
        else:
            for i in range(0, len(self.groove_array), 2):
                start = self.groove_array[i]
                end = self.groove_array[i+1]
                if end < len(self.channel_a):
                    self.channel_a[start:end] = np.full(end - start, np.nan)
                if end < len(self.channel_b):
                    self.channel_b[start:end] = np.full(end - start, np.nan)
        return None

    def create_graphs(self) -> None:
        set_machine_saver_theme()
        self.trace_a = go.Scatter(y=self.channel_a, mode='lines+markers', name='Channel A')
        self.trace_b = go.Scatter(y=self.channel_b, mode='lines+markers', name='Channel B')
        self.orbit = go.Scatter(x=self.channel_a, y=self.channel_b, mode='lines+markers', name='Orbit')
        self.fig = make_subplots(rows=1, cols=2, subplot_titles=('Waveform', 'Orbit Plot'))
        self.fig.add_trace(self.trace_a, row=1, col=1)
        self.fig.add_trace(self.trace_b, row=1, col=1)
        self.fig.update_yaxes(title_text='Mils', row=1, col=1)
        self.fig.update_xaxes(title_text='Time (Samples)', row=1, col=1)
        self.fig.add_trace(self.orbit, row=1, col=2)
        self.fig.show()
        # self.fig.update_xaxes(title_text='Channel A', row=1, col=2)
        # self.fig.update_yaxes(title_text='Channel B', row=1, col=2)
        # self.fig.update_layout(xaxis2=dict(scaleanchor="y2", scaleratio=1), title='Data', xaxis_title='Samples')
        # self.fig.show()
        return None

    def extract_data(self) -> List[int]:
        self.data_array = []
        total_chunks = math.ceil(self.get_dataclip_size() / 122)
        for _ in tqdm(range(total_chunks), desc="Extracting Data", ncols=100):
            chunk = self.slave.read_registers(49, 123, functioncode=3)
            index = chunk.pop(0)
            if index == 0:
                self.header = chunk[:8]
            else:
                pass
            self.data_array.extend(chunk)

        self.extract_header()
        self.extract_waves()
        self.extract_grooves()
        self.update_groove_array()
        self.scale_waveforms()
        self.split_channels()
        self.replace_with_nan()
        self.create_graphs()
        return None


def calculate_crc(data: List[int]) -> int:
    crc = 0xFFFF
    for byte in data:
        crc ^= byte
        for _ in range(8):
            if crc & 0x0001:
                crc >>= 1
                crc ^= 0xA001
            else:
                crc >>= 1
    return crc


def update_waveforms(data: Dict[str, List[int]]) -> Dict[str, List[int]]:
    # First update the groove array
    data = update_groove_array(data)

    # Then scale the waveforms as before
    data['waveforms'] = [value / 100.0 for value in data['waveforms']]

    return data



def create_marker_sizes(channel: np.ndarray) -> List[int]:
    mask = np.isnan(channel)
    first_dot_after_nan = np.concatenate(([False], np.diff(mask) == -1))
    return [20 if x else 1 for x in first_dot_after_nan]

def create_waveform_traces(channel_A: np.ndarray, channel_B: np.ndarray) -> Tuple[go.Scatter, go.Scatter]:
    trace_A = go.Scatter(y=channel_A, mode='lines+markers', name='Channel A')
    trace_B = go.Scatter(y=channel_B, mode='lines+markers', name='Channel B')
    return trace_A, trace_B

def create_orbit_plot_trace(channel_A: np.ndarray, channel_B: np.ndarray) -> go.Scatter:
    marker_sizes = create_marker_sizes(channel_A)
    return go.Scatter(x=channel_A, y=channel_B, mode='lines+markers', marker=dict(size=marker_sizes, opacity=0.8))

def set_machine_saver_theme():
    """Set up the Machine Saver plotly theme."""
    pio.templates["draft"] = go.layout.Template(
        layout_annotations=[
            dict(
                name="draft watermark",
                text="Machine Saver Inc.",
                textangle=-30,
                opacity=0.1,
                font=dict(color="white", size=50),
                xref="paper",
                yref="paper",
                x=0.5,
                y=0.5,
                showarrow=True,
            )
        ],
        data_scatter=[
            dict(
                marker=dict(size=2),  # Increase marker size here
            )
        ]
    )
    pio.templates.default = "plotly_dark+draft"

def main():
    twinprox = TwinProx(port="COM4", slaveaddress=5)
    twinprox.trigger_capture()
    while twinprox.check_capture_engine() != 0:
        pass
    twinprox.extract_data()

while True:
    main()