In [None]:
import os
import cv2
import torch
import numpy as np
import sys
import shutil
from datetime import datetime
import glob
import gc
import time
# import spaces         # only for web demo

from pi3.utils.geometry import se3_inverse, homogenize_points, depth_edge
from pi3.models.pi3 import Pi3
from pi3.utils.basic import load_images_as_tensor
# import torch._dynamo
# torch._dynamo.config.accumulated_cache_size_limit = 10240



In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = Pi3.from_pretrained("yyfz233/Pi3").to(device).eval()
model = torch.compile(model)

target_dir = "/root/repos/Pi3/input_images_20250814_090447_171967"
imgs = load_images_as_tensor(os.path.join(target_dir, "images"), interval=1, PIXEL_LIMIT=255000).to(device)

print("Running model inference...")
dtype = torch.bfloat16
with torch.no_grad():
    with torch.amp.autocast('cuda', dtype=dtype):
        # Run inference 10 times and compute average time
        times = []
        num_run = 10
        for _ in range(num_run):
            torch.cuda.synchronize()
            t0 = time.time()
            predictions = model(imgs[None].to("cuda:0"))  # Ensure input is on the first device of tensor parallel
            if _ == 0:
                print(f"imgs[None] shape: {imgs[None].shape}")
            torch.cuda.synchronize()
            t1 = time.time()
            times.append(t1 - t0)
        # Remove top 5 time diffs (slowest runs)
        times_sorted = sorted(times)
        if len(times_sorted) > 5:
            times_filtered = times_sorted[:-5]
        else:
            times_filtered = times_sorted
        avg_time = sum(times_filtered) / len(times_filtered)
        print(f"Average inference time over {len(times_filtered)} runs (top 5 removed): {avg_time:.4f} seconds")

In [None]:
def run_inference(img_tensors):
    with torch.amp.autocast('cuda', dtype=dtype):
        for _ in range(num_run):
            torch.cuda.synchronize()
            t0 = time.time()
            predictions = model(imgs[None].to("cuda:0"))  # Ensure input is on the first device of tensor parallel
            if _ == 0:
                print(f"imgs[None] shape: {imgs[None].shape}")
            torch.cuda.synchronize()
            t1 = time.time()
            times.append(t1 - t0)
        # Remove top 5 time diffs (slowest runs)
        times_sorted = sorted(times)
        if len(times_sorted) > 5:
            times_filtered = times_sorted[:-5]
        else:
            times_filtered = times_sorted
        avg_time = sum(times_filtered) / len(times_filtered)
        print(f"Average inference time over {len(times_filtered)} runs (top 5 removed): {avg_time:.4f} seconds")

In [None]:
# 只能在 test.py 中运行，在jupyter中会报错


import multiprocessing
import time
import random
from concurrent.futures import ProcessPoolExecutor
from itertools import count
from pi3.utils.geometry import se3_inverse, homogenize_points, depth_edge
from pi3.models.pi3 import Pi3
from pi3.utils.basic import load_images_as_tensor
import os
import sys
import torch

# Konfiguration
NUM_GPUS = 3
PROCESSING_TIME_MIN_MS = 50
PROCESSING_TIME_MAX_MS = 200
CLIENT_REQUEST_INTERVAL_MS = 30
TOTAL_REQUESTS = 20000


device_list = ["cuda:0", "cuda:1", "cuda:2"]
# model_list = [
#     torch.compile(Pi3.from_pretrained("yyfz233/Pi3").to(device_list[i]).eval())
#     for i in range(NUM_GPUS)
# ]
# change the above code to parallel load parameters

import torch, torch.multiprocessing as mp, os

model_dict = {}

