* 装饰器的使用
* 关于*args和**kwargs

In [None]:
# 0号装饰器，过于简单实用性不高
def dec_0(f):

    return f

In [2]:
# 1号装饰器，加入可执行的操作
def dec_1(*opts):
    def dec(f):
        # do something
        return f
    return dec

In [None]:
def dec_2(f):
    def wrapper(*args, **kwargs):
        # do something
        return f(*args, **kwargs)
    return wrapper

In [3]:
def dec_3(*opt):    # 配置函数
    def dec(f):     # 装饰器函数
        def wrapper(*args, **kwargs):   # 包装函数
            # do somthing
            return f(*args, **kwargs)
        return wrapper
    return dec

In [6]:
# def task(f):
#     global tasks
#     tasks.append(f)
#     return f

def task(name=''):
    def _task(f):
        global tasks
        tasks.append(f)
        if name:
            setattr(f, 'name', name)
        else:
            setattr(f, 'name', f.__name__)
        return f
    
    if callable(name):
        return task()(name)
    
    return _task

def action():
    for task in tasks:
        print(task, ':', task())

@task
def play():
    return 'playing ...'

# action()

NameError: name 'tasks' is not defined

In [None]:
def count(f):
    counter = 0
    def wrapper(*args, **kwargs):
        nonlocal counter
        counter += 1
        return f(*args, **kwargs) + ' ' + str(counter)
    return wrapper



一个关于使用decorator自动添加绝对路径的实例

In [None]:
import os
import sys
import time
import numpy as np

import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

from torch.utils.tensorboard import SummaryWriter   # Launch Tensorboard
from torchsummary import summary

# add SASNet/ChangeFormer/ to environment 
def set_env(vis_main):
    def set_env_main():
        pythonpath = os.path.abspath(os.getcwd())
        changeformer_path = os.path.join(pythonpath, 'ChangeFormer')
        print(changeformer_path)
        vis_main(path = changeformer_path)

    return set_env_main

def model_tensorboard(model, device):
    # path to log
    current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) # time info
    log_dir = f'SASNet/attention_vis/{current_time}'    # log path
    log_comment = f'harry'                              # comment info
    print(next(model.parameters()).device)              # check cuda info
    inputs = torch.randn(1, 3, 256, 256).to(device)
    # writer = SummaryWriter(log_dir=log_dir, comment=log_comment)
    summary(model=model, input_size=[(3, 256, 256), (3, 256, 256)], device='cuda')
    print('model:'.center(50,"*"))
    print(model)

    # writer.add_graph(model=model, input_to_model=(inputs, inputs))
    # writer.close()

@set_env
def main(path:str):
    # set path
    sys.path.insert(0, path)
    from models.networks import BASE_Transformer

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    changeformer_net = BASE_Transformer(input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4,
                                with_pos='learned', enc_depth=1, dec_depth=8)
    model_tensorboard(changeformer_net.to(device), device)

if __name__ == '__main__':
    main()