In [1]:
import torch
from torch import nn
from d2l import torch as d2l
from train_epoch.train import train_ch6

In [2]:
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10))

In [3]:
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape: \t',X.shape)

In [5]:
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

In [49]:
lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

In [2]:
import torch
import torch.nn as nn

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=5, padding=2), nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(8, 16, kernel_size=5), nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(16 * 5 * 5, 120), nn.ReLU(),
            nn.Linear(120, 64), nn.ReLU(),
            nn.Linear(64, 32), nn.Sigmoid(),
            nn.Linear(32, 10)
        )
    
    def forward(self, x):
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        out = self.classifier(x2)
        return out
    
    def get_activations(self, x):
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        return x1, x2
# Initialize the network
lenet = LeNet()


In [7]:
lr, num_epochs = 0.9, 20
train_ch6(lenet, train_iter, test_iter, num_epochs, lr, d2l.try_gpu(),save_path="models/lenet")

# train_ch6(lenet, train_iter, test_iter, num_epochs, lr, d2l.try_gpu(),save_path="models/lenet",load_dir="models/lenet/best.ckpt")

In [58]:
import matplotlib.pyplot as plt

def visualize_activations(net, device, dataloader):
    net.to(device)
    net.eval()
    with torch.no_grad():
        for X, y in dataloader:
            X = X.to(device)
            activations1, activations2 = net.get_activations(X)
            # 只显示一个批次中的第一个样本的激活
            act1 = activations1[0].cpu()
            act2 = activations2[0].cpu()
            break  # We only need one batch for visualization
    
    # 可视化第一层的激活
    plt.figure(figsize=(16, 6))
    for i in range(8):  # 假设第一层有8个过滤器
        plt.subplot(2, 4, i + 1)
        plt.imshow(act1[i].numpy(), cmap='gray')
        plt.title(f'Layer 1 - Filter {i+1}')
        plt.axis('off')
    plt.show()

    # 可视化第二层的激活
    plt.figure(figsize=(16, 6))
    for i in range(16):  # 假设第二层有16个过滤器
        plt.subplot(4, 4, i + 1)
        plt.imshow(act2[i].numpy(), cmap='gray')
        plt.title(f'Layer 2 - Filter {i+1}')
        plt.axis('off')
    plt.show()

# 假设你已经有了一个数据加载器
visualize_activations(lenet, 'cuda', test_iter)


In [54]:
def visualize_incorrect_predictions(model, dataloader, device, classes, num_images=5):
    model.eval()
    model.to(device)
    incorrect_count = 0

    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs, 1)
            
            # 检查预测是否正确
            incorrect_indices = (predicted != target).nonzero(as_tuple=False).squeeze()
            for idx in incorrect_indices:
                if incorrect_count >= num_images:
                    return
                plt.figure(figsize=(3, 3))
                image = data[idx].cpu().numpy().transpose((1, 2, 0))
                true_label = classes[target[idx].item()]
                predicted_label = classes[predicted[idx].item()]
                
                plt.imshow(image, cmap="gray")
                plt.title(f"True: {true_label}, Pred: {predicted_label}")
                plt.axis('off')
                plt.show()
                
                incorrect_count += 1


In [59]:
# 假设类别标签如下
classes = [
    'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
]# 等等

# 调用函数
visualize_incorrect_predictions(lenet, test_iter, 'cuda', classes, num_images=10)
