In [None]:
import numpy as np
import pandas as pd
import scipy.io as io
import sklearn

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.io import Dataset
from paddle.io import DataLoader
from paddle.io import TensorDataset

In [None]:
class BAP(nn.Layer):
    def __init__(self,in_features=8): 
        super(BAP,self).__init__()   
#         self.weight = nn.Sequential(
#             nn.Linear(in_features=in_features, out_features=in_features*2),
#             nn.ReLU(),
#             nn.Linear(in_features=in_features*2, out_features=in_features),
#             nn.Sigmoid())
#         self.flatten = nn.Flatten()
    def forward(self, x):   
        w = paddle.mean(x, axis=-1)
        w = paddle.mean(w, axis=1)
#         w = self.flatten(w)
#         w = self.weight(w)
        w = w.reshape([x.shape[0]]+[1]+[x.shape[2]]+[1])
        return w

class CAFM(nn.Layer):
    def __init__(self,in_channels,Gamma=0): 
        super(CAFM,self).__init__()   
#         self.conv1 = nn.Conv2D(in_channels=in_channels,out_channels=1,kernel_size=1)
        self.bap = BAP()
        self.gamma_ = paddle.create_parameter(
            shape=[1], dtype='float32',
            default_initializer= paddle.nn.initializer.Constant(Gamma) )
    def forward(self, c, v, f):   
        c = self.bap(c)
        f1 = f*c
        gamma = F.sigmoid(self.gamma_)
        f2 = f*gamma + v*(1-gamma)
        f = f1+f2
        return f
# paddle.summary(CAFM(3,0),[(2,3,8,256),(2,3,8,256),(2,3,8,256)])##batch @ channel @ band @ T

class con_bn_relu_maxpool(nn.Layer):
    def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0): 
        super(con_bn_relu_maxpool,self).__init__() 
        self.CBRM = nn.Sequential(   
            nn.Conv2D(in_channels, out_channels, kernel_size,stride,padding),
            nn.BatchNorm2D(out_channels),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=(1,2),stride=(1,2)),
            )
    def forward(self, x):
        x= self.CBRM(x)
        return x

class Classifier(nn.Layer):
    def __init__(self,in_features,out_features=[128,3], drop=0.4):
        super(Classifier,self).__init__()
        self.fc = nn.Sequential(
            nn.Flatten(),
            #nn.Dropout(drop),
            nn.Linear(in_features, out_features[0]),
            nn.ReLU(),
            nn.Linear(out_features[0],out_features[1]),
            nn.Softmax(), )
    def forward(self,x):
        x = self.fc(x)
        return x
    
class Classifier_(nn.Layer):
    def __init__(self,in_features,out_features=[128,3], drop=0.4):
        super(Classifier_,self).__init__()
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(drop),
            nn.Linear(in_features, out_features[0]),
            nn.ReLU(),
            nn.Linear(out_features[0],out_features[1]))
    def forward(self,x):
        x = self.fc(x)
        return x
    
class B_att(nn.Layer):
    def __init__(self,in_features=8): 
        super(B_att,self).__init__()   
        self.weight = nn.Sequential(
            nn.Linear(in_features=in_features, out_features=in_features*2),
            nn.ReLU(),
            nn.Linear(in_features=in_features*2, out_features=in_features),
            nn.Sigmoid())
        self.flatten = nn.Flatten()
    def forward(self, x):   
        w = paddle.mean(x, axis=-1)
        w = self.flatten(w)
        w = self.weight(w)
        w = w.reshape(x.shape[:3]+[1])
        x = x*w + x
        return x
    
class CAVFNet(nn.Layer):
    def __init__(self, num_classes=8): 
        super(CAVFNet,self).__init__()   
        self.layer11 = con_bn_relu_maxpool(1,  16, (3,3), (1,2) ,(1,1) )
        self.layer12 = con_bn_relu_maxpool(16, 32, (3,3), (1,2) ,(1,1) )
        self.layer13 = con_bn_relu_maxpool(32, 32, (3,3), (1,2) ,(1,1) )
        self.layer21 = con_bn_relu_maxpool(1,  16, (3,3), (1,2) ,(1,1) )
        self.layer22 = con_bn_relu_maxpool(16, 32, (3,3), (1,2) ,(1,1) )
        self.layer23 = con_bn_relu_maxpool(32, 32, (3,3), (1,2) ,(1,1) )
        self.layer31 = con_bn_relu_maxpool(1,  16, (3,3), (1,2) ,(1,1) )
        self.layer32 = con_bn_relu_maxpool(16, 32, (3,3), (1,2) ,(1,1) )
        self.layer33 = con_bn_relu_maxpool(32, 32, (3,3), (1,2) ,(1,1) )
        self.cafm1 = CAFM(16,0.5)
        self.cafm2 = CAFM(32,0.5)
        self.cafm3 = CAFM(32,0.5)
        self.ba = B_att(8)

        self.fc1 = Classifier_(32*8*4,[128,num_classes],0.4)
        self.fc2 = Classifier_(32*8*4,[128,num_classes],0.4)
        self.fc3 = Classifier_(32*8*4,[128,num_classes],0.4)
        self.fc = nn.Sequential(
            nn.Linear(in_features=num_classes*2,out_features=2),
            nn.Softmax() )
    def forward(self, v, c):    
        frames_0 = self.layer11(c)
        frames_1 = self.layer21(v)
        fusion = self.layer31(v)
        fusion = self.cafm1(frames_0,frames_1,fusion)
        frames_0 = self.layer12(frames_0)
        frames_1 = self.layer22(frames_1)
        fusion = self.layer32(fusion)
        fusion = self.cafm2(frames_0,frames_1,fusion)
        frames_0 = self.layer13(frames_0)
        frames_1 = self.layer23(frames_1)
        fusion = self.layer33(fusion)
        fusion = self.cafm3(frames_0,frames_1,fusion)
        frames_0 = F.softmax(self.fc1(frames_0))
        frames_1 = self.fc2(frames_1)
        fusion  = self.fc3(fusion)
        xi = paddle.concat([frames_1, fusion],axis=1)
        xi = self.fc(xi)
        frames_1 = F.softmax(frames_1)
        fusion = F.softmax(fusion)
        fusion = frames_1*(xi[:,0]).reshape([v.shape[0],1]) + fusion *(xi[:,1]).reshape([v.shape[0],1])      
        return  fusion