# torch 源码抄写

只抄写部分我感兴趣的-------> 任性

In [4]:
import warnings
import torch
from torch.nn import Linear
from torch.nn.init import xavier_uniform_
from torch.nn.init import constant_
from torch.nn.init import xavier_normal_
from torch.nn.parameter import Parameter
from torch.nn import Module
from torch.nn import functional as F

In [5]:
class ReLU(Module):
    __constants__ = ['inplace']
    
    def __init__(self, inplace=False):
        super(ReLU, self).__init__()
        self.inplace = inplace
    
    def forward(self, input):
        return F.relu(input, inplace=self.inplace)
    
    def extra_repr(self):
        inplace_str = 'inplace=True' if self.inplace else ''
        return inplace_str

In [6]:
class RReLU(Module):
    __constants__ = ['lower', 'upper', 'inplace']
    
    def __init__(self, lower=1./8, upper=1./3, inplace=False):
        super(RReLU, self).__init__()
        self.lower = lower
        self.upper = upper
        self.inplace = inplace
    def forward(self, input):
        return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
    
    def extra_repr(self):
        inplace_str = ', inplace=True' if self.inplace else ''
        return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)

In [7]:
class Sigmoid(Module):
    
    def forward(self, input):
        return torch.sigmoid(input)

In [8]:
class Tanh(Module):
    
    def forward(self, input):
        return torch.tanh(input)

In [11]:
class MutiheadAttention(Module):
    def __init__(self, embed_dim, num_heads, dropout=0., bias=True,
                 add_bias_kv=False, add_zero_attn=False, kdim=None,
                 vdim=None):
        super(MutiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
        
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        
        self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
        
        if self._qkv_same_embed_dim is False:
            self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
            self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
            self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
        
        if bias:
            self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
        else:
            self.register_parameter('in_proj_bias', None)
        self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
        
        if add_bias_kv:
            self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
            self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
        else:
            self.bias_k = self.bias_v = None
        
        self.add_zero_attn = add_zero_attn
        
        self._reset_parameters()
        
    def _reset_parameters(self):
        if self._qkv_same_embed_dim:
            xavier_uniform_(self.in_proj_weight)
        else:
            xavier_uniform_(self.q_proj_weight)
            xavier_uniform_(self.k_proj_weight)
            xavier_uniform_(self.v_proj_weight)
        if self.in_proj_bias is not None:
            constant_(self.in_proj_bias, 0.)
            constant_(self.out_proj.bias, 0.)
        if self.bias_k is not None:
            xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            xavier_normal_(self.bias_v)
    
    def forward(self, query, key, value, key_padding_mask=None,
                need_weights=True, attn_mask=None):
        
        if hasattr(self, '_qkv_same_embed_dim') and self._qkv_same_embed_dim is False:
            return F.multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias, 
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights, 
                attn_mask=attn_mask, use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight)
        else:
            if not hasattr(self, '_qkv_same_embed_dim'):
                warnings.warn('A new version of MultiheadAttention module has been implemented. \
                    Please re-train your model with the new module',
                              UserWarning)

            return F.multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias, 
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights, 
                attn_mask=attn_mask)     