# NIN (改进版)

In [None]:
import os

# 设置代理
os.environ['http_proxy'] = 'http://127.0.0.1:7893'
os.environ['https_proxy'] = 'http://127.0.0.1:7893'
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7893'
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7893'
os.environ['no_proxy'] = '127.0.0.1,localhost'
os.environ['NO_PROXY'] = '127.0.0.1,localhost'

# 验证代理设置
print(f"HTTP代理: {os.environ.get('http_proxy')}")

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time

In [None]:
# 配置matplotlib - 适用于Linux服务器环境
import matplotlib
matplotlib.use('Agg')  # 使用非交互式后端，适合服务器环境
import matplotlib.pyplot as plt
from matplotlib import rcParams

# Linux服务器中文字体配置
import warnings
warnings.filterwarnings('ignore', category=UserWarning, module='matplotlib')

# 尝试使用系统中文字体，如果没有则使用默认字体
try:
    plt.rcParams['font.sans-serif'] = ['WenQuanYi Micro Hei', 'WenQuanYi Zen Hei', 'Droid Sans Fallback', 'SimHei', 'DejaVu Sans']
    plt.rcParams['axes.unicode_minus'] = False
    print("中文字体配置完成")
except:
    print("使用默认字体")
    
# 清除matplotlib字体缓存
import matplotlib.font_manager as fm
fm._load_fontmanager(try_read_cache=False)

In [None]:
def nin_block(in_channels, out_channels, kernel_size, strides, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, stride=strides, padding=padding, kernel_size=kernel_size),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU()
    )

## 改进的网络架构

主要改进：
1. 减小第一层的kernel size从11改为7（更适合Fashion-MNIST）
2. 调整stride从4改为2（减少信息损失）
3. 增加padding让特征图大小更合理

In [None]:
# 改进版NIN网络
net = nn.Sequential(
    # 第一个模块：调整kernel size和stride
    nin_block(1, 96, kernel_size=7, strides=2, padding=3),  # 从11,4,0 改为 7,2,3
    nn.MaxPool2d(3, stride=2),
    
    # 第二个模块
    nin_block(96, 256, kernel_size=5, strides=1, padding=2),
    nn.MaxPool2d(3, stride=2),
    
    # 第三个模块
    nin_block(256, 384, kernel_size=3, strides=1, padding=1),
    nn.MaxPool2d(3, stride=2),
    nn.Dropout(p=0.5),
    
    # 输出模块
    nin_block(384, 10, kernel_size=3, strides=1, padding=1),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten()
)

# 初始化权重
def init_weights(m):
    if type(m) == nn.Conv2d:
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

net.apply(init_weights)

In [None]:
# 测试网络结构
X = torch.rand(1, 1, 224, 224)
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__, " out shape:\t", X.shape)

In [None]:
# 数据预处理和加载
batch_size = 128

# 定义数据转换
transform = transforms.Compose([
    transforms.Resize(224),  # 调整图像大小到224x224
    transforms.ToTensor(),
    # 使用Fashion-MNIST的统计值进行标准化
    transforms.Normalize(mean=[0.2860], std=[0.3530])  # 更合理的标准化参数
])

# 加载Fashion-MNIST数据集
train_dataset = datasets.FashionMNIST(
    root='./data',  # 数据存储路径
    train=True,
    download=True,
    transform=transform
)

test_dataset = datasets.FashionMNIST(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

# 创建数据加载器
train_iter = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_iter = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print(f'训练集大小: {len(train_dataset)}, 测试集大小: {len(test_dataset)}')

In [None]:
# 改进的训练函数
def train(net, train_iter, test_iter, num_epochs, lr, device):
    """训练模型"""
    print(f'training on {device}')
    net.to(device)
    
    # 使用Adam优化器，更稳定
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    # 添加学习率调度器
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
    loss_fn = nn.CrossEntropyLoss()
    
    train_losses = []
    train_accs = []
    test_accs = []
    
    for epoch in range(num_epochs):
        # 训练模式
        net.train()
        train_loss_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()
        
        for X, y in train_iter:
            X, y = X.to(device), y.to(device)
            
            # 前向传播
            y_hat = net(X)
            loss = loss_fn(y_hat, y)
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            # 梯度裁剪，防止梯度爆炸
            torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
            optimizer.step()
            
            # 统计
            train_loss_sum += loss.item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
            n += y.shape[0]
            batch_count += 1
        
        # 评估模式
        test_acc = evaluate_accuracy(net, test_iter, device)
        
        train_loss = train_loss_sum / batch_count
        train_acc = train_acc_sum / n
        
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        test_accs.append(test_acc)
        
        print(f'epoch {epoch + 1}, loss {train_loss:.4f}, '
              f'train acc {train_acc:.3f}, test acc {test_acc:.3f}, '
              f'time {time.time() - start:.1f} sec, '
              f'lr {scheduler.get_last_lr()[0]:.5f}')
        
        # 更新学习率
        scheduler.step()
    
    return train_losses, train_accs, test_accs

def evaluate_accuracy(net, data_iter, device):
    """评估模型准确率"""
    net.eval()
    acc_sum, n = 0.0, 0
    
    with torch.no_grad():
        for X, y in data_iter:
            X, y = X.to(device), y.to(device)
            acc_sum += (net(X).argmax(dim=1) == y).sum().item()
            n += y.shape[0]
    
    return acc_sum / n

In [None]:
# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 训练参数（降低学习率）
lr = 0.001  # 从0.1降低到0.001
num_epochs = 10

# 开始训练
train_losses, train_accs, test_accs = train(net, train_iter, test_iter, num_epochs, lr, device)

In [None]:
# 绘制训练曲线
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# 损失曲线
ax1.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss')
ax1.legend()
ax1.grid(True)

# 准确率曲线
ax2.plot(range(1, len(train_accs) + 1), train_accs, label='Training Accuracy')
ax2.plot(range(1, len(test_accs) + 1), test_accs, label='Test Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.savefig('nin_training_curves.png')
plt.show()

print(f"\n最终结果：")
print(f"训练准确率: {train_accs[-1]:.3f}")
print(f"测试准确率: {test_accs[-1]:.3f}")