# <center>我整个重写一遍吧

## 所有用到的环境

In [4]:
import functools;
import time;
import types;

## 代码兼容性评估

### Some General Functions

In [5]:
def isFuncFromModule(func: types.FunctionType, module: str):
    return func.__module__.startswith(module);
isFuncFromModule(torch.nn.LSTM, "torch")

True

In [6]:
def getAttributes(module: str):
    return [i for i in dir(module) if "__" not in i];
getAttributes(torch.optim)

['ASGD',
 'Adadelta',
 'Adagrad',
 'Adam',
 'AdamW',
 'Adamax',
 'LBFGS',
 'NAdam',
 'Optimizer',
 'RAdam',
 'RMSprop',
 'Rprop',
 'SGD',
 'SparseAdam',
 '_functional',
 '_multi_tensor',
 'lr_scheduler',
 'swa_utils']

In [7]:
def getAPIName(func):
    apiName = func.__module__;
    if func.__name__[-len(func.__name__)+1:] == apiName[-len(func.__name__)+1:]:
        apiName = apiName[:-len(func.__name__)] + func.__name__;
    else:
        apiName = apiName + "." + func.__name__;
    return apiName;
getAPIName(torch.optim.Adam),getAPIName(torch.mul)

('torch.optim.Adam', 'torch.mul')

In [8]:
def isDecorated(func: types.FunctionType):
    if hasattr(func, "isDecorated"):
        return True;
    else: 
        return False;
isDecorated(torch.tensor)

False

### Some General Decorators Patterns

#### 需要的环境

In [9]:
import functools;
import time;
import types;
import warnings;

#### TimerDecorator

In [10]:
def TimerDecorator(func: types.FunctionType):
    """
    **Description**
    A running timer for a function.
    
    **params**
    func(String): the function to be timed.
    
    **returns**
    wrapper: a timer decorated function.
    """
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        startTime = time.perf_counter_ns();
        result = func(*args, **kwargs);
        endTime = time.perf_counter_ns();
        costTime = (endTime - startTime) / 1000 / 1000;
        timeLog = f"{func.__name__}() cost {costTime} ms";
        # print(timeLog);
        return result, startTime, costTime;
    return wrapper;

Test Cases

In [11]:
def tst():
    t = 1;
    for i in range(5):
        t *= (i+1);
    return t;

TimerDecorator(torch.matmul)(torch.tensor([[1.0,5.0],[3.07,7.29]]), torch.tensor([[6.08,9.05],[3.4,2.8]])), "%%%%%%%%%%%%%%%%%%%%%%%%%",\
TimerDecorator(tst)()

((tensor([[23.0800, 23.0500],
          [43.4516, 48.1955]]),
  609541406512600,
  2.0538000000000003),
 '%%%%%%%%%%%%%%%%%%%%%%%%%',
 (120, 609541408669800, 0.005900000000000001))

#### APIDecorator

In [12]:
def APIDecorator(func: str, module: str=None):
    """
    **Description**
    A API usage recorder for a function.
    
    **params**
    func(String): the function to be recorded.
    
    **returns**
    wrapper: a API usage recored functiion.
    """
    apiName = getAPIName(func)
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        apiName = getAPIName(func)
        result, startTimestamp, costTime= TimerDecorator(func)(*args, **kwargs);
        return result, apiName, startTimestamp, costTime;
    if module == None:
        return wrapper;
    elif func.__module__.startswith(module):
        return wrapper;
    else:
        # raise ValueError(f"the function `{apiName}` is not from module `{module}`");
        return func;

Test Cases

In [13]:
APIDecorator(torch.normal, "types")(torch.tensor([0,0.01])), "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%", \
APIDecorator(torch.randn, "torch")([2,3]), "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%", \
APIDecorator(torch.mul)(torch.tensor([2.5,0,1,3]), torch.tensor([2,0.9,2,4]))

(tensor([0.2437, 1.2762]),
 '%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%',
 (tensor([[ 0.0903,  0.0890,  2.5897],
          [ 0.7614,  0.2895, -1.5856]]),
  'torch.randn',
  609544966311900,
  0.2326),
 '%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%',
 (tensor([ 5.,  0.,  2., 12.]), 'torch.mul', 609544966770700, 0.0437))

In [14]:
torch.mul.__module__

'torch'

### Class TorchWrapper

#### 需要的环境

In [1]:
import functools;
import time;
import types;
import json;
import os;
import operator;
import pandas as pd;
import torch;

#### TorchWrapper

