# <center>我整个重写一遍吧

## 所有用到的环境

In [1]:
import functools;
import time;
import types;

In [2]:
import torch;

## 代码兼容性评估

### Some General Functions

In [3]:
def isFuncFromModule(func: types.FunctionType, module: str):
    return func.__module__.startswith(module);
isFuncFromModule(torch.nn.LSTM, "torch")

True

In [4]:
def getAttributes(module):
    return [i for i in dir(module) if not i.startswith("__") and callable(getattr(module, i)) or isinstance(getattr(module, i), types.ModuleType)];
getAttributes(torch.optim)

['ASGD',
 'Adadelta',
 'Adagrad',
 'Adam',
 'AdamW',
 'Adamax',
 'LBFGS',
 'NAdam',
 'Optimizer',
 'RAdam',
 'RMSprop',
 'Rprop',
 'SGD',
 'SparseAdam',
 '_functional',
 '_multi_tensor',
 'lr_scheduler',
 'swa_utils']

In [5]:
getAttributes(torch.optim.lr_scheduler.Counter)

['_keep_positive',
 'clear',
 'copy',
 'elements',
 'fromkeys',
 'get',
 'items',
 'keys',
 'most_common',
 'pop',
 'popitem',
 'setdefault',
 'subtract',
 'total',
 'update',
 'values']

In [6]:
def getAPIName(func):
    if isinstance(func, types.FunctionType) or isinstance(func, types.BuiltinFunctionType):
        apiName = func.__module__;
        apiName = apiName + "." + func.__name__;
        return apiName;
    elif isinstance(func, type):
        apiName = func.__module__;
        if func.__name__[-len(func.__name__)+1:] == apiName[-len(func.__name__)+1:]:
            apiName = apiName[:-len(func.__name__)] + func.__name__;
        return apiName;
    elif isinstance(func, types.ModuleType):
        return func.__name__;
getAPIName(torch.optim.Adam),getAPIName(torch.mul)

('torch.optim.Adam', 'torch.mul')

In [7]:
torch.tensor.__module__,getAPIName(torch.optim),torch._C.__name__

('torch', 'torch.optim', 'torch._C')

In [8]:
def isDecorated(obj: (types.FunctionType, types.ModuleType, type)):
    return getattr(obj, "isDecorated", False);
isDecorated(torch.tensor)

False

### Some General Decorators Patterns

#### 需要的环境

In [9]:
import functools;
import time;
import types;
import warnings;

#### TimerDecorator

In [10]:
def TimerDecorator(func: types.FunctionType):
    """
    **Description**
    A running timer for a function.
    
    **params**
    func(String): the function to be timed.
    
    **returns**
    wrapper: a timer decorated function.
    """
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        startTime = time.perf_counter_ns();
        result = func(*args, **kwargs);
        endTime = time.perf_counter_ns();
        costTime = (endTime - startTime) / 1000 / 1000;
        timeLog = f"{func.__name__}() cost {costTime} ms";
        # print(timeLog);
        return result, startTime, costTime;
    return wrapper;

Test Cases

In [11]:
def tst():
    t = 1;
    for i in range(5):
        t *= (i+1);
    return t;

TimerDecorator(torch.matmul)(torch.tensor([[1.0,5.0],[3.07,7.29]]), torch.tensor([[6.08,9.05],[3.4,2.8]])), "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%",\
TimerDecorator(tst)()

((tensor([[23.0800, 23.0500],
          [43.4516, 48.1955]]),
  18026340271900,
  0.1904),
 '%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%',
 (120, 18026340495200, 0.0038))

#### APIDecorator

In [12]:
def APIDecorator(func: str, module: str=None):
    """
    **Description**
    A API usage recorder for a function.
    
    **params**
    func(String): the function to be recorded.
    
    **returns**
    wrapper: a API usage recored functiion.
    """
    apiName = getAPIName(func)
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        apiName = getAPIName(func)
        result, startTimestamp, costTime= TimerDecorator(func)(*args, **kwargs);
        print(f"{apiName} starts from {startTimestamp} costs {costTime}ms.");
        return result, apiName, startTimestamp, costTime;
    if module == None:
        return wrapper;
    elif isFuncFromModule(func, module):
        return wrapper;
    else:
        # raise ValueError(f"the function `{apiName}` is not from module `{module}`");
        return func;

