In [2]:

from torchsummary import summary
from torch import nn
import torch
from models.SFMCNN import SFMCNN
from models.RGB_SFMCNN import RGB_SFMCNN
from models.RGB_SFMCNN_V2 import RGB_SFMCNN_V2
from dataloader import get_dataloader
from config import *
import matplotlib



matplotlib.use('Agg')

with torch.no_grad():
    # Load Dataset
    train_dataloader, test_dataloader = get_dataloader(dataset=config['dataset'], root=config['root'] + '/data/',
                                                       batch_size=config['batch_size'],
                                                       input_size=config['input_shape'])
    images, labels = torch.tensor([]), torch.tensor([])
    for batch in test_dataloader:
        imgs, lbls = batch
        images = torch.cat((images, imgs))
        labels = torch.cat((labels, lbls))
    print(images.shape, labels.shape)

    # Load Model
    models = {'SFMCNN': SFMCNN, 'RGB_SFMCNN': RGB_SFMCNN, 'RGB_SFMCNN_V2': RGB_SFMCNN_V2}
    checkpoint_filename = 'RGB_SFMCNN_V2_best'
    checkpoint = torch.load(f'../pth/{config["dataset"]}_pth/{checkpoint_filename}.pth', weights_only=True)
    model = models[arch['name']](**dict(config['model']['args']))
    model.load_state_dict(checkpoint['model_weights'])
    model.cpu()
    model.eval()
    summary(model, input_size=(config['model']['args']['in_channels'], *config['input_shape']), device='cpu')
    print(model)

    # Test Model
    batch_num = 1000
    pred = model(images[:batch_num])
    y = labels[:batch_num]
    correct = (pred.argmax(1) == y.argmax(1)).type(torch.float).sum().item()
    print("Test Accuracy: " + str(correct / len(pred)))

