In [3]:
from typing import List,Tuple,Sequence,Union

from pathlib import Path
import dataloader as dl
import xarray as xr
import sys
sys.path.append('../')
import smartscan.TCP as TCP
from importlib import reload

In [None]:
reload(TCP)
class Controller(TCP.TCPClient):

    def __init__(
            self, 
            host: str, 
            port: int,
            **kwargs
            ) -> None:
        super().__init__(host,port,**kwargs)
        self.filename = None
        self.ndim = None
        self.limits = None
        self.current_pos = None

    def send_tcp_message()
        

    def get_scan_info(self) -> None:
        """ ask SGM4 for the scan info 
        
        commands:
        NDIM - get number of dimensions
        LIMITS - after NDIM, as you need to know how many to expect
        FILENAME - get the filename of the scan
        CURRENT_POS - where are we starting from?
        """
        self.ndim = self.NDIM()
        self.limits = self.LIMITS()
        self.filename = self.FILENAME()
        self.current_pos = self.CURRENT_POS()
        
    def parse_h5_file(self, filename:str | Path) -> None:
        """Read the h5 file and get the scan info

        assert the info found in the file matches the info found in the SGM4


        Args:
            filename: path to the h5 file

        Returns:
            None
        """
        raise NotImplementedError
    
    def send_command(self, command, *args) -> None:
        """ send a command to the SGM4 and wait for a response"""
        message = command
        for arg in args:
            message += f' {arg}'
        response = self.receive_message()
        if "INVALID" in response:
            raise RuntimeError(f"Invalid command: {command}")
        return response
    
    # TCP commands
    def NDIM(self) -> int:
        """ get the number of dimensions """
        response = self.send_command('NDIM')
        return int(response)
    
    def LIMITS(self) -> List[Tuple[float]]:
        """ get the limits of the scan """
        response = self.send_command('LIMITS')
        return [tuple(float(x) for x in line.split()) for line in response.splitlines()]


# direct controller

In [None]:
import socket
reload(TCP)

def send_tcp_message(
        host:str,
        port:int|str,
        msg:str,
        checksum:bool=False,
        buffer_size:int=1024,
        verbose:bool=False,
        timeout:float=1.0,
    ) -> str:
    """ send a message to a host and port and return the response

    Args:
        host: host to connect to
        port: port to connect to
        msg: message to send
        checksum: add a checksum to the message
        buffer_size: size of the buffer to use
        verbose: print out extra information

    Returns:
        data: response from host
    """
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        if verbose:
            print(f"Connecting to {host}:{port}")
        s.connect((host, port))
        if checksum:
            msg = TCP.add_checksum(msg)
            if verbose:
                print(f"Sending message with checksum: {msg}")
        else:
            if verbose:
                print(f"Sending message: {msg}")
        s.sendall(msg.encode())
        if verbose:
            print("Waiting for response")
        data = s.recv(buffer_size).decode()
        if verbose:
            print(f"Received: {data}")
        data = TCP.remove_checksum(data)
    return data

class Controller:

    def __init__(
            self, 
            host: str, 
            port: int,
            checksum:bool=False,
            verbose:bool=True,
            timeout:float=1.0,
            buffer_size:int=1024,
            ) -> None:
        # TCP
        self.host = host
        self.port = port
        self.checksum = checksum
        self.verbose = verbose
        self.timeout = timeout
        self.buffer_size = buffer_size

        self.filename = None
        self.ndim = None
        self.limits = None
        self.current_pos = None

    def send_command(self, command, *args) -> None:
        """ send a command to the SGM4 and wait for a response"""
        message = command.upper()
        for arg in args:
            message += f' {arg}'
        response = send_tcp_message(
            host=self.host,
            port=self.port,
            msg=message,
            checksum=self.checksum,
            verbose=self.verbose,
            timeout=self.timeout,
            buffer_size=self.buffer_size,
        )
        if "INVALID" in response:
            raise RuntimeError(f"Invalid command: {command}")
        return response
    
    def get_scan_info(self) -> None:
        """ ask SGM4 for the scan info 
        
        commands:
        NDIM - get number of dimensions
        LIMITS - after NDIM, as you need to know how many to expect
        FILENAME - get the filename of the scan
        CURRENT_POS - where are we starting from?
        """
        self.ndim = self.NDIM()
        self.limits = self.LIMITS()
        self.filename = self.FILENAME()
        self.current_pos = self.CURRENT_POS()

    def parse_h5_file(self, filename:str | Path) -> None:
        """Read the h5 file and get the scan info

        assert the info found in the file matches the info found in the SGM4


        Args:
            filename: path to the h5 file

        Returns:
            None
        """
        raise NotImplementedError
        

    def LIMITS(self) -> List[Tuple[float]]:
        """ get the limits of the scan """
        response = self.send_command('LIMITS')
        split = response.split(' ')
        assert split[0] == 'LIMITS', f"Expected LIMITS, got {split[0]}"
        return [tuple(lim.split(',')) for lim in split[1:]]
    
    def NDIM(self) -> int:
        """ get the number of dimensions """
        response = self.send_command('NDIM')
        split = response.split(' ')
        assert split[0] == 'NDIM', f"Expected NDIM, got {split[0]}"
        return int(split[1])
    
    def FILENAME(self) -> str:
        """ get the filename of the scan """
        response = self.send_command('FILENAME')
        split = response.split(' ')
        assert split[0] == 'FILENAME', f"Expected FILENAME, got {split[0]}"
        return split[1]
    
    def CURRENT_POS(self) -> List[float]:
        """ get the current position """
        response = self.send_command('CURRENT_POS')
        split = response.split(' ')
        assert split[0] == 'CURRENT_POS', f"Expected CURRENT_POS, got {split[0]}"
        return [float(x) for x in split[1:]]
    
    def ADD_POINT(self, *args) -> None:
        """ add a point to the scan """
        assert len(args) == self.ndim, f"Expected {self.ndim} args, got {len(args)}"
        response = self.send_command('ADD_POINT', *args)
        split = response.split(' ')
        assert split[0] == 'ADD_POINT', f"Expected ADD_POINT, got {split[0]}"
        return None
    

In [None]:
ctrl = Controller('localhost', 12345)

In [None]:
ctrl.get_scan_info()

In [None]:
print(f'current position: {ctrl.current_pos}')
print(f'limits: {ctrl.limits}')
print(f'filename: {ctrl.filename}')
print(f'ndim: {ctrl.ndim}')

In [None]:
ctrl.ADD_POINT(200,200)

In [None]:
ctrl.CURRENT_POS()