Test Cases

In [13]:
APIDecorator(torch.normal, "types")(torch.tensor([0,0.01])), "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%", \
APIDecorator(torch.randn, "torch")([2,3]), "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%", \
APIDecorator(torch.mul)(torch.tensor([2.5,0,1,3]), torch.tensor([2,0.9,2,4]))

torch.randn starts from 18026355410100 costs 0.054299999999999994ms.
torch.mul starts from 18026355592200 costs 0.0325ms.


(tensor([0.0717, 0.8878]),
 '%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%',
 (tensor([[ 0.8719,  0.1853, -0.7167],
          [-0.3325,  3.0851,  0.6144]]),
  'torch.randn',
  18026355410100,
  0.054299999999999994),
 '%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%',
 (tensor([ 5.,  0.,  2., 12.]), 'torch.mul', 18026355592200, 0.0325))

In [14]:
torch.optim.__name__

'torch.optim'

### Class TorchWrapper

#### 需要的环境

In [15]:
import functools;
import time;
import types;
import json;
import os;
import operator;
import pandas as pd;
import torch;

#### TorchWrapper

In [16]:
"""
*****************************
The Structure For callRecords
*****************************
callRecords
│
├── API_1
│   │
│   ├── TotalTime(ms): 150.0
│   │
│   ├── 1
│   │   ├── detailedAPIName: 
│   │   ├── StartTimestamp: 1625150800123456789
│   │   ├── CostTime(ms): 50.0
│   │   └── Arguments: (arg1, arg2, ...)
│   │
│   ├── 2
│   │   ├── detailedAPIName: 
│       ├── StartTimestamp: 1625150860123456789
│       ├── CostTime(ms): 100.0
│       └── Arguments: (arg1, arg2, ...)
│
├── API_2
│   │
│   ├── TotalTime(ms): 200.0
│   │
│   ├── 1
│   │   ├── detailedAPIName: 
│       ├── StartTimestamp: 1625150900123456789
│       ├── CostTime(ms): 200.0
│       └── Arguments: (arg1, arg2, ...)
"""

