In [5]:
import torch
import torchviz
from utils.Models import MLP

model = MLP(input_dim=4, output_dim=2, hidden_layers=2, hidden_dim=64)
dummy_input = torch.randn(10, 4)
graph = torchviz.make_dot(model(dummy_input), params=dict(list(model.named_parameters())))
graph.render("MLP_Model", format="png")  # 保存图表为PNG格式

'MLP_Model.png'

In [11]:
from torchsummary import summary
summary(model, (4,))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                   [-1, 64]             320
              ReLU-2                   [-1, 64]               0
            Linear-3                   [-1, 64]           4,160
       BatchNorm1d-4                   [-1, 64]             128
              ReLU-5                   [-1, 64]               0
            Linear-6                   [-1, 64]           4,160
       BatchNorm1d-7                   [-1, 64]             128
              ReLU-8                   [-1, 64]               0
            Linear-9                    [-1, 2]             130
Total params: 9,026
Trainable params: 9,026
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.03
Estimated Total Size (MB): 0.04
-----------------------------------------------

In [6]:
from torchinfo import summary

model = MLP(input_dim=4, output_dim=2, hidden_layers=2, hidden_dim=64)
summary(model, input_size=(10, 4))


Layer (type:depth-idx)                   Output Shape              Param #
MLP                                      [10, 2]                   --
├─Linear: 1-1                            [10, 64]                  320
├─ReLU: 1-2                              [10, 64]                  --
├─ModuleList: 1-5                        --                        (recursive)
│    └─Sequential: 2-1                   [10, 64]                  --
│    │    └─Linear: 3-1                  [10, 64]                  4,160
│    │    └─BatchNorm1d: 3-2             [10, 64]                  128
├─ReLU: 1-4                              [10, 64]                  --
├─ModuleList: 1-5                        --                        (recursive)
│    └─Sequential: 2-2                   [10, 64]                  --
│    │    └─Linear: 3-3                  [10, 64]                  4,160
│    │    └─BatchNorm1d: 3-4             [10, 64]                  128
├─ReLU: 1-6                              [10, 64]         

In [16]:
from utils.unet import UNetModel
from torchinfo import summary

image_size = 64
image_channels = 1
num_channels = 64
num_res_blocks = 4
num_heads = 4
num_heads_upsample = -1
attention_resolutions = "8"
dropout = 0.1
use_checkpoint = False
use_scale_shift_norm = True

channel_mult = (1, 2, 4)

attention_ds = []
for res in attention_resolutions.split(","):
    attention_ds.append(image_size // int(res))
kwargs = {
    "in_channels": image_channels,
    "model_channels": num_channels,
    "out_channels": image_channels,
    "num_res_blocks": num_res_blocks,
    "attention_resolutions": tuple(attention_ds),
    "dropout": dropout,
    "channel_mult": channel_mult,
    "num_classes": None,
    "use_checkpoint": use_checkpoint,
    "num_heads": num_heads,
    "num_heads_upsample": num_heads_upsample,
    "use_scale_shift_norm": use_scale_shift_norm,
}

unet = UNetModel(**kwargs)
x = torch.rand(3, 1, 64, 64)
t = torch.rand(3)
d = torch.rand(3)
summary(unet, input_data=[x,t,d],batch_dim=0)

Layer (type:depth-idx)                        Output Shape              Param #
UNetModel                                     [3, 1, 64, 64]            --
├─Sequential: 1-1                             [3, 256]                  --
│    └─Linear: 2-1                            [3, 256]                  16,640
│    └─SiLU: 2-2                              [3, 256]                  --
│    └─Linear: 2-3                            [3, 256]                  65,792
├─Sequential: 1-2                             [3, 256]                  --
│    └─Linear: 2-4                            [3, 256]                  16,640
│    └─SiLU: 2-5                              [3, 256]                  --
│    └─Linear: 2-6                            [3, 256]                  65,792
├─ModuleList: 1-3                             --                        --
│    └─TimestepEmbedSequential: 2-7           [3, 64, 64, 64]           --
│    │    └─Conv2d: 3-1                       [3, 64, 64, 64]           640
│  