In [2]:
import torch.nn as nn
from numpy import repeat
from einops import rearrange
import torch
from torch import einsum

'''
# 转置，对角线对称
image = rearrange(image, 'h w c -> w h c')
# 计算矩阵的迹
a = torch.ones(3,4)
einsum('ii',a)
# 计算你在所有元素的和
einsum('i,j',a)

'''
def pair(size):
    if(type(size) !=(size,size)):
        return (size,size)
    else:
        pass
        

class PreNorm(nn.Module):
    '''
    标准化，设置可传入参数(Multi-Head Attention 和 MLP)
    '''
    def __init__(self,dim,fn) -> None:
        super().__init__()
        self.norm = nn.LayerNopprm(dim)
        self.fn = fn
    def forward(self,x,**kwargs):
        return self.fn(self.norm(x),**kwargs)



class Attention(nn.Module):
    # 默认 8 个头，每个头处理64维的信息，用于计算注意力的维度为512维
    def __init__(self,dim,heads=8,dim_head=64,dropout=0.) -> None:
        super().__init__()
        # 计算注意力的维度
        inner_dim = dim_head *heads
        project_out = not(heads ==1 and dim_head==dim)
        
        self.heads = heads
        self.scale = dim_head **-0.5
        # 将输出的最后一维归一化
        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim,inner_dim*3,bias=False)
        self.to_out = nn.Sequemtial(
            nn.Linear(inner_dim,dim),
            nn.Droupout(dropout),
        )if project_out else nn.Identify()
        
    def forward(self,x):
        b,n,_,h = *x.shape,self.heads
        qkv = self.to_qkv(x).chunk(3,dim=-1)
        
        q,k,v =map(lambda t: rearrange(t,'b n (h d) -> b h n d', h=h), qkv)
        
        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        attn = self.attend(dots)
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        
        # 这一步是将每一个头输出的Attention向量拼接起来，将多头输出变成一个输出，可以看到将h隐藏了起来
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim), 
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class transformer(nn.Module):
    '''
       dim:输入的最后一维
        depth:transformer中多头注意力层(Attention层)的层数
        heads:多头注意力头数
        dim_head:多头注意力层的输入(与dim区别)
        mlp_dim:mlp层隐藏层的维度
        droupout:失活神经元的比例
    '''
    def __init__(self,dim,depth,heads,dim_head,mlp_dim,dropout=0.) -> None:
        super().__init__()
        self.layers = nn.Modulelist([])
        for _ in range(depth):
            self.layer.append(nn.Modulelist([
                PreNorm(dim,Attention(dim,))
            ]))
    def forward(self,x):
        for attn,ff in self.layers:
            x = attn(x)+x
            x = ff(x)+x
            return x
class ViT(nn.Module):
    def __init__(self,*,image_size,patch_size,num_classes,dim,depth,heads,mlp_dim, \
        pool='cls', channels=1, dim_head=64, dropout=0., emb_dropout=0.):
        super().__init__()
        # 图片patch化
        image_height,image_width = pair(image_size)
        patch_height,patch_width = pair(patch_size)
        # 保证 图片的大小和patch的长宽可以整除
        assert image_height % patch_height ==0 and image_width % patch_width ==0 
        num_patches = (image_height//patch_height)*(image_width//patch_width)
        patch_dim = channels *patch_height*patch_width
        
        assert pool in {'cls','mean'}
        
        self.to_patch_embedding = nn.Sequential(
            rearrange('b c (h p1)(w p2) -> b (h w)(p1 p2 c)',p1=patch_height,p2=patch_width),
            nn.Linear(patch_dim,dim)
        )
        # 初始化 cls_token
        # nn.Parameter() 定义可学习参数
        self.cls_token = nn.Parameter(torch.randn(1,1,dim))
        # 初始化位置信息
        self.pos_embedding = nn.Parameter(torch.randn(1,num_patches+1,dim))
        
        self.dropout = nn.Dropout(emb_dropout)
        
        # 初始化 transformer
        self.transformer = transformer(dim,depth,heads,dim_head,mlp_dim,dropout)
        
        self.pool = pool
        self.to_latent = nn.Identity()
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim,num_classes)
        )


    def forward(self,img):
        # 图片patche 化
        # [b,c,(h*p1),(w*p2)] -> [b,(h*w),(p1*p2*c)] -> [b,(h*w),dim]
        x = self.to_patch_embedding(img)
        b,n,_ = x.shape()
        
        cls_tokens = repeat(self.cls_token,'() n d -> b n d',b=b)
        
        x = torch.cat((cls_tokens,x),dim=1)
        x += self.pos_embedding[:,:(n+1)]
        x = self.dropout(x)
        x = self.transformer(x)
        x = x.mean(dim=1) if self.pool=='mean' else x[:,0]
        x = self.to_latent(x)
        return self.mlp_head(x)
    



