In [None]:
import random
import os

from emulate_client import *

print(f"CONFIG FILE {CONFIGS_DIR}")
CIANNA_OTF_DIR = "/home/gsainton/01_CODES/CIANNA_OTF/CODE"

In [None]:

"""
Main function for the client.

Steps:
    - Load the client configuration.
    - If a remote connection is specified, establish an SSH tunnel.
    - Update the local CIAnna models XML file by always retrieving the 
            latest version.
    - Locate FITS images from the designated input folder.
    - Emulate a set number of client requests to the server.

Note:
    - Vérifier sur le fichier FITS est compatible avec le modèle YOLO.
"""
# Load configuration from JSON file
print(CONFIGS_DIR)
config = load_config(os.path.join(CONFIGS_DIR,"param_cianna_rts_client.json"))

# Here is created a directory to save the XML files with the request to server
JOB_DIR = os.path.join(CIANNA_OTF_DIR, config.get("JOB_DIR"))
os.makedirs(JOB_DIR, exist_ok=True)
print(f"Job directory {JOB_DIR}")

# Path to the local Cianna models XML file
local_models_file = config.get("LOCAL_FILE_MODELS")

# Determine connection mode and set server URL accordingly.
print(40 * "-.")
client_connexion = config.get("CLIENT_CONNEXION", "local").lower()
tunnel = None
if client_connexion == "remote":
    print("Establishing remote connection via SSH tunnel...")
    tunnel = create_ssh_tunnel(
        ssh_server_ip = config.get("SSH_SERVER_IP"),
        ssh_username  = config.get("SSH_USERNAME"),
        ssh_password  = config.get("SSH_PASSWORD"),
        remote_port   = int(config.get("REMOTE_PORT", 5000)),
        local_port    = int(config.get("LOCAL_PORT", 5000))
    )
    server_url = f"http://127.0.0.1:{tunnel.local_bind_port}"
else:
    server_url = f"http://127.0.0.1:{config.get('LOCAL_PORT', 5000)}"

print(f"Connecting to a {client_connexion} server...")
print("Server URL:", server_url)
print(40 * "-.")

In [None]:
from pprint import pprint
from uuid import uuid4
norm_fct_avail = ["tanh", "power", "linear", "log", "squared" ," squared root"]
img_format_list = ["full", "region"]
min_pix, max_pix = 0.4e-6, 0.4e-4

# Update the local Cianna models XML file (always retrieves the latest version)
models_url = f"{server_url}/models/CIANNA_models.xml"
updated_result = update_cianna_models(models_url, local_models_file)
if updated_result is None:
    print("Error updating CIANNA models.")
    if tunnel is not None:
        tunnel.stop()
    #return

# Choices of the uses -> will be selected in a GUI

yolo_model_user = "SDC1_Cornu_2024"  # Supposed to be selected by the user
quant_user      = "FP32C_FP32A"    # Supposed to be selected by the user
norm_fct_user   ="tanh, linear"    # Supposed to be selected by the user
user_id = uuid4()
# Check to be added in the graphical interface

# parse norm_fct_user to check if it is a valid normalization function
norm_fct_user = norm_fct_user.split(",")
norm_fct_user = [fct.strip() for fct in norm_fct_user]
if not all(fct in norm_fct_avail for fct in norm_fct_user):
    sys.exit("One or more normalization functions are not available.")

# remove brackets and spaces in norm_fct_user
norm_fct_user = [fct.replace("[", "").replace("]", "").strip() for fct in norm_fct_user]

print(f"Norm functions selected by the user: {norm_fct_user}")
print(f"Type of norm_fct_user: {type(norm_fct_user)}")


img_format = "full"  # choice between full (the full raw image is sent)
                        #            and region (the coordinates are sent)

if img_format not in img_format_list:
    print("Wrong choice of image type ")
    sys.exit("Wrong choice of image type ") 

# Check if the image is compatible with the YOLO model
model_info = get_model_info(config.get("LOCAL_FILE_MODELS"), yolo_model_user)

