In [5]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from sklearn.neighbors import NearestNeighbors
from scipy.special import digamma as psi
import numpy as np
from tqdm import tqdm

In [6]:
# 加载CIFAR100数据集
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 将图像大小调整为ResNet50的输入尺寸
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # 归一化
])

trainset = datasets.CIFAR100(root='/home/fabien/Documents/project/2d/mdistiller/data/', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=False)

# 加载预训练的ResNet50模型
resnet50 = models.resnet50(pretrained=True)

# 修改ResNet50的全连接层以匹配CIFAR100的类别数
num_ftrs = resnet50.fc.in_features
resnet50.fc = torch.nn.Linear(num_ftrs, 100)

Files already downloaded and verified


In [11]:
# 定义函数以提取特征和logits
# ResNet-50
def extract_features_logits(model, dataloader):
    model.eval()  # 设置为评估模式
    all_features = []
    all_logits = []
    all_labels = []
    i = 0
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Extracting features and logits"):
            i += 1
            if i > 20:
                break
            # 前向传播
            outputs = model(inputs)
            
            # 提取倒数第二层的特征
            features = model.avgpool(model.layer4(model.layer3(model.layer2(model.layer1(model.maxpool(model.relu(model.bn1(model.conv1(inputs))))))))).view(inputs.size(0), -1)
            
            all_features.extend(features.cpu().numpy())
            all_logits.extend(outputs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return np.array(all_features), np.array(all_logits), np.array(all_labels)


In [12]:
features, logits, labels = extract_features_logits(resnet50, trainloader)

Extracting features and logits:   3%|▎         | 20/782 [01:20<51:10,  4.03s/it]


In [15]:
# 使用最近邻方法估计互信息的函数
def estimate_mutual_information(k, data, labels):
    nbrs = NearestNeighbors(n_neighbors=k+1).fit(data)
    distances, indices = nbrs.kneighbors(data)
    epsilon = distances[:, -1]
    label_distances = labels[indices]
    counts = (label_distances == labels[:, None]).sum(axis=1) - 1
    mi_estimate = psi(k) - (counts/k).mean() + psi(len(data))
    return mi_estimate

# 估计特征与标签之间的互信息
mi_features_labels = estimate_mutual_information(5, features, labels)
print(f'Estimated mutual information between features and labels: {mi_features_labels}')

# 估计logits与标签之间的互信息
mi_logits_labels = estimate_mutual_information(5, logits, labels)
print(f'Estimated mutual information between logits and labels: {mi_logits_labels}')

Estimated mutual information between features and labels: 8.385342349482837
Estimated mutual information between logits and labels: 8.440498599482837


In [8]:
# MobileNetV2
# 加载预训练的MobileNetV2模型
mobilenet_v2 = models.mobilenet_v2(pretrained=True)

class ModifiedMobileNetV2(torch.nn.Module):
    def __init__(self, original_model):
        super(ModifiedMobileNetV2, self).__init__()
        self.features = original_model.features
        self.classifier = original_model.classifier

    def forward(self, x):
        x = self.features(x)
        feature_maps = x
        x = x.mean([2, 3])  # 全局平均池化
        logits = self.classifier(x)
        return feature_maps, logits

modified_mobilenet_v2 = ModifiedMobileNetV2(mobilenet_v2)

def extract_features_logits(model, dataloader):
    model.eval()
    all_features = []
    all_logits = []
    all_labels = []

    with torch.no_grad():
        i = 0
        for inputs, labels in tqdm(dataloader, desc="Extracting features and logits"):
            i += 1
            if i > 20:
                break
            feature_maps, logits = model(inputs)
            all_features.extend(feature_maps.cpu().numpy())
            all_logits.extend(logits.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return np.array(all_features), np.array(all_logits), np.array(all_labels)

features, logits, labels = extract_features_logits(modified_mobilenet_v2, trainloader)

def estimate_mutual_information(k, data, labels):
    nbrs = NearestNeighbors(n_neighbors=k+1).fit(data.reshape(data.shape[0], -1))
    distances, indices = nbrs.kneighbors(data.reshape(data.shape[0], -1))
    epsilon = distances[:, -1]
    label_distances = labels[indices]
    counts = (label_distances == labels[:, None]).sum(axis=1) - 1
    mi_estimate = psi(k) - (counts/k).mean() + psi(len(data))
    return mi_estimate

mi_features_labels = estimate_mutual_information(5, features, labels)
print(f'Estimated mutual information between features and labels: {mi_features_labels}')

mi_logits_labels = estimate_mutual_information(5, logits, labels)
print(f'Estimated mutual information between logits and labels: {mi_logits_labels}')

Extracting features and logits:   3%|▎         | 20/782 [00:13<08:39,  1.47it/s]


Estimated mutual information between features and labels: 8.544092349482836
Estimated mutual information between logits and labels: 8.451904849482837


In [10]:
# 加载预训练的VGG13模型
vgg13 = models.vgg13(pretrained=True)

class ModifiedVGG13(torch.nn.Module):
    def __init__(self, original_model):
        super(ModifiedVGG13, self).__init__()
        self.features = original_model.features
        self.avgpool = original_model.avgpool
        self.classifier = original_model.classifier

    def forward(self, x):
        x = self.features(x)
        feature_maps = x
        x = self.avgpool(x)
        x = torch.flatten(x, 1)  # 展平
        logits = self.classifier(x)
        return feature_maps, logits

modified_vgg13 = ModifiedVGG13(vgg13)

def extract_features_logits(model, dataloader):
    model.eval()
    all_features = []
    all_logits = []
    all_labels = []

    i = 0
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Extracting features and logits"):
            feature_maps, logits = model(inputs)
            i += 1
            if i > 20:
                break
            all_features.extend(feature_maps.cpu().numpy())
            all_logits.extend(logits.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return np.array(all_features), np.array(all_logits), np.array(all_labels)

features, logits, labels = extract_features_logits(modified_vgg13, trainloader)

Extracting features and logits:   3%|▎         | 20/782 [01:03<40:14,  3.17s/it]


In [11]:
features.shape, logits.shape, labels.shape

((1280, 512, 7, 7), (1280, 1000), (1280,))

In [30]:
# 假设features, logits, labels已经准备好
# 需要将features从(1280, 512, 7, 7)展平为(1280, 512*7*7)

features_flattened = torch.tensor(features).reshape(features.shape[0], -1)  # 展平特征
logits = torch.tensor(logits)

# 将labels从(1280,)扩展为(1280, 1)以便和特征、logits拼接
labels_expanded = torch.tensor(labels).reshape(-1, 1).float()  # 假设labels是LongTensor，需要转换为FloatTensor用于拼接.

joint_features = torch.cat((features_flattened, labels_expanded), dim=1)
joint_logits = torch.cat((logits, labels_expanded), dim=1)

# 为边缘分布创建打乱的labels
shuffled_labels = labels[torch.randperm(labels.shape[0])]
shuffled_labels_expanded = torch.tensor(shuffled_labels).reshape(-1, 1).float()

marginal_features = torch.cat((features_flattened, shuffled_labels_expanded), dim=1)
marginal_logits = torch.cat((logits, shuffled_labels_expanded), dim=1)


  logits = torch.tensor(logits)


In [31]:
import torch.nn as nn
# 定义MINE网络的输入维度
features_dim = features_flattened.size(1) + 1  # 特征维度+标签维度
logits_dim = logits.size(1) + 1  # logits维度+标签维度
hidden_dim = 100  # 可以根据需要调整隐藏层维度

# 重新定义MINE网络以匹配新的输入维度
class MINE(nn.Module):
    def __init__(self, input_dim):
        super(MINE, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

# 创建两个MINE网络实例，一个用于特征和标签，一个用于logits和标签
mine_net_features = MINE(features_dim)
mine_net_logits = MINE(logits_dim)

In [33]:
import torch.optim as optim

optimizer_features = optim.Adam(mine_net_features.parameters(), lr=1e-4)
optimizer_logits = optim.Adam(mine_net_logits.parameters(), lr=1e-4)

def mutual_information_loss(t, et):
    """计算互信息损失"""
    mi_loss = -(torch.mean(t) - torch.log(torch.mean(torch.exp(et))))
    return mi_loss

num_epochs = 1000
for epoch in range(num_epochs):
    # 计算联合分布和边缘分布的网络输出
    t_features = mine_net_features(joint_features)
    et_features = mine_net_features(marginal_features)
    loss_features = mutual_information_loss(t_features, et_features)
    
    t_logits = mine_net_logits(joint_logits)
    et_logits = mine_net_logits(marginal_logits)
    loss_logits = mutual_information_loss(t_logits, et_logits)

    # 反向传播和优化
    optimizer_features.zero_grad()
    loss_features.backward()
    optimizer_features.step()

    optimizer_logits.zero_grad()
    loss_logits.backward()
    optimizer_logits.step()

    if epoch % 10 == 0:
        print(f'Epoch {epoch}, MI Features: {-loss_features.item()}, MI Logits: {-loss_logits.item()}')

Epoch 0, MI Features: -0.006370410323143005, MI Logits: -0.10478800535202026
Epoch 10, MI Features: 0.10805314034223557, MI Logits: 0.0008315276354551315
Epoch 20, MI Features: 0.1784161478281021, MI Logits: 0.0301041416823864
Epoch 30, MI Features: 0.25960874557495117, MI Logits: 0.05236309766769409
Epoch 40, MI Features: 0.3529484272003174, MI Logits: 0.07308085262775421
Epoch 50, MI Features: 0.4553203284740448, MI Logits: 0.0926579087972641
Epoch 60, MI Features: 0.5737002491950989, MI Logits: 0.11138467490673065
Epoch 70, MI Features: 0.6998199820518494, MI Logits: 0.13082760572433472
Epoch 80, MI Features: 0.8330408334732056, MI Logits: 0.15056318044662476
Epoch 90, MI Features: 0.9693697690963745, MI Logits: 0.17055219411849976
