In [2]:
def read_txt_to_dict(file_path):
    result_dict = {}
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            for line in file:
                # 去掉行首尾的空白字符，并按列分隔
                parts = line.strip().split()
                if len(parts) != 2:
                    print(f"警告：跳过格式不正确的行 - {line.strip()}")
                    continue
                index_name, value = parts
                # 将第一列作为键，第二列作为值存入字典
                result_dict[index_name] = value
        return result_dict
    except FileNotFoundError:
        print(f"错误：文件 {file_path} 未找到！")
        return {}
    except Exception as e:
        print(f"发生错误：{e}")
        return {}

In [4]:
file_path = r"D:\python\val.txt"
data_dict = read_txt_to_dict(file_path)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
from sklearn.mixture import GaussianMixture
from sklearn.cluster import KMeans
import os

# 定义数据集类
class ImageNetSubsetDataset(Dataset):
    def __init__(self, root_dir, labellist,transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = os.listdir(root_dir)
        self.labellist=labellist
        self.image_paths = []
        self.labels = []
        for idx, cls in enumerate(self.classes):
            cls_dir = os.path.join(root_dir, cls)
            self.image_paths.append(cls_dir)
            name = cls.split(".")[0]
            name = name+".png"
            self.labels.append(labellist[name])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# 定义预处理变换
data_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载本地ImageNet数据
root_dir = r"D:\python\imagenet"  # 替换为你的本地ImageNet数据路径
dataset = ImageNetSubsetDataset(root_dir,data_dict, transform=data_transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

# 加载预训练的ResNet18模型
model = models.resnet18(pretrained=True)
model.eval()

# 提取特征图
def extract_features(model, dataloader):
    features = []
    with torch.no_grad():
        for images, _ in dataloader:
            outputs = model.conv1(images)
            outputs = model.bn1(outputs)
            outputs = model.relu(outputs)
            outputs = model.maxpool(outputs)
            outputs = model.layer1(outputs)
            outputs = model.layer2(outputs)
            outputs = model.layer3(outputs)
            outputs = model.layer4(outputs)
            features.append(outputs.cpu().numpy())
    return np.concatenate(features, axis=0)

cnn_feature_maps = extract_features(model, dataloader)

# 定义解释图类
class ExplanatoryGraph:
    def __init__(self, num_layers, num_filters_per_layer, num_patterns_per_filter):
        self.num_layers = num_layers
        self.num_filters_per_layer = num_filters_per_layer
        self.num_patterns_per_filter = num_patterns_per_filter
        self.graph = {}

    def build_graph(self, cnn_feature_maps):
        for layer_idx in range(self.num_layers):
            layer_feature_maps = cnn_feature_maps[layer_idx]
            self.graph[layer_idx] = {}
            for filter_idx in range(self.num_filters_per_layer[layer_idx]):
                filter_feature_map = layer_feature_maps[filter_idx]
                # 使用聚类方法解缠部件模式
                patterns = self._disentangle_patterns(filter_feature_map)
                self.graph[layer_idx][filter_idx] = patterns

    def _disentangle_patterns(self, feature_map):
        # 使用高斯混合模型进行聚类
        gmm = GaussianMixture(n_components=self.num_patterns_per_filter)
        gmm.fit(feature_map.reshape(-1, 1))
        return gmm.means_.flatten()

    def train(self, cnn_feature_maps, num_iterations=100):
        for iteration in range(num_iterations):
            for layer_idx in range(self.num_layers):
                for filter_idx in range(self.num_filters_per_layer[layer_idx]):
                    patterns = self.graph[layer_idx][filter_idx]
                    # 更新模式参数
                    self._update_patterns(patterns, cnn_feature_maps[layer_idx][filter_idx])

    def _update_patterns(self, patterns, feature_map):
        # 使用EM算法更新模式参数
        # 这里只是一个简化的示例，实际实现需要更复杂的细节
        kmeans = KMeans(n_clusters=self.num_patterns_per_filter)
        kmeans.fit(feature_map.reshape(-1, 1))
        patterns[:] = kmeans.cluster_centers_.flatten()

# 构建解释图
num_layers = 4
num_filters_per_layer = [64, 128, 256, 512]  # 根据ResNet18的结构设置
num_patterns_per_filter = 5  # 每个滤波器解缠出5个模式
explanatory_graph = ExplanatoryGraph(num_layers, num_filters_per_layer, num_patterns_per_filter)
explanatory_graph.build_graph(cnn_feature_maps)

# 训练解释图
explanatory_graph.train(cnn_feature_maps)

# 输出结果
print("Explanatory Graph:")
for layer_idx in range(num_layers):
    for filter_idx in range(num_filters_per_layer[layer_idx]):
        patterns = explanatory_graph.graph[layer_idx][filter_idx]
        print(f"Layer {layer_idx}, Filter {filter_idx}: {patterns}")