In [35]:
# 参数配置
import argparse
import torch.optim as op
import yaml

'''yaml?


'''
'''
_C = CN()

# Base config files
_C.BASE = ['']

# -----------------------------------------------------------------------------
# Data settings
# -----------------------------------------------------------------------------
_C.DATA = CN()
# Batch size for a single GPU, could be overwritten by command line argument
_C.DATA.BATCH_SIZE = 32
# Path to dataset, could be overwritten by command line argument
_C.DATA.DATA_PATH = ''
# Dataset name
_C.DATA.DATASET = 'imagenet'
# Input image size
_C.DATA.IMG_SIZE = 224
def _update_config_from_file(config, cfg_file):
    config.defrost()
    with open(cfg_file, 'r') as f:
        yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)

    for cfg in yaml_cfg.setdefault('BASE', ['']):
        if cfg:
            _update_config_from_file(
                config, os.path.join(os.path.dirname(cfg_file), cfg)
            )
    print('=> merge config from {}'.format(cfg_file))
    config.merge_from_file(cfg_file)
    config.freeze()

def update_config(config, args):
    _update_config_from_file(config, args.cfg)

    config.defrost()
    if args.opts:
        config.merge_from_list(args.opts)

    # merge from specific arguments
    if args.batch_size:
        config.DATA.BATCH_SIZE = args.batch_size
    if args.data_path:
        config.DATA.DATA_PATH = args.data_path
    if args.zip:
        config.DATA.ZIP_MODE = True
    if args.cache_mode:
        config.DATA.CACHE_MODE = args.cache_mode
    if args.resume:
        config.MODEL.RESUME = args.resume
    if args.accumulation_steps:
        config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
    if args.use_checkpoint:
        config.TRAIN.USE_CHECKPOINT = True
    if args.amp_opt_level:
        config.AMP_OPT_LEVEL = args.amp_opt_level
    if args.output:
        config.OUTPUT = args.output
    if args.tag:
        config.TAG = args.tag
    if args.eval:
        config.EVAL_MODE = True
    if args.throughput:
        config.THROUGHPUT_MODE = True

        
    if args.num_workers is not None:
        config.DATA.NUM_WORKERS = args.num_workers
        
    #set lr and weight decay
    if args.lr is not None:
        config.TRAIN.BASE_LR = args.lr
    if args.min_lr is not None:
        config.TRAIN.MIN_LR = args.min_lr
    if args.warmup_lr is not None:
        config.TRAIN.WARMUP_LR = args.warmup_lr
    if args.warmup_epochs is not None:
        config.TRAIN.WARMUP_EPOCHS = args.warmup_epochs
    if args.weight_decay is not None:
        config.TRAIN.WEIGHT_DECAY = args.weight_decay

    if args.epochs is not None:
        config.TRAIN.EPOCHS = args.epochs
    if args.dataset is not None:
        config.DATA.DATASET = args.dataset
    if args.lr_scheduler_name is not None:
        config.TRAIN.LR_SCHEDULER.NAME = args.lr_scheduler_name
    if args.pretrain is not None:
        config.MODEL.PRETRAINED = args.pretrain

    # set local rank for distributed training
    config.LOCAL_RANK = args.local_rank

    # output folder
    config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)

    config.freeze()

def get_config(args):
    """Get a yacs CfgNode object with default values."""
    # Return a clone so that the defaults will not be altered
    # This is for the "local variable" use pattern
    config = _C.clone()
    update_config(config, args)

    return config
def parse_option():
    # ArgumentParser 将命令行和解析成Python数据类型所需的全部信息
    parser = argparse.ArgumentParser('ViT train', add_help=False)
    # add_argument 添加参数信息
    #parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
    parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs='+',
    )
     # easy config modification
    parser.add_argument('--batch-size', default=32,type=int, help="batch size for single GPU")
    parser.add_argument('--data-path',default='/home/yy/Project/dataset/data', type=str, help='path to dataset')
    parser.add_argument('--epoch',default=250, type=int, help='path to dataset')
    parser.add_argument('--lr',default=1e-4, type=float, help='learning rate')
   # parser.add_argument('--optim',default=op.Adam,type=)
    # 设置
    args = parser.parse_args()
    print(args)
'''
def parse_option():
    # ArgumentParser 将命令行和解析成Python数据类型所需的全部信息
    parser = argparse.ArgumentParser('ViT train', add_help=False)
    # add_argument 添加参数信息
    #parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
    parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs='+',
    )
     # easy config modification
    parser.add_argument('--batch-size', default=32,type=int, help="batch size for single GPU")
    parser.add_argument('--data-path',default='/home/yy/Project/dataset/data', type=str, help='path to dataset')
    parser.add_argument('--epoch',default=250, type=int, help='path to dataset')
    parser.add_argument('--lr',default=1e-4, type=float, help='learning rate')
   # parser.add_argument('--optim',default=op.Adam,type=)
    # 将参数设置为元组形式返回
    args = parser.parse_known_args()
   
    
    print(type(args))
    print(args)
