In [1]:
import socket
import numpy as np

class DataStream:
    def __init__(self, IP, PORT, 
                 buffer_size=1024,
                 total_channels=8,      # 硬件总通道数
                 used_channels=7,       # 实际使用的通道数（取前 N 个）
                 pkg_groups=5,          # 每个网络包包含的时间点数
                 data_group_len=250):   # 每次 __next__ 返回的时间点数
        """
        在线数据流：从 TCP 服务器接收 CSV 格式的神经信号数据。
        
        数据格式假设（每包）：
            timestamp, marker, ch0_t0, ch1_t0, ..., ch7_t0, ch0_t1, ..., ch7_t4
            → 共 2 + pkg_groups * total_channels 个字段
        """
        self.ip = IP
        self.port = PORT
        self.buffer_size = buffer_size
        self.total_channels = total_channels
        self.used_channels = used_channels
        self.pkg_groups = pkg_groups
        self.data_group_len = data_group_len
        
        self.is_running = False
        self.socket = None
        self._buffer_str = ""      # 累积未解析的原始字符串
        self._data_buffer = []     # 累积已解析的时间点（每个是 shape=(used_channels,)）

    def __iter__(self):
        if self.is_running:
            self.close()
        self.is_running = True
        self._connect()
        # 重置缓冲区
        self._buffer_str = ""
        self._data_buffer = []
        return self

    def _connect(self):
        try:
            self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.socket.settimeout(5)
            self.socket.connect((self.ip, self.port))
            self.socket.send(b'start')
            print(f"[DataStream] Connected to {self.ip}:{self.port}")
        except Exception as e:
            print(f"[DataStream] Connection Error: {e}")
            self.is_running = False
            raise

    def close(self):
        self.is_running = False
        if self.socket:
            try:
                self.socket.close()
            except:
                pass
            self.socket = None

    def __next__(self):
        if not self.is_running:
            raise StopIteration

        while len(self._data_buffer) < self.data_group_len:
            try:
                chunk = self.socket.recv(self.buffer_size)
                if not chunk:
                    raise ConnectionError("Server closed connection.")
                self._buffer_str += chunk.decode('utf-8', errors='ignore')

                # 尝试解析所有完整数据包
                while True:
                    lines = self._buffer_str.split('\n')
                    if len(lines) < 2:
                        break  # 没有完整行
                    # 处理所有完整行（除最后一行可能不完整）
                    complete_lines = lines[:-1]
                    self._buffer_str = lines[-1]  # 保留不完整尾部

                    for line in complete_lines:
                        if not line.strip():
                            continue
                        fields = line.strip().split(',')
                        expected_fields = 2 + self.pkg_groups * self.total_channels
                        if len(fields) < expected_fields:
                            continue  # 数据不完整，跳过

                        # 解析数据部分：跳过前2个字段（timestamp, marker）
                        try:
                            data_vals = list(map(float, fields[2:2 + self.pkg_groups * self.total_channels]))
                        except ValueError:
                            continue  # 转换失败，跳过

                        # 转为 (pkg_groups, total_channels) → 取前 used_channels
                        arr = np.array(data_vals, dtype=np.float32)
                        arr = arr.reshape(self.pkg_groups, self.total_channels)
                        arr = arr[:, :self.used_channels]  # shape: (pkg_groups, used_channels)

                        # 将每个时间点加入缓冲区（每个时间点是 (used_channels,)）
                        for t in range(self.pkg_groups):
                            self._data_buffer.append(arr[t].tolist())

                    # 检查是否已收集足够数据
                    if len(self._data_buffer) >= self.data_group_len:
                        break

            except socket.timeout:
                continue
            except Exception as e:
                print(f"[DataStream] Receive/Parse Error: {e}")
                self.close()
                raise StopIteration

        # 返回所需数量的数据（列表 of 列表）
        result = self._data_buffer[:self.data_group_len]
        self._data_buffer = self._data_buffer[self.data_group_len:]  # 保留多余数据
        return result

    def __del__(self):
        self.close()

In [None]:
stream = DataStream(
    IP="127.0.0.1",
    PORT=9600,
    total_channels=8,
    used_channels=7,
    pkg_groups=5,
    data_group_len=250
)

try:
    for data_group in stream:  # 每次返回 250 个时间点 × 7 通道
        arr = np.array(data_group)  # shape: (250, 7)
        print("Received:", arr.shape)
        # 在此处处理数据...
except KeyboardInterrupt:
    pass
finally:
    stream.close()