In [8]:
import os,sys
sys.path.append(os.path.abspath('..'))
import abc
from jittorKLike.utils import cAxis,silent_preload_jittor
silent_preload_jittor()
import jittor as jt
from jittor import nn
from functools import partial,wraps
import logging

In [2]:

activation_dict={
    'relu':nn.ReLU, 
    'leakyrelu':nn.LeakyReLU, 
    'sigmoid':nn.Sigmoid,
    'tanh':nn.Tanh,
    'softmax':partial(nn.Softmax,dim=cAxis),
}

In [3]:
def get_activation(act):
    #引入None，可以直接act=get_act而不用if
    if act is None:
        return None
    else:
        act=str.lower(act)
    return activation_dict[act]()

def check_activations():
    check=True
    for name,constr in activation_dict.items():
        if name is None: continue
        try:
            constr()
        except:
            check=False
            logging.warning(f'Check_activations: Failed when constructing activation {name}')
    if check:
        logging.info('Check_activations: All activations ready.')
    else:
        logging.warning('Check_activations: Not all activations are available.')
        
check_activations()

In [4]:
# Simple alias
Embedding=nn.Embedding
Dropout=nn.Dropout

In [5]:
class Dense(nn.Module):
    def __init__(self,inC,units,activation=None):
        super().__init__()
        self.fc=nn.Linear(inC,units)
        self.act=get_activation(activation)
        
    def execute(self,x):
        x=self.fc(x)
        if self.act is not None:
            x=self.act(x)
        return x
    
class BasicDense(nn.Module):
    '''
    This class is for channels_last format, eg. RNN.
    '''
    def __init__(self,input_dim,units,activation=None,use_bias=True):
        super().__init__()
        import jittorKLike.keras_ops as K
        
        self.units = int(units) if not isinstance(units, int) else units
        self.activation = get_activation(activation)
        self.use_bias=use_bias
        
        self.kernel=K.random_normal([input_dim,units])
        self.bias=K.random_normal([units]) if use_bias else None
        
    def execute(self,inputs):
        x=jt.matmul(inputs, self.kernel)
        if self.use_bias:
            x=x+self.bias
        if self.activation is not None:
            x=self.activation(x)
        return x

In [6]:
class Conv2D(nn.Module):
    def __init__(self,inC,outC,kernel_size=3,stride=1,padding='valid',activation=None,**kwargs):
        super().__init__()
        
        padding=self._check_padding(padding,kernel_size)
        
        self.conv=nn.Conv2d(inC,outC,
                            kernel_size=kernel_size,stride=stride,padding=padding,**kwargs)
        self.act=get_activation(activation)
        
    def execute(self,x):
        x=self.conv(x)
        if self.act is not None:
            x=self.act(x)
        return x
        
    def _check_padding(self,pad,kernel_size):
        if isinstance(pad,str):
            pad=str.lower(pad)
            if pad in ['valid','same']:
                if pad == 'valid':
                    return 0
                else:
                    if isinstance(kernel_size,int):
                        k=[kernel_size]*2
                    else: 
                        k=kernel_size
                    return tuple(map(lambda x:(x-1)//2, k))
            else:
                raise KeyError('Expected padding with str-type to be one of [\'valid\',\'same\']'\
                              f', but receive: {pad}')
        elif isinstance(pad,[tuple]):
            return pad

In [7]:
class Conv2DTranspose(Conv2D):
    def __init__(self,inC,outC,kernel_size=3,stride=1,padding='valid',output_padding=0,activation=None,**kwargs):
        #此处跳过一层Conv2D初始化
        super(Conv2D,self).__init__()
        
        padding=self._check_padding(padding,kernel_size)
        
        self.conv=nn.ConvTranspose(inC,outC,
                            kernel_size=kernel_size,stride=stride,padding=padding,output_padding=output_padding,**kwargs)
        self.act=get_activation(activation)

In [8]:
class Pooling(nn.Module):
    def __init__(self,pool_size=2,stride=None):
        '''
        args:
          pool_size (name from keras): default 2 (from keras)
        '''
        super().__init__()

class MaxPool2D(Pooling):
    @wraps(Pooling.__init__)
    def __init__(self,pool_size=2,stride=None):
        super().__init__()
        self.pool=nn.MaxPool2d(kernel_size=pool_size,stride=stride)
        
    def execute(self,x):
        return self.pool(x)
    
class AvgPool2D(Pooling):
    @wraps(Pooling.__init__)
    def __init__(self,pool_size=2,stride=None):
        super().__init__()
        self.pool=nn.AvgPool2d(kernel_size=pool_size,stride=stride)
        
    def execute(self,x):
        return self.pool(x)
    
AveragePool2D=AvgPool2D

In [9]:
class UpSample2D(nn.Module):
    def __init__(self,scale_factor=2,interpolation='nearest'):
        '''
        args:
          scale_factor (name from jittor/torch): default 2 (from keras)
          interpolation (name from opencv/keras)
        '''
        super().__init__()
        self.upsample=nn.Upsample(scale_factor=scale_factor,mode=interpolation)
    
    def execute(self,x):
        return self.upsample(x)
    
UpSampling2D=UpSample2D

In [23]:
class _RNNFactory(nn.Module,metaclass=abc.ABCMeta):
    @abc.abstractmethod
    def __init__(self,**kwargs):
        super().__init__()
        
    def execute(self,x):
        x=self.rnn(x)[0]
        if self.return_sequence:
            return x
        else:
            return x[:,0,:]

In [28]:
class RNN(_RNNFactory):
    def __init__(self,inC,outC,return_sequence=True,bidirectional=False,*args,**kwargs):
        super().__init__()
        self.return_sequence=return_sequence
        self.rnn=nn.RNN(inC,outC,bidirectional=bidirectional,*args,**kwargs)
        
class GRU(_RNNFactory):
    def __init__(self,inC,outC,return_sequence=True,bidirectional=False,*args,**kwargs):
        super().__init__()
        self.return_sequence=return_sequence
        self.rnn=nn.GRU(inC,outC,bidirectional=bidirectional,*args,**kwargs)
        
class LSTM(_RNNFactory):
    def __init__(self,inC,outC,return_sequence=True,bidirectional=False,*args,**kwargs):
        super().__init__()
        self.return_sequence=return_sequence
        self.rnn=nn.LSTM(inC,outC,bidirectional=bidirectional,*args,**kwargs)

In [10]:
class Flatten(nn.Module):
    # A layer wrapper. Same as nn.flatten(inputs,1,-1)
    def execute(self,inputs):
        return nn.flatten(inputs,start_dim=1,end_dim=-1)
    
class Reshape(nn.Module):
    def __init__(self,shape):
        super().__init__()
        self.shape=list(shape)
    
    def execute(self,inputs):
        batchsz=inputs.shape[0]
        return jt.reshape(inputs,[batchsz]+self.shape)

In [11]:
class Concat(nn.Module):
    def execute(self,inputs):
        return jt.concat(inputs, dim=cAxis)