## examples

In [4]:
from ptflops import get_model_complexity_info
from torchvision.models import resnet50
from torchvision.models import alexnet
model = resnet50()
model_alex = alexnet()

print (type(model_alex))

# 实例化model对象后直接调用下面代码
flops, params = get_model_complexity_info(model_alex, (3,224,224),as_strings=True,print_per_layer_stat=True)

print("%s |flops: %s |params: %s" % ('ResNet50',flops,params))

<class 'torchvision.models.alexnet.AlexNet'>
AlexNet(
  61.101 M, 100.000% Params, 0.716 GMac, 100.000% MACs, 
  (features): Sequential(
    2.47 M, 4.042% Params, 0.657 GMac, 91.804% MACs, 
    (0): Conv2d(0.023 M, 0.038% Params, 0.07 GMac, 9.848% MACs, 3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 0.027% MACs, inplace=True)
    (2): MaxPool2d(0.0 M, 0.000% Params, 0.0 GMac, 0.027% MACs, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(0.307 M, 0.503% Params, 0.224 GMac, 31.316% MACs, 64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 0.020% MACs, inplace=True)
    (5): MaxPool2d(0.0 M, 0.000% Params, 0.0 GMac, 0.020% MACs, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(0.664 M, 1.087% Params, 0.112 GMac, 15.681% MACs, 192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(0.0 M, 0.000% 

## Base Blocks

In [5]:
##
import torch
import torch.nn as nn
import numpy as np
from torch.nn import functional as F
from collections import Counter

class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 relu = True, bn = False, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs)
        if bn:
            self.bn = nn.BatchNorm2d(out_channels)
        else:
            self.bn = None
        self.relu = relu

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu:
            # 
            x = F.leaky_relu(x, inplace = True)
        return x
                                                   
                                                   
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, input):
        return input.view(input.size(0), -1)


class Means(nn.Module):
    def __init__(self):
        super(Means, self).__init__()

    def forward(self, input):
        return torch.mean(input, dim=(1, 2, 3)).unsqueeze(-1)

        
class ZeroOuts(nn.Module):
    def __init__(self):
        super(ZeroOuts, self).__init__()

    def forward(self, x):
        batchSize = x.size()[0]
        return torch.zeros(batchSize, 4, 1, 1)

## FCN

In [63]:
class _FCN(nn.Module):
    def __init__(self, in_channel, hidden_channel): 
        super(_FCN, self).__init__()
        
        # Bottleneck
        self.encoder = nn.Sequential(
            BasicConv2d(in_channel, hidden_channel, 1, bn = True),
            BasicConv2d(hidden_channel, hidden_channel, 3, bn = True, padding = 1),
            BasicConv2d(hidden_channel, 128, 1, bn = True), 
        )
        self.shortcut = BasicConv2d(in_channel, 128, 1, bn=True)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(
            Flatten(),
            nn.Linear(128 * 1 * 1, 32),
            nn.ReLU(True),
            nn.Dropout(0.3),
            nn.Linear(32, 1),
        )            
             
            
    def forward(self, x):                                     
        res = self.shortcut(x)
        x = self.encoder(x)
        x = x + res
        fcn_x = self.pool(x)
        fcn_x = self.fc(fcn_x)
        return fcn_x
    
###
fcn = _FCN(
    27,
    256,
)

print (type(fcn))

# HR-ECb 57 × 6 × 29 × 29 
# (57,29,29)
# ERA5 27 × 6 × 16 × 16 
# (27,16,16)
flops, params = get_model_complexity_info(fcn, (27,16,16),as_strings=True,print_per_layer_stat=False)
    
print("%s |flops: %s |params: %s" % ('FCN:', flops, params))



<class '__main__._FCN'>
FCN: |flops: 0.16 GMac |params: 638.66 k


## MLP

In [65]:
class _MLP(nn.Module):
    def __init__(self, in_channel, hidden_channel): 
        super(_MLP, self).__init__()
        
        # Bottleneck
        self.mlp = nn.Sequential(
             nn.Linear(in_channel, 512),
             nn.Linear(512, 256),
             nn.Linear(256, 128), 
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * 1 * 1, 32),
            nn.ReLU(True),
            nn.Dropout(0.3),
            nn.Linear(32, 1),
        )            
             
            
    def forward(self, x):                                     
        x = self.mlp(x)
        y = self.fc(x)
        return y
    
###
mlp = _MLP(
    27,
    256,
)

print (type(mlp))

# HR-ECb 57 × 6 × 29 × 29 
# (57,29,29)
# ERA5 27 × 6 × 16 × 16 
# (27,16,16)
flops, params = get_model_complexity_info(mlp, (1,27), as_strings=True, print_per_layer_stat=False)
print("%s |flops: %s |params: %s" %('MLP:',flops, params))


<class '__main__._MLP'>
MLP: |flops: 0.0 GMac |params: 182.72 k


## FPN

In [68]:
class bottleneck_block(nn.Module):
    def __init__(self, in_channel, out_channel, curr_stride):
        super(bottleneck_block, self).__init__()        
        self.shortcut = BasicConv2d(in_channel, out_channel, 1, bn=True, stride=curr_stride)
        # Bottleneck
        self.bk = nn.Sequential(
            BasicConv2d(in_channel, 64, 1, bn=True),
            # change scale
            BasicConv2d(64, 64, 3, bn=True, stride=curr_stride, padding=1),
            BasicConv2d(64, out_channel, 1, bn=True),
            )
        
    def forward(self, x):
        res = self.shortcut(x)
        x = self.bk(x)
        b_x = x + res 
        return b_x

class _FPN(nn.Module):
    def __init__(self, in_channel, hidden_channel):
        super(_FPN, self).__init__()
        # 57 - 64
        self.smooth_s = BasicConv2d(in_channel, hidden_channel, 3, bn = True, padding=1)
        # 64 - 64 stride[1,1]
        self.c1 = nn.Sequential(
            bottleneck_block(hidden_channel, hidden_channel*2, 1),
            bottleneck_block(hidden_channel*2, hidden_channel, 1),
        )
        ## 瓶颈网络: 一大一小
        # 64 - 128 stride[2,1]
        self.c2 = nn.Sequential(
            bottleneck_block(hidden_channel, hidden_channel*4, 2),
            bottleneck_block(hidden_channel*4, hidden_channel*2, 1),
        )
        # 128 - 256 stride[2,1]
        self.c3 = nn.Sequential(
            bottleneck_block(hidden_channel*2, hidden_channel*8, 2),
            bottleneck_block(hidden_channel*8, hidden_channel*4, 1),
        )        
        # 256 - 512 stride[2,1]
        self.c4 = nn.Sequential(
            bottleneck_block(hidden_channel*4, hidden_channel*16, 2),
            bottleneck_block(hidden_channel*16, hidden_channel*8, 1),
        ) 
        #
        self.top = BasicConv2d(hidden_channel*8, hidden_channel*4, 3, bn = True, padding=1)
        self.latlayer1 = BasicConv2d(hidden_channel*4, hidden_channel*4, 1, bn = True)
        self.latlayer2 = BasicConv2d(hidden_channel*2, hidden_channel*4, 1, bn = True)
        self.latlayer3 = BasicConv2d(hidden_channel, hidden_channel*4, 1, bn = True)
        self.smooth_e = BasicConv2d(hidden_channel*4, hidden_channel, 3, bn = True, padding=1)
        #
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        # fc
        self.fc = nn.Sequential(
            Flatten(),
            nn.Linear(hidden_channel * 1 * 1, 32),
            nn.ReLU(True),
            nn.Dropout(0.3),
            nn.Linear(32, 1),
        )
        
    def res(self, p, c):
        _,_,H,W = c.size()
        res_cp = c + F.upsample(p, size=(H,W), mode='bilinear')
        return res_cp
    
    def forward(self, x):
        # Bottom-up
        c1 = self.smooth_s(x) 
        c2 = self.c1(c1)
        c3 = self.c2(c2)
        c4 = self.c3(c3)
        c5 = self.c4(c4)
        # Top-down
        p5 = self.top(c5)
        # fuse and hidden-smooth 
        p4 = self.res(p5, self.latlayer1(c4))
        p3 = self.res(p4, self.latlayer2(c3))
        p2 = self.res(p3, self.latlayer3(c2))
        # Smooth
        p2 = self.smooth_e(p2)
        # pooling HW for p2
        p2 = self.pool(p2)
        # last layer in p
        pred = self.fc(p2)                   
        return pred
    

###
fpn = _FPN(
    27,
    64,
)

print (type(fpn))

# HR-ECb 57 × 6 × 29 × 29 
# (57,29,29)
# ERA5 27 × 6 × 16 × 16 
# (27,16,16)
flops, params = get_model_complexity_info(fpn, (27,16,16), as_strings=True, print_per_layer_stat=False)
    
print("%s |flops: %s |params: %s" % ('FPN:', flops, params))    

<class '__main__._FPN'>
FPN: |flops: 0.11 GMac |params: 3.16 M


  "See the documentation of nn.Upsample for details.".format(mode))


## ConvLSTM

In [73]:
class ConvLSTMCell(nn.Module):
    def __init__(self, tsLength, inputChannelNums, hiddenChannelNums, ecsize, device, layer_flag):
        super(ConvLSTMCell, self).__init__()
        if layer_flag==True:
            self.ts_conv = nn.Conv2d(in_channels=(inputChannelNums + hiddenChannelNums),
                                     out_channels=hiddenChannelNums * 4,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
            
        else:
            self.ts_conv = nn.Conv2d(in_channels=hiddenChannelNums * 2,
                                     out_channels=hiddenChannelNums * 4,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)            
        
           
        self.ts_bn = nn.BatchNorm2d(hiddenChannelNums * 4)
        # smooth
        self._conv = nn.Sequential(
             nn.Conv2d(in_channels=hiddenChannelNums, out_channels=hiddenChannelNums, kernel_size=3, stride=1, padding = 1)
#             BasicConv2d(128, 128, 3,  padding = 1),
#             BasicConv2d(128, 32, 3, padding =1),
#             BasicConv2d(32, 32, 1),
        )
        
        self.ac = nn.ReLU(True)        
        self.pooling = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        
        self.fc = nn.Sequential(
            Flatten(),
            nn.Linear(hiddenChannelNums, 1),
        )
        #
        self._state_height, self._state_width = ecsize, ecsize
        # LSTM W
        self.W_ci = nn.Parameter(torch.zeros(1, hiddenChannelNums, self._state_height, self._state_width))
        self.W_cf = nn.Parameter(torch.zeros(1, hiddenChannelNums, self._state_height, self._state_width))
        self.W_co = nn.Parameter(torch.zeros(1, hiddenChannelNums, self._state_height, self._state_width))
        # LSTM B
        self.b_i = nn.Parameter(torch.zeros(1, hiddenChannelNums, self._state_height, self._state_width))
        self.b_f = nn.Parameter(torch.zeros(1, hiddenChannelNums, self._state_height, self._state_width))
        self.b_c = nn.Parameter(torch.zeros(1, hiddenChannelNums, self._state_height, self._state_width))
        self.b_o = nn.Parameter(torch.zeros(1, hiddenChannelNums, self._state_height, self._state_width))
        #
        self._seq = tsLength 
        self._input_channel =  inputChannelNums
        self._hidden_channel_nums = hiddenChannelNums
        #
#         self.device = device
        self.layer_flag = layer_flag
               
    # inputs: N * C * D * h * w  
    # num_filter is channel number of h 
    def forward(self, inputs = None, states = None):
        # D * N * C * h * w origin inputs        
        if self.layer_flag==True:
            inputs = inputs.transpose(0,2).transpose(1,2)
        # 
        else:   
            inputs = torch.stack(inputs)            

        if states is None:
            # LSTM Cell and hidden
            c = torch.zeros((inputs.size(1), self._hidden_channel_nums, self._state_height,
                                  self._state_width), dtype = torch.float)
            h = torch.zeros((inputs.size(1), self._hidden_channel_nums, self._state_height,
                             self._state_width), dtype = torch.float)
            
        else:
            h, c = states

        outputs = []
        #
        for index in range(self._seq):
            if inputs is None:
                x = torch.zeros((h.size(0), self._input_channel, self._state_height,
                                      self._state_width), dtype=torch.float)
            else:
                x = inputs[index]
                
            st_x = torch.cat([x, h], dim = 1)
            #
            conved_st_x = self.ts_conv(st_x)
            conved_st_x_bn = self.ts_bn(conved_st_x)
            conv_i, conv_f, conv_c, conv_o = torch.chunk(conved_st_x_bn, 4, dim = 1)
            # save spatiotemporal features
            i = torch.sigmoid(conv_i + self.W_ci * c + self.b_i)
            f = torch.sigmoid(conv_f + self.W_cf * c + self.b_f)
            # c_{t-1} → c_t
            c = f * c + i * torch.tanh(conv_c + self.b_c)
            #
            o = torch.sigmoid(conv_o + self.W_co * c + self.b_o)
            h = o * torch.tanh(c)
            
            # output y_hat
            x_deep = self._conv(h)
            x_relu = self.ac(x_deep)
            x_pooling = self.pooling(x_relu)
            y_hat_TimeStamp_last = self.fc(x_pooling)
            #
            outputs.append(h)        
        # N * hiC * h * w (hidden in the last timeStamp) | N * 1                    
        return outputs, y_hat_TimeStamp_last
    
    
class _ConvLSTM(nn.Module):
    def __init__(self, tsLength, inputChannelNums, hiddenChannelNums, ec_size, devices):
        super(_ConvLSTM, self).__init__()
        self._w_h = ec_size
#         self.device = devices                
        # ts representation - stacked 2 layers
        self.encoderConvLSTM_layer1 = ConvLSTMCell(tsLength, inputChannelNums, hiddenChannelNums, ec_size, devices, True)
        #
        self.encoderConvLSTM_layer2 = ConvLSTMCell(tsLength, inputChannelNums, hiddenChannelNums, ec_size, devices, False)        
    def forward(self, x): 
        # ConvLSTM
        outputsL1, _ = self.encoderConvLSTM_layer1(x, states = None)
        _, ts_y_last = self.encoderConvLSTM_layer2(outputsL1, states = None)                             
        return ts_y_last
    

###
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
###    
conv_lstm = _ConvLSTM(
    6,
    27,
    256,
    16,
    device     
    )

# HR-ECb 57 × 6 × 29 × 29 
# (57,29,29)
# ERA5 27 × 6 × 16 × 16 
# (27,16,16)
flops, params = get_model_complexity_info(conv_lstm, (27,6,16,16), as_strings=True, print_per_layer_stat=False)
    
print("%s |flops: %s |params: %s" % ('ConvLSTM:',flops, params)) 


ConvLSTM: |flops: 13.08 GMac |params: 9.43 M


## OBA

In [61]:
"""------- single base classifer for ordinal-------"""
class BaseClassifier(nn.Module):
    def __init__(self):
        super(BaseClassifier, self).__init__()
        self.conv = nn.Sequential(
            BasicConv2d(64, 128, 3, padding = 1),
            BasicConv2d(128, 128, 3, padding = 1),
            BasicConv2d(128, 32, 1),
            BasicConv2d(32, 32, 1),
        )

        self.pooling = nn.AdaptiveAvgPool2d(output_size=(1, 1))

        self.downsampler = nn.Sequential(
            nn.Conv2d(64, 32, 1, padding=0),
            nn.BatchNorm2d(32)
        )

        self.ac = nn.ReLU(True)
        # sigmoid for multi-labels and outputing a probability
        self.fc = nn.Sequential(
            Flatten(),
            nn.Dropout(0.5),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

# 3
class OrdinalRegressionModel(nn.Module):
    def __init__(self, nClass):
        super(OrdinalRegressionModel, self).__init__()
        self.nClass = nClass
        self.boosting = nn.ModuleList()

        for i in range(self.nClass):
            oneClassifier = BaseClassifier()
            self.boosting.append(oneClassifier)

    def forward(self, x):
        # list sigmoid outputs from all classifers
        outputs = [self.boosting[i](x) for i in range(self.nClass)]
        # list → Tensor (torch.Size([1, nClass])
        return torch.cat(outputs, dim = 1)
    
# 2
class rainFallClassification(nn.Module):
    def __init__(self):
        super(rainFallClassification, self).__init__()

        self.conv = nn.Sequential(
            BasicConv2d(57, 64, 3, bn = True, padding=1),
            BasicConv2d(64, 64, 3, bn = True, padding=1),
            BasicConv2d(64, 128, 3, bn = True, padding=1),
            BasicConv2d(128, 128, 3, bn = True, padding=1),
        )

        self.downsample = nn.Sequential(
            BasicConv2d(57, 128, 1, bn = True, relu = False, padding=0)
        )

        self.pool = nn.AdaptiveAvgPool2d((1, 1))

        self.fc = nn.Sequential(
            Flatten(),
            nn.Linear(128 * 1 * 1, 32),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(32, 2),
        )

    def forward(self, x):
        residual = self.downsample(x)
        x = self.conv(x)
        x = x + residual
        x = self.pool(x)
        x = self.fc(x)
        return x


# 1
class AutoencoderBN(nn.Module):
    def __init__(self):
        super(AutoencoderBN, self).__init__()

        self.encoder = nn.Sequential(
            BasicConv2d(57, 32, 1, bn = True),
            BasicConv2d(32, 32, 3, bn = True, padding = 1),
        )
        
        self.encoderAfterNoise = nn.Sequential(
            nn.MaxPool2d(2),
            # -------------------------------------
            BasicConv2d(32, 64, 3, bn = True, padding = 1),
            BasicConv2d(64, 64, 3, bn = True, padding = 1),
        )

        self.decoder = nn.Sequential(
            BasicConv2d(64, 32, 3, bn = True, padding = 1),
            # upsample and output size → C × 29 × 29 
            nn.Upsample(size = (29, 29), mode ='bilinear',align_corners = True),
            BasicConv2d(32, 32, 3, bn = True, padding = 1),
            
            BasicConv2d(32, 57, 1, bn = True),
        )

    def forward(self, x):
        encoder = self.encoder(x)
        encoder = self.encoderAfterNoise(encoder)
        decoder = self.decoder(encoder)
        return encoder, decoder
    
autoModel = AutoencoderBN()
regressionModel = OrdinalRegressionModel(71)
rainFallClassifierModel = rainFallClassification()   
# bs = BaseClassifier()


# HR-ECb 57 × 6 × 29 × 29 
# (57,29,29)
# ERA5 27 × 6 × 16 × 16 
# (27,16,16)
flops_1, params_1 = get_model_complexity_info(autoModel, (57,29,29), as_strings=True, print_per_layer_stat=False)
# 71
# flops_2, params_2 = get_model_complexity_info(bs, (64,29,29), as_strings=True, print_per_layer_stat=False)
flops_3, params_3 = get_model_complexity_info(rainFallClassifierModel, (57,29,29), as_strings=True, print_per_layer_stat=False)
     
    
# MLP: |flops: 0.0 GMac |params: 177.09 k

print("%s |flops: %s |params: %s" % ('autoModel:',flops_1, params_1)) 
# print("%s |flops: %s |params: %s" % ('BaseClassifier:',flops_2, params_2)) 
print("%s |flops: %s |params: %s" % ('rainFallClassifierModel:', flops_3, params_3))

# ECB 0.0071+0.03+0.25 = 0.29 GMac | params 96.43 + 303.39 + 177.09 * 71 = 12.67M
# ERA5 0.0071 + 0.02 + 0.07 = 0.097 GMac | params 95.47 + 282.27 + 177.09 * 71 = 12.65M 

autoModel: |flops: 0.03 GMac |params: 96.43 k
rainFallClassifierModel: |flops: 0.25 GMac |params: 303.39 k


## SSAS

In [13]:
from pre_training.deformConv.layers import ConvOffset2D


## util classes
class BasicConv3d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 relu = True, bn = False, **kwargs):
        super(BasicConv3d, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, bias=False, **kwargs)
        
        if bn:
            self.bn = nn.BatchNorm3d(out_channels)
        else:
            self.bn = None
        self.relu = relu

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu:
            # 
            x = F.leaky_relu(x, inplace = True)
        return x

## Pre-training

# C3AE_3DI   
class C3AE_3DI(nn.Module):
    def __init__(self, seq, Range, curscale, ED_out_channel):
        super(C3AE_3DI, self).__init__()
        self.curscale = curscale
        prob_len = len(Range)
        # kernal depth 2 and shared weight
        self.Conv_3d = BasicConv3d(ED_out_channel, 64, (2,3,3), bn=True, padding=(0,1,1))
        # Pooling
#         self.pooling  = nn.AdaptiveAvgPool2d(output_size=(1, 1))     
        # crop
        self.fc_prob = nn.Linear(64*curscale*curscale, prob_len)
        self.fc_pred = nn.Linear(prob_len, 1)
                                  
    def forward(self, x):
        # N1*C1*D*H_u*W_u: D → D-1 ... → 1
        flag = 0
        while (x.size(2)) >1:
            x = self.Conv_3d(x)
#             print ('curscale:', x.size(2))
        x = torch.flatten(x, start_dim=1)        
        # N1 * len(range)                         
        prob = self.fc_prob(x)
        # N1                         
        pred = self.fc_pred(prob)                          
        return prob, pred

class DeformConvNetI(nn.Module):
    def __init__(self, in_channel):
        super(DeformConvNetI, self).__init__()
        
        # shallow feature map
        self.shortcut = BasicConv2d(in_channel, 256, 1, bn=True, padding=0)
        # Bottleneck
        self.bk = nn.Sequential(
            BasicConv2d(256, 64, 1, bn=True),
            BasicConv2d(64, 64, 3, bn=True, padding=1),
            BasicConv2d(64, 256, 1, bn=True),
            )
        # smooth
        self.smo = BasicConv2d(256, 32, 3, bn=True, padding=1)
        
        # Dconv1
        self.offset1 = ConvOffset2D(32)
        self.conv1 = BasicConv2d(32, 64, 3, bn=True, padding=1)

#         # Dconv2
#         self.offset2 = ConvOffset2D(64)
#         self.conv2 = BasicConv2d(64, 128, 3, bn=True, padding=1)

        # Dconv3
        self.offset3 = ConvOffset2D(64)
        self.conv3 = BasicConv2d(64, 128, 3, bn=True, padding=1)

        # Pooling
        self.pooling  = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        
        # MLP
        self.fc = nn.Sequential(
            Flatten(),
            nn.Dropout(0.5),            
            nn.Linear(128,32),
            nn.Linear(32, 1),
            )

    def forward(self, x):
        res = self.shortcut(x)
        x = self.bk(res) + res
        
        x = self.smo(x)
  
        x = self.offset1(x)
        x = self.conv1(x)
        
#         x = self.offset2(x)
#         x = self.conv2(x)
        
        x = self.offset3(x)
        x = self.conv3(x)        
        
        x = self.pooling(x)  
        x = self.fc(x)
        return x
    
# class DeformConvNetII        
class DeformConvNetII(nn.Module):
    def __init__(self, in_channel):
        super(DeformConvNetII, self).__init__()
        
        # shallow feature map
        self.shortcut = BasicConv2d(in_channel, 256, 1, bn=True, padding=0)
        # Bottleneck
        self.bk = nn.Sequential(
            BasicConv2d(256, 64, 1, bn=True),
            BasicConv2d(64, 64, 3, bn=True, padding=1),
            BasicConv2d(64, 256, 1, bn=True),
            )
        # smooth
        self.smo = BasicConv2d(256, 32, 3, bn=True, padding=1)
        
        # Dconv1
        self.offset1 = ConvOffset2D(32)
        self.conv1 = BasicConv2d(32, 64, 3, bn=True, padding=1)

        # Dconv2
#         self.offset2 = ConvOffset2D(64)
#         self.conv2 = BasicConv2d(64, 128, 3, bn=True, padding=1)

        # Dconv3
        self.offset3 = ConvOffset2D(64)
        self.conv3 = BasicConv2d(64, 128, 3, bn=True, padding=1)
        # Pooling
        self.pooling  = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        
        # MLP branch (2 tasks)
        self.flat = Flatten()
        self.fc_dre = nn.Linear(128, 1)
        self.fc_voc = nn.Linear(128, 1)        
        

    def forward(self, x):
        res = self.shortcut(x)
        x = self.bk(res) + res
        
        x = self.smo(x)
  
        x = self.offset1(x)
        x = self.conv1(x)
        
#         x = self.offset2(x)
#         x = self.conv2(x)
        
        x = self.offset3(x)
        x = self.conv3(x)        
        
        x = self.pooling(x)
        x = self.flat(x)
        x_dre = self.fc_dre(x)
        x_voc = self.fc_voc(x)    
        x_total = torch.cat((x_dre, x_dre),-1) 
        return x_total  
    

class C3AE_3DII(nn.Module):
    def __init__(self, seq, wdRange, wsRange, curscale, ED_out_channel):
        super(C3AE_3DII, self).__init__()
        self.curscale = curscale
        prob_len_wd = len(wdRange)
        prob_len_ws = len(wsRange)
        
        self.curscale = curscale
        
        # kernal depth 2 and shared weight
        self.Conv_3d = BasicConv3d(ED_out_channel, 64, (2,3,3), bn=True, padding=(0,1,1))
        
        # Pooling
#         self.pooling  = nn.AdaptiveAvgPool2d(output_size=(1, 1))     
        # crop
        self.fc_wd_prob = nn.Linear(64*curscale*curscale, prob_len_wd)
        self.fc_ws_prob = nn.Linear(64*curscale*curscale, prob_len_ws)                   
        self.fc_wd_pred = nn.Linear(prob_len_wd, 1)
        self.fc_ws_pred = nn.Linear(prob_len_ws, 1)
                                                                    
    def forward(self, x):
        # N1*C1*D*H_u*W_u: D → D-1 ... → 1
        flag = 0
        while (x.size(2)) >1:
            x = self.Conv_3d(x)
#             print ('curscale:', x.size(2))
        x = torch.flatten(x, start_dim=1)        
        # N1 * len(range) 
        # N1*C(HW_u)                         
        prob_wd = self.fc_wd_prob(x)
        prob_ws = self.fc_ws_prob(x)
        # N1                         
        pred_wd = self.fc_wd_pred(prob_wd)
        pred_ws = self.fc_ws_pred(prob_ws)              
        return prob_wd, pred_wd, prob_ws, pred_ws
    
tem_MSM = DeformConvNetI(57)
wind_MSM = DeformConvNetII(57)

# HR-ECb 57 × 6 × 29 × 29 
# (57,29,29)
# ERA5 27 × 6 × 16 × 16 
# (27,16,16)
flops_1, params_1 = get_model_complexity_info(tem_MSM, (57,29,29), as_strings=True, print_per_layer_stat=False)
# 71
flops_2, params_2 = get_model_complexity_info(wind_MSM, (57,29,29), as_strings=True, print_per_layer_stat=False)


tem_MTM = C3AE_3DI(6, np.arange(-4,43,2), 18,  64)
wind_MTM = C3AE_3DII(6, np.arange(1,20,1), np.arange(5,400,5), 18, 64)

flops_3, params_3 = get_model_complexity_info(tem_MTM, (64,6,18,18), as_strings=True, print_per_layer_stat=True)
# 71
flops_4, params_4 = get_model_complexity_info(wind_MTM, (64,6,18,18), as_strings=True, print_per_layer_stat=True)

print("%s |flops: %s |params: %s" % ('tem_MSM:',flops_1, params_1)) 
print("%s |flops: %s |params: %s" % ('wind_MSM:',flops_2, params_2)) 
print("%s |flops: %s |params: %s" % ('tem_MTM:',flops_3, params_3)) 
print("%s |flops: %s |params: %s" % ('wind_MTM:',flops_4, params_4)) 


# ECB 0.21*5 = 1.05 GMac | params: 348.16 = 0.34 * 5  = 0.17 M
# ERA5 0.36 * 5 = 1.80 GMac | params 571.57 = 0.56 * 4 +2.11 = 4.35 M


C3AE_3DI(
  0.572 M, 100.000% Params, 0.359 GMac, 100.000% MACs, 
  (Conv_3d): BasicConv3d(
    0.074 M, 12.922% Params, 0.359 GMac, 99.862% MACs, 
    (conv): Conv3d(0.074 M, 12.899% Params, 0.358 GMac, 99.688% MACs, 64, 64, kernel_size=(2, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
    (bn): BatchNorm3d(0.0 M, 0.022% Params, 0.001 GMac, 0.173% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (fc_prob): Linear(0.498 M, 87.074% Params, 0.0 GMac, 0.138% MACs, in_features=20736, out_features=24, bias=True)
  (fc_pred): Linear(0.0 M, 0.004% Params, 0.0 GMac, 0.000% MACs, in_features=24, out_features=1, bias=True)
)
C3AE_3DII(
  2.106 M, 100.000% Params, 0.361 GMac, 100.000% MACs, 
  (Conv_3d): BasicConv3d(
    0.074 M, 3.507% Params, 0.359 GMac, 99.437% MACs, 
    (conv): Conv3d(0.074 M, 3.501% Params, 0.358 GMac, 99.265% MACs, 64, 64, kernel_size=(2, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
    (bn): BatchNorm3d(0.0 M, 0.006% Par

In [5]:
##
import torch
import torch.nn as nn
import numpy as np
from torch.nn import functional as F
from collections import Counter
from pre_training.deformConv.layers import ConvOffset2D


## Training
class BasicConv3d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 relu = True, bn = False, **kwargs):
        super(BasicConv3d, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, bias=False, **kwargs)
        
        if bn:
            self.bn = nn.BatchNorm3d(out_channels)
        else:
            self.bn = None
        self.relu = relu

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu:
            # 
            x = F.leaky_relu(x, inplace = True)
        return x

    
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 relu = True, bn = False, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs)
        if bn:
            self.bn = nn.BatchNorm2d(out_channels)
        else:
            self.bn = None
        self.relu = relu

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu:
            # 
            x = F.leaky_relu(x, inplace = True)
        return x
                                                   
                                                   
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, input):
        return input.view(input.size(0), -1)


class Means(nn.Module):
    def __init__(self):
        super(Means, self).__init__()

    def forward(self, input):
        return torch.mean(input, dim=(1, 2, 3)).unsqueeze(-1)

        
class ZeroOuts(nn.Module):
    def __init__(self):
        super(ZeroOuts, self).__init__()

    def forward(self, x):
        batchSize = x.size()[0]
        return torch.zeros(batchSize, 4, 1, 1)

    
"""------- single base classifer for ordinal-------"""
class BaseClassifier(nn.Module):
    def __init__(self):
        super(BaseClassifier, self).__init__()
        self.conv = nn.Sequential(
            BasicConv2d(64, 128, 3, padding = 1),
            BasicConv2d(128, 128, 3, padding = 1),
            BasicConv2d(128, 32, 1),
            BasicConv2d(32, 32, 1),
        )

        self.pooling = nn.AdaptiveAvgPool2d(output_size=(1, 1))

        self.downsampler = nn.Sequential(
            nn.Conv2d(64, 32, 1, padding=0),
            nn.BatchNorm2d(32)
        )

        self.ac = nn.ReLU(True)
        # sigmoid for multi-labels and outputing a probability
        self.fc = nn.Sequential(
            Flatten(),
            nn.Dropout(0.5),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )
    

## 1
class uniEncoder(nn.Module):
    def __init__(self, uniScale, in_channel):
        super(uniEncoder, self).__init__()
        self.uniScale = uniScale
        # 考虑one sample BN是否得False
        self.shortcut = BasicConv2d(in_channel, 256, 1, bn=True, padding=0)
        # Bottleneck
        self.bk = nn.Sequential(
            BasicConv2d(256, 64, 1, bn=True),
            BasicConv2d(64, 64, 3, bn=True, padding=1),
            BasicConv2d(64, 256, 1, bn=True),
            )
        # smooth
        self.smo = BasicConv2d(256, 256, 3, bn=True, padding=1) 
        # uniscale-downsampling - find key factor in big region
        self.uniDown = nn.AdaptiveMaxPool2d(uniScale)   
        
    def forward(self, x, Bs, flag, isTrain):
        if flag=='D-1':
            cpu_uni_stats_l = []
            for stat_cuda in x:
                stat_cuda = self.shortcut(stat_cuda)
                feature_stat = self.bk(stat_cuda) + stat_cuda
                feature_stat = self.smo(feature_stat)
                if feature_stat.size(2)>=self.uniScale:
                    uni_stat = self.uniDown(feature_stat) 
                if feature_stat.size(2)<self.uniScale:
                    uni_stat = F.interpolate(feature_stat, size=self.uniScale, mode='bilinear')
                if isTrain:
                    uni_stat = uni_stat.detach()
                    
                cpu_uni_stats_l.append(uni_stat.cpu().numpy())
            # N(D-1)*1*C*H*W → N(D-1)*C*H*W → N*(D-1)*C*H*W → N*C*(D-1)*H*W  
            cpu_uni_stats = np.stack(cpu_uni_stats_l).squeeze(1)
            channel_size = cpu_uni_stats.shape[1]
            #
            cpu_uni_stats = np.reshape(cpu_uni_stats, (Bs, -1, channel_size, self.uniScale, self.uniScale))
            cpu_uni_stats = np.transpose(cpu_uni_stats, (0,2,1,3,4))
           
            return cpu_uni_stats
        ##     
        if flag=='D-th':
            # N*C*H*W
            x = self.shortcut(x)
            feature_stat = self.bk(x) + x
            feature_stat = self.smo(feature_stat)
            uni_stat = self.uniDown(feature_stat)
            return uni_stat
        
        
## 2     
class ED_DCNN(nn.Module):
    def __init__(self):
        super(ED_DCNN, self).__init__()
        # Dconv1 has noisy
        self.offset1 = ConvOffset2D(256)
        self.conv1 = BasicConv2d(256, 64, 3, bn=True, padding=1)
        # Dconv2
        self.offset2 = ConvOffset2D(64)
        self.conv2 = BasicConv2d(64, 64, 3, bn=True, padding=1)
        
    def forward(self, x, Bs, isTrain):
        cpu_deform_stats_l = []
        # N * C * D * H * W → N * D * C * H * W → ND * C * H * W
        feature_cha = x.size(1)
        hw = x.size(3)
        ND_x = x.transpose(1,2).contiguous().view(-1, feature_cha, hw, hw)
        for statUni_cuda in ND_x:
            # batchsize = 1
            statUni_cuda = statUni_cuda.unsqueeze(0)
            statdf_cuda = self.offset1(statUni_cuda)
            statdf_cuda = self.conv1(statdf_cuda)
            statdf_cuda = self.offset2(statdf_cuda)
            statdf_cuda = self.conv2(statdf_cuda)
            if isTrain:
                statdf_cuda = statdf_cuda.detach()            
            #
            cpu_deform_stats_l.append(statdf_cuda.cpu().numpy())
        # ND*1*C*H*W → ND*C*H*W → N*D*C*H*W → N*C*D*H*W  
        cpu_deform_stats = np.stack(cpu_deform_stats_l).squeeze(1)
        channel_size = cpu_deform_stats.shape[1]
        cpu_deform_stats = np.reshape(cpu_deform_stats, (Bs, -1, channel_size, hw, hw))
        cpu_deform_stats = np.transpose(cpu_deform_stats, (0,2,1,3,4))        
        return cpu_deform_stats 


## 3    
class CNN_3D(nn.Module):
    def __init__(self, curscale, out_channel_3d):
        super(CNN_3D, self).__init__()
        self.curscale = curscale
        self.Conv_3d = BasicConv3d(64, out_channel_3d, (2,3,3), bn=True, padding=(0,1,1))
                                  
    def forward(self, x, isTrain):
        cpu_out_stats_l = []
        for stat_ts in x:
#             stat_ts = stat_ts.unsqueeze(0)
            # 1*C1*D*H_u*W_u: D → D-1 ... → 1            
            while (stat_ts.size(2)) >1:
                stat_ts = self.Conv_3d(stat_ts)                
            if isTrain:
                stat_ts = stat_ts.detach()
            
            cpu_out_stats_l.append(stat_ts.cpu().numpy())
        # N*1*C*H*W → N*C*H*W
        cpu_out_stats = np.stack(cpu_out_stats_l).squeeze(1).squeeze(2)
        out_stats =  torch.from_numpy(cpu_out_stats.astype(np.float32)).cuda()
        return out_stats 

    
## 4
class OrdinalRegressionModel(nn.Module):
    def __init__(self, nClass):
        super(OrdinalRegressionModel, self).__init__()
        self.nClass = nClass
        self.boosting = nn.ModuleList()

        for i in range(self.nClass):
            oneClassifier = BaseClassifier()
            self.boosting.append(oneClassifier)

    def forward(self, x): 
#         print ('self.nClass:', self.nClass)
        # list sigmoid outputs from all classifers
        outputs = [self.boosting[i](x) for i in range(self.nClass)]
        # list → Tensor (torch.Size([1, nClass])
        return torch.cat(outputs, dim = 1)
        
## 5
class rainFallClassification(nn.Module):
    def __init__(self):
        super(rainFallClassification, self).__init__()
        self.conv = nn.Sequential(
            BasicConv2d(57, 64, 3, bn = True, padding=1),
            BasicConv2d(64, 64, 3, bn = True, padding=1),
            BasicConv2d(64, 128, 3, bn = True, padding=1),
            BasicConv2d(128, 128, 3, bn = True, padding=1),
        )

        self.downsample = nn.Sequential(
            BasicConv2d(57, 128, 1, bn = True, relu = False, padding=0)
        )

        self.pool = nn.AdaptiveAvgPool2d((1, 1))

        self.fc = nn.Sequential(
            Flatten(),
            nn.Linear(128 * 1 * 1, 32),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(32, 2),
        )

    def forward(self, x):
        residual = self.downsample(x)
        x = self.conv(x)
        x = x + residual
        x = self.pool(x)
        x = self.fc(x)
        return x

# HR-ECb 57 × 6 × 29 × 29 
# (57,29,29)
# ERA5 27 × 6 × 16 × 16 
# (27,16,16)   
uniE = uniEncoder(18, 57)
EDCNN = ED_DCNN()
CNN3D = CNN_3D(18, 64)
regressionModel = OrdinalRegressionModel(71)    
rainFallClassifierModel = rainFallClassification()

flops_1, params_1 = get_model_complexity_info(uniEncoder, (57,29,29), as_strings=True, print_per_layer_stat=False)
# flops_2, params_2 = get_model_complexity_info(ED_DCNN, (256,6,29,29), as_strings=True, print_per_layer_stat=False)
# flops_3, params_3 = get_model_complexity_info(CNN_3D, (64,6,18,18), as_strings=True, print_per_layer_stat=False)
# flops_4, params_4 = get_model_complexity_info(OrdinalRegressionModel, (64,6,18,18), as_strings=True, print_per_layer_stat=False)
# flops_5, params_5 = get_model_complexity_info(rainFallClassifierModel, (64,6,18,18), as_strings=True, print_per_layer_stat=False)


print("%s |flops: %s |params: %s" % ('uniE:',flops_1, params_1)) 
# print("%s |flops: %s |params: %s" % ('EDCNN:',flops_2, params_2)) 
# print("%s |flops: %s |params: %s" % ('CNN3D:',flops_3, params_3)) 
# print("%s |flops: %s |params: %s" % ('regressionModel:',flops_4, params_4)) 
# print("%s |flops: %s |params: %s" % ('rainFallClassifierModel:',flops_5, params_5)) 



# ECb1 0.21*5 = 1.05 GMac | params: 348.16 = 0.34 * 5  = 0.17 M
# ERA5 0.36 * 5 = 1.80 GMac | params 571.57 = 0.56 * 4 +2.11 = 4.35 M

# ECb2 9.43 + 9.43/2  + 
# ERA5 2

# ConvLSTM: |flops: 13.08 GMac |params: 9.43 M

# 0.17 + 9.43 + 9.43/2 = 14.32

# 0.17 + 9.43 + 9.43/2 = xx




AssertionError: 

## LSTM

In [12]:
# 记得方法加nn.DataParallel
class _LSTM(nn.Module):
    def __init__(self,in_channel, hidden_channel):
        super(_LSTM, self).__init__()
        self.in_channel = in_channel        
        self.hidden_channel = hidden_channel
        self.hidden_channel_after = hidden_channel-128
        #
        self.lstm1 = nn.LSTMCell(in_channel, hidden_channel)        
        self.lstm2 = nn.LSTMCell(hidden_channel, self.hidden_channel_after)        
        #
#         self.dropout = nn.Dropout(p=0.3)
        self.linear1 = nn.Linear(self.hidden_channel_after, 64)
        self.linear2 = nn.Linear(64, 1)
        
    def forward(self, x):
        # D*N*f
        seq_length = x.size(0)
        #
        h_t_1 = torch.zeros(x.size(1), self.hidden_channel)
        c_t_1 = torch.zeros(x.size(1), self.hidden_channel)
        h_t_2 = torch.zeros(x.size(1), self.hidden_channel_after)
        c_t_2 = torch.zeros(x.size(1), self.hidden_channel_after)          
        for i in range(seq_length):
            h_t_1, c_t_1 = self.lstm1(x[i], (h_t_1, c_t_1))
            h_t_2, c_t_2 = self.lstm2(h_t_1, (h_t_2, c_t_2))
            
        h21 = self.linear1(h_t_2)
        #
        y_t = self.linear2(h21).squeeze()

        return y_t
    

###

###    
lstm = _LSTM(
        27,
        256)

# HR-ECb 57 × 6 × 29 × 29 
# (57,29,29)
# ERA5 27 × 6 × 16 × 16 
# (27,16,16)
flops, params = get_model_complexity_info(lstm, (6,27), as_strings=True, print_per_layer_stat=False)
    
print("%s |flops: %s |params: %s" % ('LSTM:',flops, params))



'''
0.50G  1.82M

0.21G  1.71M  

'''


LSTM: |flops: 0.0 GMac |params: 497.79 k


'\n0.50G  1.82M\n\n0.21G    \n\n'

In [None]:
class GCN(nn.Module):
'''
Z = AXW
'''
    def __init__(self , A, dim_in , dim_out):
        super(GCN,self).__init__()
        self.A = A
        self.fc1 = nn.Linear(dim_in ,dim_in,bias=False)
        self.fc2 = nn.Linear(dim_in,dim_in//2,bias=False)
        self.fc3 = nn.Linear(dim_in//2,dim_out,bias=False)

    def forward(self,X):
        '''
        计算三层gcn
        '''
        X = F.relu(self.fc1(self.A.mm(X)))
        X = F.relu(self.fc2(self.A.mm(X)))
        return self.fc3(self.A.mm(X))
    
gcn = GCN(
        57,
        256)   

# HR-ECb 57 × 6 × 29 × 29 
# (57,29,29)
# ERA5 27 × 6 × 16 × 16 
# (27,16,16)
flops, params = get_model_complexity_info(gcn, (6,27), as_strings=True, print_per_layer_stat=False)
    
print("%s |flops: %s |params: %s" % ('GCN:',flops, params))    

In [17]:
pError_ele = [1,0,3,0,5]




