In [None]:
from call_count import *
def run_inference(batch_sizes=[64, 128, 256, 512], num_batches=5):
    """
    Run inference on Triton models and measure operator-level performance
    """
    print("Starting inference test with real dataset...")
    
    # Define device
    device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
    torch.cuda.set_device(device)  # Set default device
    torch.cuda.empty_cache()  # Clear unused GPU memory
    
    # Try to load pre-trained model
    try:
        model_state_dict = torch.load("model.pt")
        print("Successfully loaded pre-trained model weights")
        pytorch_model = Net()
        pytorch_model.load_state_dict(model_state_dict)
    except:
        print("Could not load pre-trained model, using default weights")
        pytorch_model = Net()
    # pytorch_model = Net()
    
    # Convert model to use Triton implementations
    pytorch_model = pytorch_model.to(device).half()
    triton_model = convert_pytorch_to_triton_model(pytorch_model)
    triton_model = triton_model.to(device).half()
    
    # Set to evaluation mode
    triton_model.eval()
    
    # Create the dataset once (we'll reuse it for different batch sizes)
    try:
        train_dataset = GDdataset("./train_data.csv")
        print(f"Successfully loaded dataset with {len(train_dataset)} samples")
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return
    
    # Test for different batch sizes
    for batch_size in batch_sizes:
        print(f"\n==== Batch Size: {batch_size} ====")
        
        # Store batches for testing - manually shuffle the dataset indices
        indices = torch.randperm(len(train_dataset), device='cpu').tolist()
        test_batches = []
        batch_count = 0
        
        for i in range(0, min(len(indices), batch_size * num_batches), batch_size):
            if batch_count >= num_batches:
                break
                
            # Get a batch of indices
            batch_indices = indices[i:i+batch_size]
            if len(batch_indices) < batch_size:
                # Skip incomplete batches
                continue
                
            # Collect inputs and targets for this batch
            batch_inputs = []
            batch_targets = []
            for idx in batch_indices:
                inp, tgt = train_dataset[idx]
                batch_inputs.append(inp)
                batch_targets.append(tgt)
            
            # Stack the tensors into a batch
            inputs = torch.stack(batch_inputs).to(device).half()
            targets = torch.stack(batch_targets).to(device).half()
            
            test_batches.append((inputs, targets))
            batch_count += 1
        
        # Ensure we have enough batches
        if len(test_batches) < num_batches:
            print(f"Warning: Could only extract {len(test_batches)} batches instead of {num_batches}")
            num_actual_batches = len(test_batches)
        else:
            num_actual_batches = num_batches
        
        # Verify data shape
        print(f"Input shape: {test_batches[0][0].shape}")
        
        # Warmup
        for i in range(min(10, len(test_batches))):
            inputs, _ = test_batches[i % len(test_batches)]
            with torch.no_grad():
                _ = triton_model(inputs)
        torch.cuda.synchronize()
        
        # Reset counters before measurement
        reset_counters()
        
        # Triton model inference with operator-level profiling
        with monitor():
            for i in range(num_actual_batches):
                inputs, targets = test_batches[i]
                with torch.no_grad():
                    triton_output = triton_model(inputs)
                torch.cuda.synchronize()
        
        # Print operator-level statistics
        # print("\nOperator-level performance statistics (top 10 by total time):")
        # print_statistics(sort_by="total_time", top_n=10)
        print("\nOperator-level performance statistics")
        print_statistics()
        
        # Print sample outputs for the last batch
        batch_to_print = min(batch_size, triton_output.size(0))
        print(f"\nSample output (first {min(5, batch_to_print)} from last batch):")
        for i in range(min(5, batch_to_print)):
            print(f"  Sample {i}: Output={triton_output[i].item():.4f}")

In [None]:
import functools
import torch
import torch.nn.functional as F
import types
import time
import pprint
# 创建一个字典来存储调用次数
call_count = {}

def count_calls(func, module_name=None):
    @functools.wraps(func)
    def wrapper_count_calls(*args, **kwargs):
        # module_name = func.__module__ if hasattr(func, '__module__') and func.__module__ else 'torch'
        full_name = module_name + '.' + func.__name__
        # call_count[full_name] = call_count.get(full_name, 0) + 1
        # print(f"Function {full_name} called {call_count[full_name]} times")
        # return func(*args, **kwargs)
        # 记录开始时间
        start_time = time.time()
        result = func(*args, **kwargs)
        # 记录结束时间
        end_time = time.time()
        
        # 计算调用时间
        elapsed_time = end_time - start_time
        
        if full_name not in call_count:
            call_count[full_name] = {"count": 0, "total_time": 0.0}
        
        call_count[full_name]["count"] += 1
        call_count[full_name]["total_time"] += elapsed_time
        
        return result
    wrapper_count_calls._is_decorated = True
    return wrapper_count_calls

def set_new_attr(module, attr_name, attr):
    if not hasattr(attr, "_is_decorated"):
        decorated_attr = count_calls(attr, module.__name__)
        decorated_attr._is_decorated = True
        setattr(module, attr_name, decorated_attr)

# 递归封装所有的包
def auto_decorate_module(module, visited=None):
    if visited is None:
        visited = set()
    
    module_name = module.__name__
    if module_name in visited:
        return
    visited.add(module_name)
    for attr_name in dir(module):
        try:
            attr = getattr(module, attr_name)
            # if isinstance(attr, types.FunctionType):
            if isinstance(attr, types.FunctionType):
                set_new_attr(module, attr_name, attr)
                # print(f"Decorated function: {module_name}.{attr_name}")
            elif isinstance(attr, types.ModuleType) and attr.__name__.startswith('torch'):
                # print(f"Descending into module: {attr.__name__}")
                auto_decorate_module(attr, visited)
            elif isinstance(attr, type):
                # print(f"Descending into class: {attr.__name__} in {module_name}")
                auto_decorate_class(attr)
            elif callable(attr):
                set_new_attr(module, attr_name, attr)
        except AttributeError:
            continue


def auto_decorate_class(cls):
    for attr_name in dir(cls):
        # if attr_name.startswith('__') and attr_name.endswith('__'):
        #     continue  # Skip special attributes
        try:
            attr = getattr(cls, attr_name)
            if isinstance(attr, types.FunctionType):
                set_new_attr(cls, attr_name, attr)
            elif attr_name in ['__add__', '__mul__', '__sub__', '__truediv__', '__matmul__', '__pow__', '__mod__']:
                # 特殊处理运算符重载方法
                set_new_attr(cls, attr_name, attr)
        except (AttributeError, TypeError) as e:
            continue

In [None]:
import swat
import os
import numpy as np
import pandas as pd
import sys
import dlpy
from dlpy import Sequential
from dlpy import *
from dlpy.model import TextParms
from dlpy.blocks import Bidirectional
from dlpy.applications import TextClassification
from dlpy.network import *
from dlpy.utils import *
from dlpy.applications import *
from dlpy.model import *
from dlpy.images import *
from dlpy.layers import *
cashost = 'sas-cas-server-default-client'
conn = swat.CAS(cashost, 5570, password=os.environ.get('ACCESS_TOKEN'))

conn.loadTable(path='jixie_train_data.csv', casout={'name': 'jixie_train_data', 'caslib': 'casuser'}, importOptions={'fileType': 'csv'})
tb = conn.CASTable('jixie_train_data', caslib='casuser')
tb.shape