def worker(rank, data):
    torch.cuda.set_device(rank)
    if model_dict.get(rank) is None:
        sys.stdout.write(f"Loading model on gpu-{rank}\n")
        sys.stdout.flush()
        model = torch.compile(Pi3.from_pretrained("yyfz233/Pi3").to(device_list[rank]).eval())
        model_dict[rank] = model
    model = model_dict[rank]
    data = data.to(device_list[rank])
    dtype = torch.bfloat16
    with torch.no_grad():
        with torch.amp.autocast('cuda', dtype=dtype):
            # output = model(data[None])
            times = []
            num_run = 10
            for _ in range(num_run):
                torch.cuda.synchronize()
                t0 = time.time()
                predictions = model(data[None])  # Add batch dimension
                torch.cuda.synchronize()
                t1 = time.time()
                times.append(t1 - t0)
            # Remove top 5 time diffs (slowest runs)
            times_sorted = sorted(times)
            if len(times_sorted) > 5:
                times_filtered = times_sorted[:-5]
            else:
                times_filtered = times_sorted
            avg_time = sum(times_filtered) / len(times_filtered)
            sys.stdout.write(f"GPU-{rank} Average inference time over {len(times_filtered)} runs (top 5 removed): {avg_time:.4f} seconds\n")
            sys.stdout.flush()
        # Process output
        sys.stdout.write(f"Processed data on gpu-{rank}\n")
        sys.stdout.flush()


if __name__ == "__main__":
    mp.set_start_method("spawn")
    torch._dynamo.config.capture_scalar_outputs = True
    torch.set_float32_matmul_precision('high')
    world_size = torch.cuda.device_count()
    target_dir = "/root/repos/Pi3/input_images_20250814_090447_171967"
    data_list = []
    for i in range(NUM_GPUS):
        imgs = load_images_as_tensor(os.path.join(target_dir, "images"), interval=1, PIXEL_LIMIT=255000).to(device_list[i])
        data_list.append(imgs)

    procs = []
    for rank in range(world_size):
        p = mp.Process(target=worker, args=(rank, data_list[rank]))
        p.start()
        procs.append(p)

    for p in procs:
        p.join()
    print("All finished.")





  from .autonotebook import tqdm as notebook_tqdm


Loading images from directory: /root/repos/Pi3/input_images_20250814_090447_171967/images
Found 2 images/frames. Processing...
All images will be resized to a uniform size: (672, 378)
Loading images from directory: /root/repos/Pi3/input_images_20250814_090447_171967/images
Found 2 images/frames. Processing...
All images will be resized to a uniform size: (672, 378)
Loading images from directory: /root/repos/Pi3/input_images_20250814_090447_171967/images
Found 2 images/frames. Processing...
All images will be resized to a uniform size: (672, 378)
All finished.


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/conda/envs/pi3/lib/python3.11/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pi3/lib/python3.11/multiprocessing/spawn.py", line 132, in _main
    self = reduction.pickle.load(from_parent)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: Can't get attribute 'worker' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/conda/envs/pi3/lib/python3.11/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pi3/lib/python3.11/multiprocessing/spawn.py", line 132, in _main
    self = reduction.pickle.load(from_parent)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: Can't get attribute 'worker' on <module

: 

In [None]:
run_model_on_gpu(0)

In [None]:
def get_time_ms():
    """Gibt die aktuelle Zeit in Millisekunden zurück."""
    return int(time.time() * 1000)

def gpu_worker(task_id, gpu_id, imgs):
    """
    Simuliert die Verarbeitung von Bildern auf einer GPU.
    """
    start_processing_time = get_time_ms()
    processing_duration = random.randint(PROCESSING_TIME_MIN_MS, PROCESSING_TIME_MAX_MS)
    time.sleep(processing_duration / 1000.0)
    end_processing_time = get_time_ms()
    
    result = {
        "task_id": task_id,
        "gpu_id" : gpu_id,
        "status": "processed",
        "start_processing_time": start_processing_time,
        "end_processing_time": end_processing_time,
        "processing_duration_ms": end_processing_time - start_processing_time,
    }
    return result

