In [None]:
import torch 
from torch import nn 
from torch.utils import data
import torchvision
from torchvision import transforms, datasets
from visdom import Visdom
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.figure_factory as ff
import numpy as np
viz = Visdom()

In [None]:
def load_data_fashion_mnist(batch_size, train_rate=0.8, resize=None):
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = datasets.FashionMNIST(root='../data', train=True,
                                        transform=trans, download=False)
    mnist_test = datasets.FashionMNIST(root='../data', train=False,
                                       transform=trans, download=False)

    train_size = int(train_rate* len(mnist_train))
    val_size = len(mnist_train) - train_size
    train_dataset, val_dataset = data.random_split(
        mnist_train, [train_size, val_size])

    return (data.DataLoader(train_dataset, batch_size, shuffle=True),
            data.DataLoader(val_dataset, batch_size, shuffle=False),
            mnist_test,
            mnist_train.classes)


In [None]:
# 展示一批数据
data_iter, _, _, classes = load_data_fashion_mnist(64)
batch_data = next(iter(data_iter))
for i in range(2):
    batch_data[i] = batch_data[i].squeeze().numpy()
titles = [classes[i] for i in batch_data[1]]
fig = make_subplots(rows=4, cols=16, subplot_titles=titles)
for r in range(4):
    for c in range(16):
        fig.add_trace(go.Heatmap(z=batch_data[0][r*4+c]), row=r+1, col=c+1)

fig.update_layout(title_text="FashionMNIST")
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False,autorange="reversed")
viz.plotlyplot(fig, win='datashow')

In [None]:
# 定义网络
class Reshape(torch.nn.Module):
    def forward(self, x):
        return x.view(-1, 1, 28, 28)
net = nn.Sequential(Reshape(), nn.Conv2d(1, 6, kernel_size=5,
                                                 padding=2), nn.Sigmoid(),
                            nn.AvgPool2d(kernel_size=2, stride=2),
                            nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
                            nn.AvgPool2d(
                                kernel_size=2, stride=2), nn.Flatten(),
                            nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
                            nn.Linear(120, 84), nn.Sigmoid(), nn.Linear(84, 10))


In [None]:
# 每一个训练epoch
def train_epoch(model, train_loader, optimizer, loss_fn, epoch):
    size = len(train_loader.dataset)
    num_batches = len(train_loader)
    model.train()  # 将模型设置为训练模式
    train_loss, train_correct = 0, 0
    for batch_idx, (X, y) in enumerate(train_loader):
        pred = model(X)
        y = y.to(pred.device)
        loss = loss_fn(pred, y)
        # 梯度清零， 反向传播，更新网络参数
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 记录损失与正确率
        train_loss += loss.item()
        train_correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        # 每 100批输出一次
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(X), size,
                100. * batch_idx / num_batches, loss.item()))
    return train_loss / num_batches, train_correct / size
# 测试epoch


def val_epoch(model, val_loader, loss_fn):
    size = len(val_loader.dataset)
    num_batches = len(val_loader)
    model.eval()  # 设为评估模式
    test_loss, test_correct = 0, 0
    # 不记录梯度，节省内存
    with torch.no_grad():
        for X, y in val_loader:
            pred = model(X)
            y = y.to(pred.device)
            loss = loss_fn(pred, y)
            test_loss += loss.item()
            test_correct += (pred.argmax(1) ==
                             y).type(torch.float).sum().item()
    test_loss /= num_batches
    print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, test_correct, size, 100. * test_correct / size))
    return test_loss, test_correct / size


# 分类问题使用交叉熵作为损失函数
loss_fn = nn.CrossEntropyLoss()
# 使用随机梯度下降法更新
trainer = torch.optim.Adam(net.parameters(), lr=0.01)
# 使用DP模式训练
net = nn.DataParallel(net)
# 训练轮数
num_epochs = 25
# 加载数据集
train_iter, val_iter, test_data, _ = load_data_fashion_mnist(256)
# 记录损失和正确率
train_loss, train_accuracy = [], []
val_loss, val_accuracy = [], []
for epoch in range(1, num_epochs + 1):
    print(f"Epoch {epoch}\n-------------------------------")
    a, b = train_epoch(net, train_iter, trainer, loss_fn, epoch)
    train_loss.append(a)
    train_accuracy.append(b)
    c, d = val_epoch(net, val_iter, loss_fn)
    val_loss.append(c)
    val_accuracy.append(d)


In [None]:
x = np.asarray(range(num_epochs))
xx = np.column_stack((x, x))
y1 = np.column_stack((train_loss, val_loss))
y2 = np.column_stack((train_accuracy, val_accuracy))
viz.line(y1, xx, win='Loss',
         opts=dict(legend=['train_loss',  'val_loss'], xlabel='epoch',
                   ylabel='loss', title='Loss',
                   markers=True, markersize=8))
viz.line(y2, xx, win='Accuracy',
         opts=dict(legend=['train_acc',  'val_acc'], xlabel='epoch',
                   ylabel='Acc', title='Accuracy',
                   markers=True, markersize=8))


In [None]:
from sklearn.metrics import accuracy_score,confusion_matrix
def test_model(model, test_X, test_y):
    model.eval()
    output = model(test_X)
    pre_lab = torch.argmax(output,1)
    test_y = test_y.cpu().numpy()
    pre_lab = pre_lab.cpu().numpy()
    acc = accuracy_score(test_y,pre_lab)
    print("在测试集上的预测精度为:",acc)
    conf_mat = confusion_matrix(test_y,pre_lab)
    fig = ff.create_annotated_heatmap(z=conf_mat, x=classes, y=classes, 
                                      annotation_text=np.around(conf_mat, decimals=2) , 
                                  colorscale='YlGnBu')
    fig.update_layout(title ='混淆矩阵')
    fig.update_xaxes(side="bottom")
    fig.update_yaxes(autorange="reversed")
    viz.plotlyplot(fig, win='heatmap')
test_data_X = test_data.data.type(torch.FloatTensor) / 255.0
test_data_xX= torch.unsqueeze(test_data_X,dim = 1)
test_data_y = test_data.targets  ## 测试集的标签
test_model(net, test_data_X, test_data_y)