if model_info is None:
    print (40*"-.")
    print("Unknown model, please check the models available.")
    print (40*"-.")
else:
    pprint(model_info)

In [None]:
import os
import matplotlib.pyplot as plt
from astropy.io import fits
from astropy.wcs import WCS
from astropy.visualization import (ZScaleInterval, ImageNormalize)


def plot_fits_ra_dec(image_path):
    """
    Plot a FITS image with RA/DEC axes using WCS projection.

    Parameters
    ----------
    image_path : str
        Path to the FITS file.
    """
    if not os.path.exists(image_path):
        print("FITS file does not exist:", image_path)
        return

    with fits.open(image_path) as hdul:
        data = hdul[0].data
        header = hdul[0].header

        # Handle 3D data (assumes shape is (z, y, x))
        if data is None:
            print("No data in FITS file.")
            return
        # Auto-squeeze and reduce to 2D
        while data.ndim > 2:
            data = data[0]

        if data.ndim != 2:
            print(f"Cannot plot image with shape {data.shape}")
            return

        # Normalize for visualization
        norm = ImageNormalize(data, interval=ZScaleInterval())

        # Create WCS object and plot
        wcs = WCS(header)
        if wcs.naxis==4:
            wcs = wcs.dropaxis(2).dropaxis(2)

        fig = plt.figure(figsize=(8, 6))
        ax = fig.add_subplot(111, projection=wcs)
        im = ax.imshow(data, origin='lower', cmap='gray', norm=norm)

        ax.coords.grid(True, color='white', ls='dotted')
        ax.set_xlabel('Right Ascension (J2000)')
        ax.set_ylabel('Declination (J2000)')
        plt.colorbar(im, ax=ax, orientation='vertical', label='Pixel value')
        plt.title(os.path.basename(image_path))
        plt.tight_layout()
        plt.show()

In [None]:

h   = random.randint(50, 200)
w   = random.randint(50, 200)
# Get list if images for test
image_folder = os.path.expanduser(config.get("IMAGE_FOLDER",
                                                "/home/gsainton/01_CODES/DIR_images"))

print(f"Looking for images in {image_folder}...")
images = [os.path.join(image_folder, img) for img in os.listdir(image_folder) if img.endswith(".fits")]
if not images:
    print("No fits images in ", image_folder)
    if tunnel is not None:
        tunnel.stop()
    #return


image_path = random.choice(images)
image_info = get_image_dim(image_path)

plot_fits_ra_dec(image_path)

In [None]:
if image_info is None:
    print(f"Error: Unable to read image dimensions from {image_path}.")
    sys.exit(0)
image_size = image_info.get('shape', (0, 0))
h = image_size[0]
w = image_size[1]

print(image_size)

if image_size[0] < h or image_size[1] < w:
    print(f"Error: Image dimensions {image_size} are smaller than the requested bounding box ({h}, {w}).")
    sys.exit(0)

image_coord = image_info.get("ra_dec", (None, None))


print(image_info)


In [None]:
import xml.etree.ElementTree as ET
from datetime import datetime
from xml.sax.saxutils import escape
from xml.dom import minidom