Code/runs/train/exp
torch.Size([900, 3, 28, 28]) torch.Size([900, 9])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
        RGB_Conv2d-1             [-1, 30, 6, 6]               0
        RGB_Conv2d-2             [-1, 30, 6, 6]               0
          triangle-3             [-1, 30, 6, 6]               0
          triangle-4             [-1, 30, 6, 6]               0
     cReLU_percent-5             [-1, 30, 6, 6]               0
     cReLU_percent-6             [-1, 30, 6, 6]               0
               SFM-7             [-1, 30, 3, 3]               0
        RBF_Conv2d-8            [-1, 225, 3, 3]           6,750
          triangle-9            [-1, 225, 3, 3]               0
    cReLU_percent-10            [-1, 225, 3, 3]               0
              SFM-11            [-1, 225, 3, 1]               0
       RBF_Conv2d-12            [-1, 625, 3, 1]         140,625
         triangle-13            [

In [32]:
def calculate_RM(layers, layer_num, images):   
    """
    計算 layer_num 層，對於所有圖片的輸出值
    """
    RM = layers[layer_num](images)
    
    filter_count = RM.shape[1]
    reshape_RM = RM.permute(1, 0, 2, 3).reshape(filter_count, -1)
    
    return reshape_RM


def get_stats(reshape_RM):
    """
    计算层的统计数据，包括每个通道的均值、标准差、最大值、最小值、峰值和偏值等。
    
    参数:
    - reshape_RM: 重塑后的特征图，形状为 (通道数, 像素数)
    
    返回:
    - channel_stats: 每个通道的统计信息字典
    - overall_stats: 整个层的统计信息字典
    """
    
    overall_stats = {
        "mean": float(reshape_RM.mean()),
        "std": float(reshape_RM.std()),
        "max": float(reshape_RM.max()),
        "min": float(reshape_RM.min()),
    }
    
    # 归一化到 0-1 之间
    min_value = reshape_RM.min()
    max_value = reshape_RM.max()
    normalized_reshape_RM = (reshape_RM - min_value) / (max_value - min_value)  # 归一化处理

    # 计算每个通道的统计指标
    channel_mean_values = normalized_reshape_RM.mean(dim=1)  # 每个通道的平均值
    channel_std_values = normalized_reshape_RM.std(dim=1)    # 每个通道的标准差
    channel_max_values = normalized_reshape_RM.max(dim=1).values  # 每个通道的最大值
    channel_min_values = normalized_reshape_RM.min(dim=1).values  # 每个通道的最小值
    
    # 计算每个通道的峰值和偏值
    channel_kurtosis_values = scipy.stats.kurtosis(normalized_reshape_RM.detach().numpy(), axis=1)  # 每个通道的峰值
    channel_skewness_values = scipy.stats.skew(normalized_reshape_RM.detach().numpy(), axis=1)      # 每个通道的偏值

    # 计算相对于最大值的阈值统计
    count_above_0_99 = (normalized_reshape_RM > 0.99 ).sum(dim=1)  
    ratio_above_0_99 = (count_above_0_99 / normalized_reshape_RM.shape[1]).tolist()  
    
    count_above_0_9 = (normalized_reshape_RM > 0.9 ).sum(dim=1)  
    ratio_above_0_9 = (count_above_0_9 / normalized_reshape_RM.shape[1]).tolist()  

    count_below_0_1 = (normalized_reshape_RM < 0.1 ).sum(dim=1) 
    ratio_below_0_1 = (count_below_0_1 / normalized_reshape_RM.shape[1]).tolist()  

    # 计算每个通道的统计指标并存储在字典中
    channel_stats = {
        "mean": channel_mean_values.tolist(),
        "std": channel_std_values.tolist(),
        "max": channel_max_values.tolist(),
        "min": channel_min_values.tolist(),
        "kurtosis": channel_kurtosis_values.tolist(),  # 峰值
        "skewness": channel_skewness_values.tolist(),    # 偏值
        "ratio_above_0.99": ratio_above_0_99,    
        "ratio_above_0.9": ratio_above_0_9,     
        "ratio_below_0.1": ratio_below_0_1,     
    }
    
    return channel_stats, overall_stats

import numpy as np
import scipy.stats
import torch

def create_metrics_dict():
    """
    創建包含多種統計指標計算方法的字典
    
    返回:
    - 包含指標計算方法的字典
    """
    metrics = {
        # 计算每个通道的最大值大于 0.9 的比例
        'each channel max > 0.9': lambda stats: len([max_val for max_val in stats["max"] if max_val > 0.9]) / len(stats["max"]),
        
        # 计算所有通道中，反应值大于 0.99 的比例小于 0.1 的通道的比例
        'ratio_above_0.99 < 0.1': lambda stats: len([ratio for ratio in stats["ratio_above_0.99"] if ratio < 0.1]) / len(stats["ratio_above_0.99"]),
        
        # 计算所有通道中，反应值大于 0.9 的比例小于 0.8 的通道的比例
        'ratio_above_0.9 < 0.8': lambda stats: len([ratio for ratio in stats["ratio_above_0.9"] if ratio < 0.8]) / len(stats["ratio_above_0.9"]),
        
        # 计算所有通道中，反应值小于 0.1 的比例小于 1 的通道的比例
        'ratio_below_0.1 < 1': lambda stats: len([ratio for ratio in stats["ratio_below_0.1"] if ratio < 1]) / len(stats["ratio_below_0.1"]),
    }
    
    
    return metrics

def calculate_layer_metrics(stats):
    """
    計算層的各種指標
    
    參數:
    - stats: 統計信息
    
    返回:
    - 包含各種指標的字典
    """
    
    # 創建指標計算方法字典
    metrics = create_metrics_dict()
    
    # 計算所有指標
    layer_metrics = {}
    for metric_name, metric_func in metrics.items():
        try:
            layer_metrics[metric_name] = metric_func(stats)
        except Exception as e:
            print(f"Error calculating {metric_name}: {e}")
            layer_metrics[metric_name] = None
    
    return layer_metrics


In [30]:
def plot_channel_histograms(raw, plot_shape=(5, 6), save_dir='.', layer_num = 'layer', xlim=(0, 1), space_count = 5):
    """
    繪製每個通道的直方圖，使用熱圖色彩，以自定義區間等分，不顯示文字標題
    
    參數:
    - raw: 原始數據張量
    - plot_shape: 子圖網格形狀，默認為 (5, 6)
    - save_dir: 保存目錄，默認為當前目錄
    - save_file: 保存文件名，默認為 'channel_histograms.png'
    - xlim: x軸範圍，默認為 (0, 1)
    - space_count: x 等分數，默認為 5
    """
    import matplotlib.pyplot as plt
    import os
    import numpy as np
    
    # 準備數據
    data = raw.detach().numpy()  # 使用 detach() 方法
    num_channels = data.shape[0]
    
    # 創建圖形，預留頂部空間
    plt.figure(figsize=(20, 17))  # 略微增加高度以容納標題
    plt.suptitle(f"{layer_num} x range : {xlim}, space: {space_count}", fontsize=32, fontweight='bold', y=0.98)

    # 創建分區（5等分）
    bins = np.linspace(xlim[0], xlim[1], space_count + 1)

    # 創建子圖網格
    rows, cols = plot_shape
    for i in range(min(num_channels, rows * cols)):
        plt.subplot(rows, cols, i+1)
        
        # 使用熱圖色彩映射，依據 bin 的位置變化顏色
        n, bins_edges, patches = plt.hist(data[i], bins=bins, edgecolor='black')
        
        # 為每個 bin 設置漸變顏色
        fracs = (bins_edges[:-1] + bins_edges[1:]) / 2
        norm = plt.Normalize(xlim[0], xlim[1])
        for frac, patch in zip(fracs, patches):
            color = plt.cm.viridis(norm(frac))
            patch.set_facecolor(color)
        
        plt.xlim(xlim[0], xlim[1])  # x軸範圍固定在指定區間
        plt.xticks([])  # 移除x軸刻度標籤
        plt.yticks([])  # 移除y軸刻度標籤

    # 調整子圖間距
    plt.tight_layout()

    # 確保保存目錄存在
    os.makedirs(save_dir, exist_ok=True)
    
    # 保存圖像
    full_path = os.path.join(save_dir, f'{layer_num}_{space_count}_channel_histograms.png')
    plt.savefig(full_path, dpi=300)  # 提高分辨率
    plt.close()  # 關閉圖形以釋放內存
    
    print(f"Histogram saved to {full_path}")

In [None]:

# 定義 layers_infos，包含每個層的相關信息
layers_infos = [
    {"layer_num": "RGB_convs_0", "is_gray": False, "plot_shape": (5, 6)},
    {"layer_num": "RGB_convs_1", "is_gray": False, "plot_shape": None},
    {"layer_num": "RGB_convs_2", "is_gray": False, "plot_shape": None},
    {"layer_num": "Gray_convs_0", "is_gray": True, "plot_shape": (7, 10)},
    {"layer_num": "Gray_convs_1", "is_gray": True, "plot_shape": None},
    {"layer_num": "Gray_convs_2", "is_gray": True, "plot_shape": None},
]
 
layers = {}
layers['RGB_convs_0'] = model.RGB_convs[0] # 空間合併前
layers['RGB_convs_1'] = nn.Sequential(*(list(model.RGB_convs[:2]) + list([model.RGB_convs[2][:-1]]))) # 空間合併前
layers['RGB_convs_2'] = nn.Sequential(*(list(model.RGB_convs)))

layers['Gray_convs_0'] = model.Gray_convs[0]  # 空間合併前
layers['Gray_convs_1'] = nn.Sequential(*(list(model.Gray_convs[:2]) + list([model.Gray_convs[2][:-1]])))# 空間合併前
layers['Gray_convs_2'] = nn.Sequential(*(list(model.Gray_convs)))



In [33]:
def get_layer_stats(model, layers, layer_num, is_gray = False):
    if is_gray:
        input_images = model.gray_transform(images)
    else:
        input_images = images
        
    raw = calculate_RM(layers, layer_num, input_images)
    channel_stats, overall_stats = get_stats(raw)
    
    # 計算指標
    metrics_results = calculate_layer_metrics({**overall_stats, **channel_stats})
        
    return {**metrics_results}

def get_all_layers_stats(model, layers, layers_infos):
    # 使用示例
    layer_stats = {}
    for layer_info in layers_infos:
        layer_num = layer_info["layer_num"]
        is_gray = layer_info["is_gray"]
        
        layer_stats[layer_num] = get_layer_stats(model, layers, layer_num, is_gray)
        
    return layer_stats
        

stats = get_all_layers_stats(model, layers, layers_infos)
print("Final Stats: ")
print(stats)

Final Stats: 
{'RGB_convs_0': {'each channel max > 0.9': 0.16666666666666666, 'ratio_above_0_99 < 0.1': 0.9666666666666667, 'ratio_above_0_9 < 0.8': 1.0, 'ratio_below_0_1 < 1': 0.8333333333333334}, 'RGB_convs_1': {'each channel max > 0.9': 0.013333333333333334, 'ratio_above_0_99 < 0.1': 1.0, 'ratio_above_0_9 < 0.8': 1.0, 'ratio_below_0_1 < 1': 0.8933333333333333}, 'RGB_convs_2': {'each channel max > 0.9': 0.6608, 'ratio_above_0_99 < 0.1': 1.0, 'ratio_above_0_9 < 0.8': 0.6512, 'ratio_below_0_1 < 1': 0.6608}, 'Gray_convs_0': {'each channel max > 0.9': 0.4, 'ratio_above_0_99 < 0.1': 0.9571428571428572, 'ratio_above_0_9 < 0.8': 1.0, 'ratio_below_0_1 < 1': 0.9571428571428572}, 'Gray_convs_1': {'each channel max > 0.9': 0.6064, 'ratio_above_0_99 < 0.1': 1.0, 'ratio_above_0_9 < 0.8': 1.0, 'ratio_below_0_1 < 1': 0.9216}, 'Gray_convs_2': {'each channel max > 0.9': 0.013877551020408163, 'ratio_above_0_99 < 0.1': 1.0, 'ratio_above_0_9 < 0.8': 1.0, 'ratio_below_0_1 < 1': 0.6824489795918367}}


In [28]:
def plot_layer_graph(model, layers, layer_num, is_gray = False, plot_shape = None, save_dir= './output'):
    if is_gray:
        input_images = model.gray_transform(images)
    else:
        input_images = images
    
    raw = calculate_RM(layers, layer_num, input_images)
    stats, global_stats = get_stats(raw)
   
    if plot_shape is None:
        plot_shape = (int(raw.shape[0] ** 0.5), int(raw.shape[0] ** 0.5))
        
    xlim = (global_stats['min'], global_stats['max'])

    plot_channel_histograms(raw, plot_shape = plot_shape, save_dir=save_dir, layer_num=layer_num, xlim=xlim, space_count = 25)

def plot_all_layers_graph(model, layers, layers_infos, save_dir = './output'):
    
    for layer_info in layers_infos:
        layer_num = layer_info["layer_num"]
        is_gray = layer_info["is_gray"]
        plot_shape = layer_info["plot_shape"]
        
        print(f"plotting {layer_num} graph")
        plot_layer_graph(model, layers, layer_num, is_gray, plot_shape, save_dir)

save_dir = './output'
plot_all_layers_graph(model, layers, layers_infos, save_dir)

plotting RGB_convs_0 graph
Histogram saved to ./output\RGB_convs_0_25_channel_histograms.png
plotting RGB_convs_1 graph
Histogram saved to ./output\RGB_convs_1_25_channel_histograms.png
plotting RGB_convs_2 graph
Histogram saved to ./output\RGB_convs_2_25_channel_histograms.png
plotting Gray_convs_0 graph
Histogram saved to ./output\Gray_convs_0_25_channel_histograms.png
plotting Gray_convs_1 graph
Histogram saved to ./output\Gray_convs_1_25_channel_histograms.png
plotting Gray_convs_2 graph
Histogram saved to ./output\Gray_convs_2_25_channel_histograms.png