def server_process(request_queue, result_queue):
    gpu_pool = ProcessPoolExecutor(max_workers=NUM_GPUS)
    gpu_iterator = count()

    futures = {}

    while True:
        # Auf neue Anfragen warten
        if not request_queue.empty():
            task = request_queue.get()
            if task == "STOP":
                break

            task_id, image_data, client_send_time = task
            server_receive_time = get_time_ms()
            
            gpu_id = next(gpu_iterator) % NUM_GPUS
            # Aufgabe an einen GPU-Worker senden
            future = gpu_pool.submit(gpu_worker, task_id, image_data)
            futures[future] = (task_id, server_receive_time)

        # Abgeschlossene Aufgaben prüfen
        for future in list(futures):
            if future.done():
                try:
                    result = future.result()
                    task_id, server_receive_time = futures.pop(future)
                    server_send_time = get_time_ms()
                    result["server_receive_time"] = server_receive_time
                    result["server_send_time"] = server_send_time
                    result_queue.put(result)
                except Exception as e:
                    print(f"[Server] Fehler bei der Verarbeitung von Aufgabe: {e}")
                    task_id, _ = futures.pop(future)
                    # Fehler an den Client senden
                    error_result = {
                        "task_id": task_id,
                        "status": "error",
                        "error_message": str(e),
                    }
                    result_queue.put(error_result)


    gpu_pool.shutdown()
    print("[Server] Server wurde heruntergefahren.")

def client_process(request_queue, result_queue):
    """
    Der Clientprozess, der regelmäßig Anfragen sendet.
    """
    print("[Client] Client wird gestartet...")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    target_dir = "/root/repos/Pi3/input_images_20250814_090447_171967"
    imgs = load_images_as_tensor(os.path.join(target_dir, "images"), interval=1, PIXEL_LIMIT=255000).to(device)

    for i in range(TOTAL_REQUESTS):
        task_id = f"Task-{i+1}"
        # image_data = ["image_data_1", "image_data_2"]
        
        client_send_time = get_time_ms()
        request_queue.put((task_id, imgs, client_send_time))
        print(f"[{client_send_time}ms] [Client] Anfrage {task_id} gesendet.")
        
        time.sleep(CLIENT_REQUEST_INTERVAL_MS / 1000.0)

    # Signal zum Beenden an den Server senden
    request_queue.put("STOP")
    
    # Auf Ergebnisse warten
    completed_requests = 0
    while completed_requests < TOTAL_REQUESTS:
        if not result_queue.empty():
            result = result_queue.get()
            client_receive_time = get_time_ms()
            
            print("\n" + "="*50)
            print(f"[{client_receive_time}ms] [Client] Ergebnis für {result['task_id']} erhalten.")
            
            if result['status'] == 'processed':
                total_duration = client_receive_time - result['server_receive_time'] + (result['server_receive_time'] - result['start_processing_time'])
                print(f"  - GPU-Verarbeitungsdauer: {result['processing_duration_ms']}ms")
                print(f"  - Gesamtdauer:              {total_duration}ms")
            else:
                print(f"  - Fehler bei der Verarbeitung: {result['error_message']}")
            print("="*50 + "\n")
            
            completed_requests += 1

    print("[Client] Client hat alle Antworten erhalten und wird beendet.")

if __name__ == "__main__":
    multiprocessing.set_start_method('spawn')
    
    # Warteschlangen für die Kommunikation zwischen den Prozessen erstellen
    requests = multiprocessing.Queue()
    results = multiprocessing.Queue()

    # Server- und Client-Prozesse erstellen
    server = multiprocessing.Process(target=server_process, args=(requests, results))
    client = multiprocessing.Process(target=client_process, args=(requests, results))

    # Prozesse starten
    server.start()
    client.start()

    # Auf das Ende der Prozesse warten
    client.join()
    server.join()

    print("Simulation abgeschlossen.")