<a href="https://colab.research.google.com/github/LiangShuLing/TensorFlowLearning/blob/main/jax/jax_ResNet50_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import numpy.random as npr
import jax.numpy as jnp
from jax import jit,grad,random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import AvgPool,BatchNorm,Dense,FanInSum,FanOut,Flatten, GeneralConv,Identity,MaxPool,Relu,LogSoftmax

jax.experimental.stax里面包含很多函数，大部分都是定义网络的函数，比如Dense,GeneralConv, Conv--实现了GeneralConv;
GeneralConvTranspose, 

池化： AvgPool, MaxPool

BatchNorm层：参数标准化

FanOut(b): fan-out layer,对输入进行泛化处理，整体乘以参数b

FanInSum(): FanIn sum层，对输入进行求和处理

函数：实现了jax.nn里面的函数
Tanh，Relu，Exp，LogSoftmax，Softmax，Softplus，Sigmoid ，Elu，LeakyRelu，Selu，Gelu

shape_dependent(make_layer): 
  """Combinator to delay layer constructor pair until input shapes are known.

In [6]:
#通过Conv层实现ConvBlock
def ConvBlock(kernel_size,filters,strides=(2,2)):
  ks=kernel_size
  filter1,filter2,filter3=filters

  Main=stax.serial(
      Conv(filter1,(1,1),strides),BatchNorm(),Relu,
      Conv(filter2,(ks,ks),padding='SAME'),BatchNorm(),Relu,
      Conv(filter3,(1,1)),BatchNorm())
  #定义网络结构，类似于keras的Sequential容器,下面再定义一个网络结构
  Shortcut=stax.serial(Conv(filter3,(1,1),strides),BatchNorm())
  return stax.serial(FanOut(2),stax.parallel(Main,Shortcut),FanInsum,Relu)  #再次拼接模型并返回



In [7]:
def IdentityBlock(kernel_size,filters):
  ks=kernel_size
  filter1,filter2=filters
  def make_main(input_shape):
    return stax.serial(
        Conv(filter1,(1,1)),BatchNorm(),Relu,
        Conv(filter2,(ks,ks),padding='SAME'),BatchNorm(),Relu,
        Conv(input_shape[3],(1,1)),BatchNorm())
    
    Main=stax.shape_dependent(make_main)   #等到输入确定后再工作的延迟层
    return stax.serial(FanOut(2),stax.parallel(Main,Identity),FanInSum,Relu)


In [8]:
 #num_classes定义输出层多少类
def ResNet50(num_classes): 
  return stax.serial(
      GeneralConv(('HWCN','OIHW','NHWC'),64,(7,7),(2,2),'SAME'),
      BatchNorm(),Relu,MaxPool((3,3),strides=(2,2)),
      ConvBlock(3,[64,64,256],strides=(1,1)),
      IdentityBlock(3, [64, 64]),
      IdentityBlock(3, [64, 64]),
      ConvBlock(3, [128, 128, 512]),
      IdentityBlock(3, [128, 128]),
      IdentityBlock(3, [128, 128]),
      IdentityBlock(3, [128, 128]),
      ConvBlock(3, [256, 256, 1024]),
      IdentityBlock(3, [256, 256]),
      IdentityBlock(3, [256, 256]),
      IdentityBlock(3, [256, 256]),
      IdentityBlock(3, [256, 256]),
      IdentityBlock(3, [256, 256]),
      ConvBlock(3, [512, 512, 2048]),
      IdentityBlock(3, [512, 512]),
      IdentityBlock(3, [512, 512]),
      AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax
  ) 



In [9]:
def main_function():
  rng_key=random.PRNGKey(0)
  batch_size=8
  num_classes=1001
  input_shape=(224,224,3,batch_size)   #batch_size也定义了input_shape里面，同时后面synth_batches函数也使用了batch_size
  step_size=0.1
  num_steps=10

#batch-->输入的train_db
  def loss(params,batch):
    inputs,targets=batch
    logits=predict_fun(params,inputs)    #前向计算
    return -jnp.sum(logits*targets)      #计算target与预测值logits的不同

  def accuracy(params,batch):
    inputs,targets=batch
    target_class=jnp.argmax(target,axis=1)  #取最大值的索引，也就是预测的class类型
    predicted_class=jnp.argmax(predict_fun(params,inputs),axis=-1)
    return jnp.mean(predicted_class==target_class)  #如果类型相同就返回1否则0，返回求取平均值就是预测精度

  def synth_batches():
    rng=npr.RandomState(0)
    while True:
      images=rng.rand(*input_shape).astype('float32')     #通过随机数生成一张图片，实际上是通过导入获取的输入
      labels=rng.randint(num_classes,size=(batch_size,1)) #labels取0-1000的随机数[1,3,5,....34]; shape=[batchsize,1],这是一个二维数组
      onehot_labels=labels==jnp.arange(num_classes)       #num_classess定义了某一类，比如1001，jnp.arange(1001)=[0,1,....,1001]，#与上面的二维数组对比可以返回batch_size个行，num_classess个列，对应相同数字为true，其他为false,  见下面的例子
      yield images, oneHot_label

      #简单地讲，yield 的作用就是把一个函数变成一个 generator，带有 yield 的函数不再是一个普通函数，Python 解释器会将其视为一个 generator，
      # 调用函数不会执行简单执行函数，而是返回一个 iterable 对象


#定义反向传播计算，opt_state储存了params
  @jit
  def update(i, opt_state,batch):
    params=get_params(opt_state)      
    return opt_update(i,grad(loss)(params,batchs),opt_state)


  opt_init,opt_update,get_params=optimizers.momentum(step_size,mass=0.9)  #实例化优化器,返回参数初始化，更新函数，以及获取参数函数
  batches=synth_batches()         #获取input以及label，如果有自己的数据需要手动修改synth_batch函数

  init_fun,predic_fun=ResNet50(num_classes)   #初始化参数模型与网络模型
  _,init_params=init_fun(rng_key,input_shape) #通过随机key与shape初始化参数

  opt_state=opt_init(init_params)             #传入初始化参数，再次用优化器初始化方法再次初始化一次
  for i in range(num_steps):                  #训练
    opt_state=update(i,opt_state,next(batches))
  
  trained_params=get_params(opt_state)



                           
        






SyntaxError: ignored

In [10]:
import jax.numpy as jnp
import numpy as np
def func():
  a=np.random.randint(10,size=(5,1))
  b=jnp.arange(10)
  oneHot_label=a==b
  return oneHot_label
input_shape=[6,6,3]
batch=np.random.rand(*input_shape).astype('float32')
print(func())
print(batch)

[[False False False False False False False False  True False]
 [False False False False  True False False False False False]
 [False False False False False False False False  True False]
 [False False False False False False False False  True False]
 [False  True False False False False False False False False]]
[[[0.996387   0.87808317 0.8637553 ]
  [0.21135399 0.9317122  0.46110785]
  [0.1203807  0.6731061  0.5242455 ]
  [0.6146163  0.32530114 0.02769888]
  [0.2761006  0.05915179 0.06862331]
  [0.14095677 0.52993107 0.22641619]]

 [[0.9442227  0.6381997  0.3695001 ]
  [0.9416853  0.5664792  0.6032861 ]
  [0.7807215  0.06147199 0.3829453 ]
  [0.6002664  0.18294187 0.62751067]
  [0.43741965 0.4259239  0.19758964]
  [0.2659596  0.6253737  0.9951261 ]]

 [[0.8369199  0.22072919 0.07965129]
  [0.9641116  0.38260308 0.29199386]
  [0.94261    0.14893198 0.11456693]
  [0.68703055 0.27786544 0.5604089 ]
  [0.5024627  0.59090036 0.04744342]
  [0.88360524 0.94454193 0.41223547]]

 [[0.4826714