In [None]:
import pandas as pd
import numpy as np

import msgpack
import base64
import blosc

blosc.set_nthreads(2)

import time
from tenacity import *
from logzero import logger

from redis.cluster import RedisCluster, ClusterNode

redis_nodes = [
    ClusterNode("10.4.208.2", 6379),
    ClusterNode("10.4.208.5", 6379),
    ClusterNode("10.4.208.6", 6379),
    ClusterNode("10.4.208.7", 6379),
    ClusterNode("10.4.208.28", 6379),
    ClusterNode("10.4.208.30", 6379),
    ClusterNode("10.4.208.31", 6379),
]
rc = RedisCluster(startup_nodes=redis_nodes, decode_responses=False)


def return_none(retry_state):
    return None


class Data_uri:
    encoder_map = {
        "base64": base64.b64encode,
        "msgpack": msgpack.packb,
        "blosc": lambda msg: blosc.compress(msg, clevel=1, cname="zstd"),
        "blosc_array": lambda data: blosc.pack_array(data, cname="zstd", clevel=1),
    }
    decoder_map = {
        "base64": base64.b64decode,
        "msgpack": msgpack.unpackb,
        "blosc": blosc.decompress,
        "blosc_array": blosc.unpack_array,
        "Text/string": lambda msg: msg.decode() if type(msg) is bytes else msg,
        "Text/bytes": lambda msg: msg.encode() if type(msg) is str else msg,
        "Numeric/int": int,
        "Numeric/float": float,
    }
    type_desc = {
        str: "Text/string",
        int: "Numeric/int",
        float: "Numeric/float",
        bytes: "Text/bytes",
        bool: "Bool",
        list: "List",
        tuple: "Tuple",
        dict: "Hash",
        np.ndarray: "Tensor",
    }

    def encode(self, data, encoding="msgpack"):
        original_type = self.type_desc[type(data)]
        if original_type == "Tensor" and encoding == "msgpack":
            encoding = "blosc_array"

        encodings = encoding.replace(" ", "").split(",")
        msg = "data:{}".format(original_type)
        for encoding in encodings:
            data = self.encoder_map[encoding](data)
            msg += ":{}".format(encoding)

        return msg.encode() + b"," + data

    def decode(self, msg):
        if type(msg) is str:
            data = msg.encode()
        elif type(msg) is bytes:
            pass
        else:
            logger.warn("Not supported data type: {}".format(type(msg)))
            return msg

        if msg[:5] != b"data:":
            raise ValueError(
                "Data_uri encoded message should starts with b'data:' but this message starts with {}".format(
                    msg[:5]
                )
            )

        data_ptr = msg.find(b",")
        header = msg[:data_ptr].decode().split(":")
        data = msg[data_ptr + 1 :]
        for decoder in header[1:][::-1]:
            if decoder in self.decoder_map.keys():
                data = self.decoder_map[decoder](data)

        return data


data_uri = Data_uri()


@retry(stop=stop_after_attempt(10), retry_error_callback=return_none)
def get_data_from_redis(key, add_device_name=False):
    data = rc.hgetall(key)
    if len(data) == 0:
        time.sleep(0.1)
        logger.warn("Broken saurce, Cannot get_data from {}".format(key))
        raise IOError("Broken saurce, Cannot get_data")

    if add_device_name is False:
        data = {k.decode(): data_uri.decode(v) for k, v in data.items()}
    else:
        if type(key) is bytes:
            key = key.decode()
        device_name = "_".join(key.split(":")[1:4])
        try:
            data = {
                (
                    k.decode()
                    if k.startswith(b"timestamp")
                    else "_".join([device_name, k.decode()])
                ): data_uri.decode(v)
                for k, v in data.items()
            }
        except:
            logger.debug("Looks likes it is failed during decoding the data.")

    series = pd.Series(data)
    series.sort_index(inplace=True)
    if "timestamp" not in series.keys():
        print(series.keys())

    if type(series["timestamp"]) is bytes:
        timestamp = series["timestamp"].decode()
    else:
        timestamp = series["timestamp"]
    timestamp = int(int(timestamp.replace(".", "")) / 1e6)
    return series.rename(timestamp)


def get_raw_data(device_key, timestamp=None, add_device_name=False):
    if timestamp is None:
        timestamp = rc.get("timestamp:last:raw_data:{}".format(device_key)).decode()

    try:
        timestamp = float(timestamp)

        last_ts = float(
            rc.get("timestamp:last:raw_data:{}".format(device_key)).decode()
        )
        last_ts_init = last_ts
        not_update_count = 0
        while timestamp > last_ts - 5:
            time.sleep(0.1)
            last_ts = float(
                rc.get("timestamp:last:raw_data:{}".format(device_key)).decode()
            )
            not_update_count += last_ts_init == last_ts
            if not_update_count > 10:
                logger.warning(
                    "Data: {} looks like that is not updating...".format(device_key)
                )
                return None
    except:
        logger.warning(
            "Data: {} looks like that is not in the redis".format(device_key)
        )
        return None

    key_list = []
    not_update_count = 0
    while len(key_list) < 1:
        key_list = rc.zrangebyscore(
            "timestamp:raw_data:{}:{}".format(device_key, int(timestamp)),
            timestamp - 0.005,
            timestamp + 0.005,
        )
        not_update_count += 1
        if not_update_count > 10:
            logger.warning(
                "Data: {}:{} looks like that is not yet...".format(
                    device_key, timestamp
                )
            )
            return None

    key = key_list[0]
    if type(device_key) == bytes:
        device_key = device_key.decode()

    return get_data_from_redis(key, add_device_name)

In [None]:
import matplotlib.pyplot as plt

In [None]:
jungfrau = get_raw_data("detector:eh1:jungfrau2")
jungfrau

In [None]:
jungfrau["dark_bkg_substrated"]

In [None]:
plt.figure()
plt.imshow(jungfrau.image[350:450, 600:800], vmin=-2, vmax=20)
plt.colorbar()

plt.show()

In [None]:
plt.figure()
plt.hist(jungfrau.image.flatten(), bins=np.linspace(-2, 2, 41))
plt.yscale("log")
plt.show()

In [None]:
plt.figure()
plt.hist(jungfrau.image.flatten(), bins=np.linspace(-2, 2, 41))
plt.yscale("log")
plt.show()

In [None]:
rc.get("timestamp:last:raw_data:{}".format("detector:eh1:jungfrau3"))

In [None]:
qbpm["timestamp_info.XFEL_HX_BEAM"]

In [None]:
1714125361.195918606 - 1714125361.179272456

In [None]:
1 / 60

In [None]:
[k for k in jungfrau.index if "RATE_" in k]

In [None]:
for k in rc.scan_iter("*measurement*detector:eh1:jungfrauS1*"):
    print(k)