class TorchWrapper:
    """
    ********************
    Initializing Section
    ********************
    """
    
    # Some const for restoring default or checking steps.
    DEFAULT_FORMAT = "csv";
    DEFAULT_NAME_EPEC = "timestamp";
    SUPPORTED_FORMATS = ["json", "csv", "html", "xlsx"];
    SUPPORETD_NAME_SPEC = ["timestamp", "datetime", "serial"];

    class ConfigKey:
        OUT_DIR = "out_dir";
        FORMAT = "format";
        FILE_MAX_SIZE = "file_max_size";
        FILE_NAME_SPEC = "file_name_spec";

    class CallRecordKey:
        API_NAME = "APIName";
        
        class ResultKey:
            TOTAL_TIME = "TotalTime(ms)";
            CALL_NUMBER = "CallNumber";
            START_TIMESTAMP = "StartTimestamp";
            COST_TIME = "CostTime(ms)";
            ARGUMENTS = "Arguments";
        


    # initialize the wrapper
    def __init__(self, config: dict):
        self.callRecords = {};
        self.config = self.parseConfig(config);

    # restoring default or checking steps.
    def parseConfig(self, config:dict):
        """restoring default or checking steps."""
        # check output directory
        if TorchWrapper.ConfigKey.OUT_DIR not in config:
            raise ValueError("Output directory is required.");
        else:
            assert isinstance(config[TorchWrapper.ConfigKey.OUT_DIR], str);

        # check output format
        if TorchWrapper.ConfigKey.FORMAT in config:
            assert isinstance(config[TorchWrapper.ConfigKey.FORMAT], str);
            format = config[TorchWrapper.ConfigKey.FORMAT];
            if format not in TorchWrapper.SUPPORTED_FORMATS:
                raise ValueError(f"Unsupported format {format} for saving result");
        else:
            config[TorchWrapper.ConfigKey.FORMAT] = TorchWrapper.DEFAULT_FORMAT;

        # check output size limits
        if TorchWrapper.ConfigKey.FILE_MAX_SIZE in config:
            assert isinstance(config[TorchWrapper.ConfigKey.FILE_MAX_SIZE], str);
            if config[TorchWrapper.ConfigKey.FILE_MAX_SIZE][-2:] not in ["KB", "MB", "GB"]:
                raise ValueError("maxSize should be defined in the style of `myInt`KB/MB/GB");

        # check output name spec 
        if TorchWrapper.ConfigKey.FILE_NAME_SPEC in config:
            assert isinstance(config[TorchWrapper.ConfigKey.FILE_NAME_SPEC], str);
            name_spec = config[TorchWrapper.ConfigKey.FILE_NAME_SPEC];
            if name_spec not in TorchWrapper.SUPPORETD_NAME_SPEC:
                raise ValueError(f"Unsupported file name spec {name_spec}");
            else:
                config[TorchWrapper.ConfigKey.FILE_NAME_SPEC] = TorchWrapper.DEFAULT_NAME_EPEC;
      
    """
    ******************
    Decorator Section
    ******************
    """
    
    def CountDecorator(self, func: str):
        """
        **Description**
        A decorator that count the call of a function and record it in a dictionary.
        
        **params**
        func(String): the function to be recorded calling times.
        
        **return**
        wrapper: a function that has been counted calling times.
        """
        funcName = getAPIName(func);
        print(f"decorating function {funcName}");
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            # initialize the record.            
            record = {
                TorchWrapper.CallRecordKey.ResultKey.CALL_NUMBER: None,
                TorchWrapper.CallRecordKey.ResultKey.START_TIMESTAMP: None,
                TorchWrapper.CallRecordKey.ResultKey.COST_TIME: None,
                TorchWrapper.CallRecordKey.ResultKey.ARGUMENTS: args
            };
            
            # catch the result.
            result, apiName, startTimestamp, costTime = APIDecorator(func)(*args, **kwargs);            
            if apiName in self.callRecords:
                callCount = len(self.callRecords[apiName].keys);
                totalTime = callRecords[apiName][TorchWrapper.CallRecordKey.ResultKey.TOTAL_TIME];
            else:
                callCount = 0;
                totalTime = 0.0;
            
            # Generate the record.
            record[TorchWrapper.CallRecordKey.ResultKey.START_TIMESTAMP] = startTimestamp;
            record[TorchWrapper.CallRecordKey.ResultKey.COST_TIME] = costTime;
            record[TorchWrapper.CallRecordKey.ResultKey.ARGUMENTS] = args;
            totalTime += costTime;
            
            # Saving the record
            self.calRecords[apiName][TorchWrapper.CallRecordKey.ResultKey.TOTAL_TIME] =totalTime;
            self.calRecords[apiName][callCount] = record;
            
            
            return result;
        wrapper.isDecorated = True;
        print(f"{funcName} decorated.");
        return wrapper;
        
    """
    ******************
    Processing Section
    ******************
    """
    
    def decorateClass(self, cls):
        """
        **Description**
        Decorates all the methods of a class with CountDecorator and records the API names.

        **params**
        cls(Class): The class whose methods are to be decorated.
        
        **returns**
        cls: The class with its methods decorated.
        """
        try:
            clsName = getAPIName(cls)
            assert isinstance(cls, type), f"`{cls}`must be a class.";
            clsName = getAPIName(cls);
            if isDecorated(cls):
                print(f"module {clsName} has been decorated, out.");
                return;
            cls.isDecorated = True;


            if getAttributes(cls):
                print(f"descensing to class {cls}");
                for name in getAttributes(cls):
                    print(f"descending into {cls.__name__}.{name}");
                    method = getattr(cls, name);
                    apiName = getAPIName(method);
                    if isinstance(method, types.FunctionType):
                        print(f"Decorating {apiName}");
                        setattr(cls, name, self.CountDecorator(method));
                    elif isinstance(method, type) and not isDecorated(method):
                        subCls = method.__name__;
                        setattr(cls, name, self.decorateClass(subCls));
        except TypeError as e:
            if "immutable type" in str(e):
                print(f"{cls} is an immutable type, out.");
            return;

    
    
    def decorateModule(self, module: types.ModuleType):
        """
        **Description**
        a function that can wrap the hole module with CountDecorator.

        **params**
        module(String): The name of module to be decorated.
        visited(Set): the module name that has been decorated.

        **returns**
        a module that has been fully decorated.
        """
        assert isinstance(module, types.ModuleType), f"`{module}`must be a module.";
        moduleName = getAPIName(module);
        if isDecorated(module):
            print(f"module {moduleName} has been decorated, out.");
            return;
        module.isDecorated = True;
            
        
        for name in getAttributes(module):
            func = getattr(module, name);
            if isinstance(func, types.ModuleType):
                setattr(module, name, self.decorateModule(func));
            elif isinstance(func, types.FunctionType):
                if not getattr(func, 'isDecorated', False):
                    print(f"{name} hasn't been decorated, decorate {name}.");
                    decoratedFunc = self.CountDecorator(func);
                    setattr(module, name, decoratedFunc);
                    print(f"{name} hasn't been decorated.")
            elif isinstance(func, type) and not isDecorated(method):
                setattr(module, name, self.decorateClass(func));
    
    """
    **************
    Saving Section
    **************
    """
    
    # parse the usable value from config
    # parse max file size limit;
    def getFileMaxSize(self, config: dict):
        """Though I don't think this is necessary.-- Frankie"""
        if TorchWrapper.ConfigKey.FILE_MAX_SIZE in config:
            maxSize = config[TorchWrapper.ConfigKey.FILE_MAX_SIZE];
            if maxSize.endwith("KB"):
                maxSize = int(maxSize[:-2]) * 1024;
            elif maxSize.endwith("MB"):
                maxSize = int(maxSize[:-2]) * (1024 ** 2);
            elif maxSize.endwith("GB"):
                maxSize = int(maxSize[:-2]) * (1024 ** 3);
            return maxSize;
  
    # get the name of file to save
    def getFileNameSuffix(self):
        file_name_spec = config[TorchWrapper.ConfigKey.FILE_NAME_SPEC];
        if file_name_spec == "timestamp":
            return time.time_ns();
        elif file_name_spec == "datetime":
            return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime());
        elif file_name_spec == "serial":
            raise NotImplementedError;
                          
    # Prepare the result saving directory
    def setPath(self, path):
        if os.path.exists(path):
            if not os.path.isdir(path):
                raise ValueError(f"Path {path} is not a directory");
        else:
            os.makedirs(path);
        return path;
    
    # get the name of file to save
    def getFileName(self, config: dict) -> str:
        """TODO: parse filename from config dictionary."""
        suffix = self.getFileNameSuffix();
        fileName = f"TorchWrapper_Result_{suffix}";
        return fileName;
    
    # get the name of DataFrame formatted callRecords to save
    def getDFFormattedCallRecords(self):
        """Formats the call records as a pandas DataFrame."""
        records = []
        for apiName, calls in self.callRecords.items():
            totalTime = calls.pop(TorchWrapper.CallRecordKey.ResultKey.TOTAL_TIME, 0);
            for callNumber, call in calls.items():
                record = {
                    TorchWrapper.CallRecordKey.API_NAME: apiName,
                    TorchWrapper.CallRecordKey.ResultKey.TOTAL_TIME: totalTime,
                    TorchWrapper.CallRecordKey.ResultKey.CALL_NUMBER: callNumber,
                    TorchWrapper.CallRecordKey.ResultKey.START_TIMESTAMP: call[TorchWrapper.CallRecordKey.ResultKey.START_TIMESTAMP],
                    TorchWrapper.CallRecordKey.ResultKey.COST_TIME: call[TorchWrapper.CallRecordKey.ResultKey.COST_TIME],
                    TorchWrapper.CallRecordKey.ResultKey.ARGUMENTS: call[TorchWrapper.CallRecordKey.ResultKey.ARGUMENTS]
                };
                records.append(record);
        return pd.DataFrame(records);
    
    # Save the result to specific file
    def saveRecord(self, config: dict):
        def saveToJson(data, path: str, fileName: str):
            """Save DataFrame formatted call records to a .json file."""
            data.to_json(f"{path}/{fileName}.json", orient='records', lines=True);
            
        def saveToCSV(data, path: str, fileName: str):
            """Save DataFrame formatted call records to a .csv file."""
            data.to_csv(f"{path}/{fileName}.csv", index=False);
            
        def saveToExcel(data, path: str, fileName: str):
            """Save DataFrame formatted call records to a .xlsx file."""
            data.to_excel(f"{path}/{fileName}.xlsx", index=False);
            
        def saveToHTML(data, path: str, fileName: str):
            """Save DataFrame formatted call records to a .html file."""
            data.to_html(f"{path}/{fileName}.html", index=False);
            
        fileName = self.getFileName(config);
        data = self.getDFFormattedCallRecords();
        outputPath = self.setPath(config[TorchWrapper.ConfigKey.OUT_DIR]);
        if config[TorchWrapper.ConfigKey.FORMAT] == "json":
            saveToJson(data, outputPath, fileName);
        elif config[TorchWrapper.ConfigKey.FORMAT] == "csv":
            saveToCSV(data, outputPath, fileName);
        elif config[TorchWrapper.ConfigKey.FORMAT] == "xlsx":
            saveToExcel(data, outputPath, fileName);
        elif config[TorchWrapper.ConfigKey.FORMAT] == "html":
            saveToHTML(data, outputPath, fileName);

    """
    ******************
    Main Usage Section
    ******************
    """
    
    def start(self, func: types.FunctionType):
        """
        **Description**
        Starts the wrapping and recording process.

        **params**
        func(FunctionType): The function to be executed and recorded.
        
        **raises**
        ValueError: If there is an error executing the function.
        """
        print(f"Starts decorating torch module.");
        self.decorateModule(torch);
        print("torch module decorating complete.");
        try:
            print(f"Starts evaluating {func.__name__}");
            func();
            
        except Exception as e:
            raise ValueError("Error executing the function.") from e;
        print("start saving results.");
        self.saveRecord(config);
        print(f"results file saved.");