parse_option()

usage: ViT train [--opts OPTS [OPTS ...]] [--batch-size BATCH_SIZE]
                 [--data-path DATA_PATH] [--epoch EPOCH] [--lr LR]
ViT train: error: unrecognized arguments: --ip=127.0.0.1 --stdin=9003 --control=9001 --hb=9000 --Session.signature_scheme="hmac-sha256" --Session.key=b"e86de039-ac7e-43fe-8c63-531f9dbb6dcf" --shell=9002 --transport="tcp" --iopub=9004 --f=/home/yy/.local/share/jupyter/runtime/kernel-v2-44453S9WSNtQSaqnt.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


### yacs的学习

yacs是一个轻量级用于管理系统配置的开源库，常用于模型训练过程中的参数配置(lr,depth...),使用可读的yaml格式。

**安装**



In [3]:
%pip install yacs

Note: you may need to restart the kernel to use updated packages.


**yacs的使用**

1. 新建一个`config.yaml`文件
    ```yaml
    DATA:
        IMG_SIZE: 224
    MODEL:
        TYPE: MetaFG
        NAME: MetaFG_0
    ```
2. 新建`python`文件

    ```python
    import os
    import yaml
    from yacs.config import CfgNode as CN  
    
     
    ```


In [6]:
import torchvision
import torchvision.transforms as tf
import  torch

trans_train = torchvision.transforms.Compose([
    tf.ToTensor(),
])
traindata = torchvision.datasets.CIFAR10('/home/yy/Project/dataset/data',train=True,transform=trans_train,
                                          download=True)

train_data_size = len(traindata)
print('训练数据集长度为{}'.format(train_data_size))
# 加载数据集
dataloader = torch.utils.data.DataLoader(traindata,batch_size=32)
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
epoch = 30
total =0
predict_correct = 0.0
for i in range(epoch):
    
    for data in dataloader:
        inputs ,lables = data
        inputs,lables = inputs.to(device),lables.to(device)
        outputs = ViT(inputs)
        _,predict = torch.max(outputs.data,1)
        total += predict.size(0)
        predict_correct +=(predict==lables).sum()
        
    print('[epoch { },accuracy {}]'.format(epoch+1,predict_correct/total))
     

Files already downloaded and verified
训练数据集长度为50000


TypeError: __init__() takes 1 positional argument but 2 were given

device(type='cuda', index=0)

## Python 知识点

`*`和`**`的区别与应用
- `*`:接受元组对象中的每个元素，然后作为一个个的**位置参数**传入函数中
- `**`:接收字典对象中的每个元素，作为一个个的**关键字参数**传入函数

In [10]:
def fun1(*args):
    print(*args)
def fun2(**kwargs):
    print(kwargs)
fun1("1","2","3")
fun2(name="张三",age=21,score=100)

1 2 3
{'name': '张三', 'age': 21, 'score': 100}
