In [None]:
import torch
import models.model_student as model_student

def count_parameters(model):
    """Count total and trainable parameters in a model"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

def print_model_parameters():
    scale_factor = 1  
    student_model = model_student.enhance_net_nopool_student(scale_factor)
    
    # Count parameters
    total_params, trainable_params = count_parameters(student_model)
    
    print(f"Student Model Parameters:")
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Model size: {total_params * 4 / (1024**2):.2f} MB (assuming float32)")
    
    # Detailed layer-wise parameter count
    print("\nLayer-wise parameter count:")
    print("-" * 50)
    for name, module in student_model.named_modules():
        if len(list(module.parameters())) > 0:
            params = sum(p.numel() for p in module.parameters())
            if params > 0:
                print(f"{name}: {params:,} parameters")

if __name__ == "__main__":
    print_model_parameters()

Student Model Parameters:
Total parameters: 161
Trainable parameters: 161
Model size: 0.00 MB (assuming float32)

Layer-wise parameter count:
--------------------------------------------------
: 161 parameters
e_conv1: 46 parameters
e_conv1.depth_conv: 30 parameters
e_conv1.point_conv: 16 parameters
e_conv2: 60 parameters
e_conv2.depth_conv: 40 parameters
e_conv2.point_conv: 20 parameters
e_conv3: 55 parameters
e_conv3.depth_conv: 40 parameters
e_conv3.point_conv: 15 parameters


In [None]:
import torch
import models.model_student as model_student  # your model definition file

# Create the model
net = model_student.enhance_net_nopool_student(scale_factor=1)  # adjust to match training

# Load weights
net.load_state_dict(torch.load("snapshots_Student_KD/Student_Final.pth", map_location="cpu"))


In [4]:
from torchinfo import summary
summary(net, input_size=(1, 3, 512, 512))


Layer (type:depth-idx)                   Output Shape              Param #
enhance_net_nopool_student               [1, 3, 512, 512]          --
├─CSDN_Tem: 1-1                          [1, 4, 512, 512]          --
│    └─Conv2d: 2-1                       [1, 3, 512, 512]          30
│    └─Conv2d: 2-2                       [1, 4, 512, 512]          16
├─ReLU: 1-2                              [1, 4, 512, 512]          --
├─CSDN_Tem: 1-3                          [1, 4, 512, 512]          --
│    └─Conv2d: 2-3                       [1, 4, 512, 512]          40
│    └─Conv2d: 2-4                       [1, 4, 512, 512]          20
├─ReLU: 1-4                              [1, 4, 512, 512]          --
├─CSDN_Tem: 1-5                          [1, 4, 512, 512]          --
│    └─Conv2d: 2-5                       [1, 4, 512, 512]          40
│    └─Conv2d: 2-6                       [1, 4, 512, 512]          20
├─ReLU: 1-6                              [1, 4, 512, 512]          --
├─CSDN_Tem: 1-7

In [5]:
import torch
import model # your model definition file

# Create the model
net = model.enhance_net_nopool(scale_factor=1)  # adjust to match training

# Load weights
net.load_state_dict(torch.load("snapshots_Zero_DCE++/Epoch99.pth", map_location="cpu"))


<All keys matched successfully>

In [6]:
total_params = sum(p.numel() for p in net.parameters())
print(f"Total parameters: {total_params:,}")


Total parameters: 10,561


In [7]:
from torchinfo import summary
summary(net, input_size=(1, 3, 512, 512))


Layer (type:depth-idx)                   Output Shape              Param #
enhance_net_nopool                       [1, 3, 512, 512]          --
├─CSDN_Tem: 1-1                          [1, 32, 512, 512]         --
│    └─Conv2d: 2-1                       [1, 3, 512, 512]          30
│    └─Conv2d: 2-2                       [1, 32, 512, 512]         128
├─ReLU: 1-2                              [1, 32, 512, 512]         --
├─CSDN_Tem: 1-3                          [1, 32, 512, 512]         --
│    └─Conv2d: 2-3                       [1, 32, 512, 512]         320
│    └─Conv2d: 2-4                       [1, 32, 512, 512]         1,056
├─ReLU: 1-4                              [1, 32, 512, 512]         --
├─CSDN_Tem: 1-5                          [1, 32, 512, 512]         --
│    └─Conv2d: 2-5                       [1, 32, 512, 512]         320
│    └─Conv2d: 2-6                       [1, 32, 512, 512]         1,056
├─ReLU: 1-6                              [1, 32, 512, 512]         --
├─CSDN