def create_xml_param(user_id, image_info, image_path, yolo_model,
                     quantization, norm_list, min_pix, max_pix,
                     model_filename, output_path=None):
    """
    Crée une structure XML contenant les paramètres de la requête qui 
    sera renvoyée au serveur. Optionnellement sauvegarde dans un fichier.
    """
    def safe(text):
        return escape(str(text))

    root = ET.Element("YOLO_CIANNA")
    
    ET.SubElement(root, "USER_ID").text = safe(user_id)
    ET.SubElement(root, "Timestamp").text = datetime.now().isoformat()
    
    image_elem = ET.SubElement(root, "Image")
    ET.SubElement(image_elem, "path").text = safe(image_path)
    ET.SubElement(image_elem, "RA").text = safe(image_info.get("ra_dec")[0])
    ET.SubElement(image_elem, "DEC").text = safe(image_info.get("ra_dec")[1])
    ET.SubElement(image_elem, "H").text = safe(image_info.get("shape")[0])
    ET.SubElement(image_elem, "W").text = safe(image_info.get("shape")[1])

    yolo_elem = ET.SubElement(root, "YOLO_Model")
    ET.SubElement(yolo_elem, "name").text = safe(yolo_model)
    ET.SubElement(yolo_elem, "filename").text = safe(model_filename)

    preproc_elem = ET.SubElement(root, "preprocessing")
    ET.SubElement(preproc_elem, "quantization").text = safe(quantization)
    ET.SubElement(preproc_elem, "normalisation").text = safe(norm_list)
    ET.SubElement(preproc_elem, "min_pix").text = safe(min_pix)
    ET.SubElement(preproc_elem, "max_pix").text = safe(max_pix)

    # Convert to pretty XML
    rough_string = ET.tostring(root, encoding="utf-8")
    reparsed = minidom.parseString(rough_string)
    pretty_xml = reparsed.toprettyxml(indent="  ")

    # Save if needed
    if output_path:
        with open(output_path, "w", encoding="utf-8") as f:
            f.write(pretty_xml)

    return pretty_xml

In [None]:

model_info = get_model_info(config.get("LOCAL_FILE_MODELS"), yolo_model_user)

print(f"[emulate_client_request] Model info: {model_info}")

#emulate_client_request(server_url, image_path, i+1, config)
#print(40 * "-.")

xml_data = create_xml_param(user_id, image_info, image_path, yolo_model_user,
                            quant_user, norm_fct_user, min_pix, max_pix,
                            model_info.get("Name"))

print(xml_data)

In [None]:
request_number = 1
process_id = send_xml_fits_to_server(server_url, xml_data)
if process_id is None:
    print(f"[EMULATE] Error sending request {request_number}")
else:
    print(f"[EMULATE] Request {request_number} sent successfully with process ID: {process_id}")

    try:
        # Poll for job completion
        print(f"[EMULATE] Polling for job {process_id} completion...")
        if poll_for_completion(server_url, process_id):
            print(f"[EMULATE] Job {process_id} completed successfully.")
            print(f"[EMULATE] Downloading result for job {process_id}...")
            download_result(server_url, process_id, destination_folder=DESTINATION_FOLDER)
            print(f"[EMULATE] Result for request {request_number} downloaded successfully.")
        else:
            print(f"[EMULATE] Error: Job {process_id} did not complete successfully.")
    except requests.ConnectionError as e:
        print(f"[EMULATE] Network error while polling/downloading: {e}")
    except requests.Timeout as e:
        print(f"[EMULATE] Timeout error: {e}")
    except Exception as e:
        print(f"[EMULATE] Unexpected error: {e}")

## Function to test the hardware of the server

In [None]:
import os
import sys
import xml.etree.ElementTree as ET
from xml.dom import minidom
from pprint import pprint

import torch
import pynvml   # Initialize NVML
import psutil   # Get system memory info
import platform # Get system information
import socket   # Get hostname

