In [1]:
from utils import cAxis,silent_preload_jittor
silent_preload_jittor()
import jittor as jt
from jittor import nn
from functools import partial
import logging

In [15]:

activation_dict={
    None:None,
    'relu':nn.ReLU, 
    'leakyrelu':nn.LeakyReLU, 
    'sigmoid':nn.Sigmoid,
    'tanh':nn.Tanh,
    'softmax':partial(nn.Softmax,dim=cAxis),
}

In [16]:
def get_activation(act):
    act=str.lower(act)
    return activation_dict[act]()

def check_activations():
    check=True
    for name,constr in activation_dict.items():
        if name is None: continue
        try:
            constr()
        except:
            check=False
            logging.warning(f'Check_activations: Failed when constructing activation {name}')
    if check:
        logging.info('Check_activations: All activations ready.')
    else:
        logging.warning('Check_activations: Not all activations are available.')
        
check_activations()

In [4]:
class Dense(nn.Module):
    def __init__(self,inC,units,activation=None):
        super().__init__()
        self.fc=nn.Linear(inC,units)
        self.act=get_activation(activation)
        
    def execute(self,x):
        x=self.fc(x)
        if self.act is not None:
            x=self.act(x)
        return x

In [4]:
class Conv2D(nn.Module):
    def __init__(self,inC,outC,kernel_size=3,stride=1,padding='valid',activation=None,**kwargs):
        super().__init__()
        
        padding=self._check_padding(padding,kernel_size)
        
        self.conv=nn.Conv2d(inC,outC,
                            kernel_size=kernel_size,stride=stride,padding=padding,**kwargs)
        self.act=get_activation(activation)
        
    def execute(self,x):
        x=self.conv(x)
        if self.act is not None:
            x=self.act(x)
        return x
        
    def _check_padding(self,pad,kernel_size):
        if isinstance(pad,str):
            pad=str.lower(pad)
            if pad in ['valid','same']:
                if pad == 'valid':
                    return 0
                else:
                    if isinstance(kernel_size,int):
                        k=[kernel_size]*2
                    else: 
                        k=kernel_size
                    return tuple(map(lambda x:(x-1)//2, k))
            else:
                raise KeyError('Expected padding with str-type to be one of [\'valid\',\'same\']'\
                              f', but receive: {pad}')
        elif isinstance(pad,[tuple]):
            return pad

In [None]:
class MaxPool2D(nn.Module):
    def __init__(self,pool_size=2,strides=None):
        '''
        args:
          pool_size (name from keras): default 2 (from keras)
        '''
        super().__init__()
        self.pool=nn.MaxPool2d(kernel_size=pool_size,stride=strides)
        
    def execute(self,x):
        return self.pool(x)

In [5]:
class UpSample2D(nn.Module):
    def __init__(self,scale_factor=2,interpolation='nearest'):
        '''
        args:
          scale_factor (name from jittor/torch): default 2 (from keras)
          interpolation (name from opencv/keras)
        '''
        super().__init__()
        self.upsample=nn.Upsample(scale_factor=scale_factor,mode=interpolation)
    
    def execute(self,x):
        return self.upsample(x)
    
UpSampling2D=UpSample2D

In [6]:
class Concat(nn.Module):
    def execute(self,inputs):
        return jt.concat(inputs, dim=cAxis)