In [17]:
config = {
    "out_dir": "./output",
    "format": "csv",
    "file_max_size": "10MB",
    "file_name_spec": "timestamp"
};
wrapper = TorchWrapper(config)

In [18]:
def myCode():
    a = torch.randn(1, 3);
    b = torch.randn(1, 3);
    c = a + b;
    return c;

In [19]:
config = {
    "out_dir": "./output",
    "format": "csv",
    "file_max_size": "10MB",
    "file_name_spec": "timestamp"
};
wrapper = TorchWrapper(config)

wrapper.start(myCode);

Starts decorating torch module.
descensing to class <class 'torch.AliasDb'>
descending into AliasDb.dump
descending into AliasDb.has_writers
descending into AliasDb.may_contain_alias
descending into AliasDb.move_after_topologically_valid
descending into AliasDb.move_before_topologically_valid
descending into AliasDb.to_graphviz_str
descensing to class <class 'torch.AnyType'>
descending into AnyType.containedTypes
descending into AnyType.contiguous
descending into AnyType.device
descending into AnyType.dim
descending into AnyType.dtype
descending into AnyType.get
descending into AnyType.isSubtypeOf
descending into AnyType.is_interface_type
descending into AnyType.kind
descending into AnyType.requires_grad
descending into AnyType.scalarType
descending into AnyType.sizes
descending into AnyType.str
descending into AnyType.strides
descending into AnyType.symbolic_sizes
descending into AnyType.undefined
descending into AnyType.varyingSizes
descending into AnyType.with_device
descending into

  return self.fget.__get__(instance, owner)()


AssertionError: `IndividualMetrics`must be a class.

In [None]:
torch.storage

In [None]:
getAPIName(torch.storage.Storage.from_file)

In [None]:
StaticModule.IndividualMetrics