In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from fvcore.nn import FlopCountAnalysis,parameter_count_table
class Model(nn.Module):
    """
    Just one Linear layer
    """
    def __init__(self, channel=7,ratio=1):
        super(Model, self).__init__()

        self.avg_pool = nn.AdaptiveAvgPool1d(1) #innovation
        self.fc = nn.Sequential(
                nn.Linear(7,14, bias=False),
                nn.Dropout(p=0.1),
                nn.ReLU(inplace=True) ,
                nn.Linear(14, 7, bias=False),
                nn.Sigmoid()
        )
        self.seq_len = 96
        self.pred_len = 96
        self.Linear_More_1 = nn.Linear(self.seq_len,self.pred_len * 2)
        self.Linear_More_2 = nn.Linear(self.pred_len*2,self.pred_len)
        self.relu = nn.ReLU()
        self.gelu = nn.GELU()    

        self.drop = nn.Dropout(p=0.1)
        # Use this line if you want to visualize the weights
        # self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
    def forward(self, x):
        # x: [Batch, Input length, Channel]
        # x = self.Linear(x.permute(0,2,1))
        x = x.permute(0,2,1) # (B，L,C)=》(B,C,L)
        b, c, l = x.size() # (B,C,L)
        # y = self.avg_pool(x) # (B,C,L) 通过avg=》 (B,C,1)
        # print("y",y.shape)
        y = self.avg_pool(x).view(b, c) # (B,C,L) 通过avg=》 (B,C,1)
        # print("y",y.shape)
        #为了丢给Linear学习，需要view把数据展平开
        # y = self.fc(y).view(b, c, 96)
        
        y = self.fc(y).view(b,c,1)

        # print("y",y.shape)
        return (x * y).permute(0,2,1)
model = Model()
print(parameter_count_table(model))

| name                   | #elements or shape   |
|:-----------------------|:---------------------|
| model                  | 37.3K                |
|  fc                    |  0.2K                |
|   fc.0                 |   98                 |
|    fc.0.weight         |    (14, 7)           |
|   fc.3                 |   98                 |
|    fc.3.weight         |    (7, 14)           |
|  Linear_More_1         |  18.6K               |
|   Linear_More_1.weight |   (192, 96)          |
|   Linear_More_1.bias   |   (192,)             |
|  Linear_More_2         |  18.5K               |
|   Linear_More_2.weight |   (96, 192)          |
|   Linear_More_2.bias   |   (96,)              |


In [17]:
'''
计算FLOPs
'''
tensor=torch.randn(8,96,7)
FLOPs = FlopCountAnalysis(model,tensor)
print("FLOPs:",FLOPs.total())

Unsupported operator aten::adaptive_avg_pool1d encountered 1 time(s)
Unsupported operator aten::sigmoid encountered 1 time(s)
Unsupported operator aten::mul encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
Linear_More_1, Linear_More_2, drop, gelu, relu


FLOPs: 1568


In [20]:
class Model(nn.Module):
    """
    Normalization-Linear
    """
    def __init__(self):
        super(Model, self).__init__()
        self.seq_len = 96
        self.pred_len = 720
        self.Linear = nn.Linear(self.seq_len, self.pred_len)
            # Use this line if you want to visualize the weights
            # self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))

    def forward(self, x):
        # x: [Batch, Input length, Channel]

        x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
        
        return x # [Batch, Output length, Channel]
model = Model()



print(parameter_count_table(model))

| name            | #elements or shape   |
|:----------------|:---------------------|
| model           | 69.8K                |
|  Linear         |  69.8K               |
|   Linear.weight |   (720, 96)          |
|   Linear.bias   |   (720,)             |


In [42]:
'''
计算FLOPs
'''
# tensor=torch.randn(,96,7)
result =model(tensor)
# FLOPs = FlopCountAnalysis(model,tensor)
# print("FLOPs:",FLOPs.total())

'\n计算FLOPs\n'