In [2]:
"""
*****************************
The Structure For callRecords
*****************************
callRecords
│
├── API_1
│   │
│   ├── TotalTime(ms): 150.0
│   │
│   ├── 1
│   │   ├── StartTimestamp: 1625150800123456789
│   │   ├── CostTime(ms): 50.0
│   │   └── Arguments: (arg1, arg2, ...)
│   │
│   ├── 2
│       ├── StartTimestamp: 1625150860123456789
│       ├── CostTime(ms): 100.0
│       └── Arguments: (arg1, arg2, ...)
│
├── API_2
│   │
│   ├── TotalTime(ms): 200.0
│   │
│   ├── 1
│       ├── StartTimestamp: 1625150900123456789
│       ├── CostTime(ms): 200.0
│       └── Arguments: (arg1, arg2, ...)
"""

class TorchWrapper:
    """
    ********************
    Initializing Section
    ********************
    """
    
    # Some const for restoring default or checking steps.
    DEFAULT_FORMAT = "csv";
    DEFAULT_NAME_EPEC = "timestamp";
    SUPPORTED_FORMATS = ["json", "csv", "html"];
    SUPPORETD_NAME_SPEC = ["timestamp", "datetime", "serial"];

    class ConfigKey:
        OUT_DIR = "out_dir";
        FORMAT = "format";
        FILE_MAX_SIZE = "file_max_size";
        FILE_NAME_SPEC = "file_name_spec";

    class CallRecordKey:
        API_NAME = "APIName";
        
        class ResultKey:
            TOTAL_TIME = "TotalTime(ms)";
            CALL_NUMBER = "CallNumber";
            START_TIMESTAMP = "StartTimestamp";
            COST_TIME = "CostTime(ms)";
            ARGUMENTS = "Arguments";
        


    # initialize the wrapper
    def __init__(self, config: dict):
        self.callRecords = {};
        self.config = self.parseConfig(config);

    # restoring default or checking steps.
    def parseConfig(self, config:dict):
        """restoring default or checking steps."""
        # check output directory
        if TorchWrapper.ConfigKey.OUT_DIR not in config:
            raise ValueError("Output directory is required.");
        else:
            assert isinstance(config[TorchWrapper.ConfigKey.OUT_DIR], str);

        # check output format
        if TorchWrapper.ConfigKey.FORMAT in config:
            assert isinstance(config[TorchWrapper.ConfigKey.FORMAT], str);
            format = config[TorchWrapper.ConfigKey.FORMAT];
            if format not in TorchWrapper.SUPPORTED_FORMATS:
                raise ValueError(f"Unsupported format {format} for saving result");
        else:
            config[TorchWrapper.ConfigKey.FORMAT] = TorchWrapper.DEFAULT_FORMAT;

        # check output size limits
        if TorchWrapper.ConfigKey.FILE_MAX_SIZE in config:
            assert isinstance(config[TorchWrapper.ConfigKey.FILE_MAX_SIZE], str);
            if config[TorchWrapper.ConfigKey.FILE_MAX_SIZE][-2:] not in ["KB", "MB", "GB"]:
                raise ValueError("maxSize should be defined in the style of `myInt`KB/MB/GB");

        # check output name spec 
        if TorchWrapper.ConfigKey.FILE_NAME_SPEC in config:
            assert isinstance(config[TorchWrapper.ConfigKey.FILE_NAME_SPEC], str);
            name_spec = config[TorchWrapper.ConfigKey.FILE_NAME_SPEC];
            if name_spec not in TorchWrapper.SUPPORETD_NAME_SPEC:
                raise ValueError(f"Unsupported file name spec {name_spec}");
            else:
                config[TorchWrapper.ConfigKey.FILE_NAME_SPEC] = TorchWrapper.DEFAULT_NAME_EPEC;
      
    """
    ******************
    Decorator Section
    ******************
    """
    def CountDecorator(func: str):
        """
        **Description**
        A decorator that count the call of a function and record it in a dictionary.
        
        **params**
        func(String): the function to be recorded calling times.
        
        **return**
        wrapper: a function that has been counted calling times.
        """
        @functools.warps(func)
        def wrapper(*args, **kwargs):
            # initialize the record.            
            record = {
                TorchWrapper.CallRecordKey.ResultKey.CALL_NUMBER: None,
                TorchWrapper.CallRecordKey.ResultKey.START_TIMESTAMP: None,
                TorchWrapper.CallRecordKey.ResultKey.COST_TIME: None,
                TorchWrapper.CallRecordKey.ResultKey.ARGUMENTS: args
            };
            
            # catch the result.
            result, apiName, startTimestamp, costTime = APIDecorator(func)(*args, **kwargs);            
            if apiName in self.callRecords:
                callCount = len(self.callRecords[apiName].keys);
                totalTime = callRecords[apiName][TorchWrapper.CallRecordKey.ResultKey.TOTAL_TIME];
            else:
                callCount = 0;
                totalTime = 0.0;
            
            # Generate the record.
            record[TorchWrapper.CallRecordKey.ResultKey.START_TIMESTAMP] = startTimestamp;
            record[TorchWrapper.CallRecordKey.ResultKey.COST_TIME] = costTime;
            record[TorchWrapper.CallRecordKey.ResultKey.ARGUMENTS] = args;
            totalTime += costTime;
            
            # Saving the record
            self.calRecords[apiName][TorchWrapper.CallRecordKey.ResultKey.TOTAL_TIME] =totalTime;
            self.calRecords[apiName][callCount] = record;
            
            
            return result;
        wrapper.isDecorated = True;
        return wrapper;
        
    """
    ******************
    Processing Section
    ******************
    """
    
    def decorateModule(self, module: types.ModuleType, visited=None: set):
        """
        **Description**
        a function that can wrap the hole module with CountDecorator.

        **params**
        module(String): The name of module to be decorated.
        visited(Set): the module name that has been decorated.

        **returns**
        a module that has been fully decorated.
        """
        assert isinstance(module), "`module`must be a module.";
        
        if visited = None:
            visited = None;
        elif module in visited:
            return;
        visited.add(getAPIName(module))
            
        
        for name in getAttribututes(module):
            print("descending into {} from ")
            func = getattr(module, name);
            if isinstance(func, types.ModuleType):
                setattr(module, func.__name__, decorateModule(func, visited));
            elif isinstance(func, types.FunctionType):
                decoratedFunc = CountDecorator(func);
                setattr(module, func.__name__, decoratedFunc);
    
    """
    **************
    Saving Section
    **************
    """
    
    # parse the usable value from config
    # parse max file size limit;
    def getFileMaxSize(self, config: dict):
        """Though I don't think this is necessary.-- Frankie"""
        if TorchWrapper.ConfigKey.FILE_MAX_SIZE in config:
            maxSize = config[TorchWrapper.ConfigKey.FILE_MAX_SIZE];
            if maxSize.endwith("KB"):
                maxSize = int(maxSize[:-2]) * 1024;
            elif maxSize.endwith("MB"):
                maxSize = int(maxSize[:-2]) * (1024 ** 2);
            elif maxSize.endwith("GB"):
                maxSize = int(maxSize[:-2]) * (1024 ** 3);
            return maxSize;
  
                          
    def getFileNameSuffix(self, file_name_spec: str):
        if file_name_spec == "timestamp":
            return time.time_ns();
        elif file_name_spec == "datetime":
            return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime());
        elif file_name_spec == "serial":
            raise NotImplementedError;
                          
    # Prepare the result saving directory
    def setPath(path):
        if os.path.exists(path):
            if not os.path.isdir(path):
                raise ValueError(f"Path {path} is not a directory");
        else:
            os.makedirs(path);
        return path;
    
        
        
    # 
    def saveRecord(config: dict, fileName: str):
        def getFileName():
            
        def saveToJson(fileName: str):
            """TODO: save DataFrame formatted callRecorded to a .json file."""
            
        def saveToCSV(fileName: str):
            """TODO: save DataFrame formatted callRecorded to a .csv file."""
            
        def saveToExcel(fileName: str):
            """TODO: save DataFrame formatted callRecorded to a .excel file."""
            
        def saveToHTML(fileName: str):
            """TODO: save DataFrame formatted callRecorded to a .html file"""
            
        def getFileName(config: dict) -> str:
        
        
        getFileName(config);
        data = self.getDFFormattedCallRecords();

    """
    ******************
    Main Usage Section
    ******************
    """
    
        def start(self, func):
            """
            **Description**
            """
            config = self.parseConfig(config);
            self.decorateModule(torch);
            func()
            
            fileName = self.getFileNameSuffix(config);
            #  self.saveReport(self.callRecords);
        

SyntaxError: invalid syntax (3343111851.py, line 156)

In [356]:
gA = getAttributes

In [359]:
ts = "tensor"
tc = torch
ttensor = TimerDecorator(torch.tensor)
ttensor([2,3]) + ttensor([6,9])

(tensor([2, 3]),
 548149788238200,
 1.2635,
 tensor([6, 9]),
 548149789541200,
 0.2058)

In [371]:
setattr(torch, "tensor", APIDecorator(torch.tensor))

In [372]:
torch.tensor([2,3])

(tensor([2, 3]), 'torch.tensor', 607726794436900, 1.0397)

In [247]:
import warnings
warnings.warn("Hi")



In [343]:
class cal():
    def __init__(self):
        self.FFF = 6;
        self.TTT = 7;
        
    def demo(self):
        return self.FFF * self.TTT;
    
    def add(self, a, b):
        return a + b;
    def minus(self, a, b):
        return a - b;

In [1]:
import torch

In [2]:
torch.tensor([2,3])

tensor([2, 3])

In [16]:
getAPIName(torch.tensor)

'torch.tensor'

In [None]:
wrapper = TorchWrapper