def get_hw_info():
    hw_info = {}


    hw_info["System"] = {
        "Hostname": socket.gethostname(),
        "OS": platform.system(),
        "OS Version": platform.version(),
        "Platform": platform.platform(),
        "Architecture": platform.machine()
    }

    hw_info['cpu'] = {
        'model': platform.processor(),
        'cores': psutil.cpu_count(logical=False),
        'threads': psutil.cpu_count(logical=True),  # ← utilisé plus loin
        'frequency': psutil.cpu_freq().current,
        'load': psutil.cpu_percent(interval=1),
        'RAM_GB': round(psutil.virtual_memory().total / 1e9, 2)  # ← manquant
    }

    gpu_info = []
    try:
        pynvml.nvmlInit()
        device_count = pynvml.nvmlDeviceGetCount()
        for i in range(device_count):
            handle = pynvml.nvmlDeviceGetHandleByIndex(i)
            name = pynvml.nvmlDeviceGetName(handle)
            memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
            total_memory = memory_info.total / (1024 ** 2)  # Convert to MB
            free_memory = memory_info.free / (1024 ** 2)    # Convert to MB

            cap = torch.cuda.get_device_capability(i)
            cap_str = f"{cap[0]}.{cap[1]}"
            support = {
                'fp16': torch.cuda.get_device_capability(i)[0] >= 7,
                'fp32': True,
                'int8': True,
                'bfloat16': torch.cuda.get_device_capability(i)[0] >= 8
            }

            gpu_info.append({
                "index": i,
                "name": name,                      # GPU name 
                "total_memory": total_memory,      # Total memory in MB
                "free_memory": free_memory,        # Free memory in MB
                "compute_capability": cap_str,     # Compute capability
                "support_data_type": support       # Supported data types
            })
        pynvml.nvmlShutdown()

    except pynvml.NVMLError as e:
        print(f"Error accessing GPU information: {e}")
        gpu_info = None

    hw_info['gpu'] = gpu_info

    return hw_info

print("Hardware Information:")
hw_info = get_hw_info()

pprint(hw_info, sort_dicts=False)


def save_hardware_to_xml(info, output_path="server_hardware_config.xml"):
    root = ET.Element("MachineConfig")

    # System metadata
    sys_elem = ET.SubElement(root, "System")
    for key, val in info.get("System", {}).items():
        ET.SubElement(sys_elem, key).text = str(val)

    # CPU
    cpu = info["cpu"]
    cpu_elem = ET.SubElement(root, "CPU")
    ET.SubElement(cpu_elem, "model").text = cpu["model"]
    ET.SubElement(cpu_elem, "cores").text = str(cpu["cores"])
    ET.SubElement(cpu_elem, "threads").text = str(cpu["threads"])
    ET.SubElement(cpu_elem, "frequency").text = str(cpu["frequency"])
    ET.SubElement(cpu_elem, "load").text = str(cpu["load"])

    ram_elem = ET.SubElement(cpu_elem, "RAM")
    ram_elem.set("unit", "GB")
    ram_elem.text = str(cpu["RAM_GB"])

    # GPU
    gpus = info["gpu"]
    gpu_root = ET.SubElement(root, "GPU")
    gpu_root.set("count", str(len(gpus)))

    for gpu in gpus:
        gpu_elem = ET.SubElement(gpu_root, "GPU", id=str(gpu["index"]))
        ET.SubElement(gpu_elem, "name").text = gpu["name"]
        
        mem_elem = ET.SubElement(gpu_elem, "memory")
        mem_elem.set("unit", "MB")
        mem_elem.text = str(round(gpu["total_memory"], 2))
        
        ET.SubElement(gpu_elem, "free_memory_MB").text = str(round(gpu["free_memory"], 2))
        ET.SubElement(gpu_elem, "compute_capability").text = gpu["compute_capability"]

        supports_elem = ET.SubElement(gpu_elem, "Supports")
        for dtype, supported in gpu["support_data_type"].items():
            ET.SubElement(supports_elem, dtype).text = str(supported).lower()

    # Pretty formatting
    indent_xml(root)
    tree = ET.ElementTree(root)
    tree.write(output_path, encoding="utf-8", xml_declaration=True)
    print(f"[OK] Hardware configuration saved to: {output_path}")



def indent_xml(elem, level=0):
    # Helper function to add indentation for readability
    i = "\n" + level * "  "
    if len(elem):
        if not elem.text or not elem.text.strip():
            elem.text = i + "  "
        for child in elem:
            indent_xml(child, level + 1)
        if not child.tail or not child.tail.strip():
            child.tail = i
    if level and (not elem.tail or not elem.tail.strip()):
        elem.tail = i

hw_info = get_hw_info()
save_hardware_to_xml(hw_info, output_path="server_hardware_config.xml")