In [5]:
import torchvision
from torchvision import transforms
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from tqdm import tqdm

In [6]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 96, kernel_size=(11, 11), stride=4)
        self.conv2 = nn.Conv2d(96, 256, kernel_size=(5, 5), padding=2)
        self.conv3 = nn.Conv2d(256, 384, kernel_size=(3, 3), padding=1)
        self.conv4 = nn.Conv2d(384, 384, kernel_size=(3, 3), padding=1)
        self.conv5 = nn.Conv2d(384, 256, kernel_size=(3, 3), padding=1)
        self.fc1 = nn.Linear(256*6*6, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, 10)

        self.pool = nn.MaxPool2d(kernel_size=3, stride=2)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.relu(self.conv1(x))
        # [batch, 1, 227, 227] -> [batch, 96, 55, 55]
        out = self.pool(out)
        # [batch, 96, 55, 55] -> [batch, 96, 27, 27]
        out = self.relu(self.conv2(out))
        # [batch, 96, 27, 27] -> [batch, 256, 27, 27]
        out = self.pool(out)
        # [batch, 256, 27, 27] -> [batch, 256, 13, 13]
        out = self.relu(self.conv3(out))
        # [batch, 256, 13, 13] -> [batch, 384, 13, 13]
        out = self.relu(self.conv4(out))
        # [batch, 384, 13, 13] -> [batch, 384, 13, 13]
        out = self.relu(self.conv5(out))
        # [batch, 384, 13, 13] -> [batch, 256, 13, 13]
        out = self.pool(out)
        # [batch, 256, 13, 13] -> [batch, 256, 6, 6]
        out = out.reshape(-1, 256*6*6)
        out = self.relu(self.fc1(out))
        out = self.relu(self.fc2(out))
        out = self.fc3(out)
        return out


def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)

In [7]:
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

SEED = 1
np.random.seed(SEED)
torch.manual_seed(SEED)

batch_size = 256
lr = 0.1
num_epoch = 10
LR = 1  # 这个学习率是My_LeNet.py中使用的

# softmax 超参数
num_inpus = 1*28*28
num_outputs = 10

# 多层感知机参数据
input_size, hidden_size, output_size = 784, 256, 10

# Dropout
p = 0.2

In [4]:
trans = transforms.Compose([transforms.Resize((227, 227)),
                            transforms.ToTensor()])
mnist_train = torchvision.datasets.FashionMNIST(root='../data', train=True,
                                                transform=trans,
                                                download=True)
mnist_test = torchvision.datasets.FashionMNIST(root='../data', train=False,
                                                transform=trans,
                                                download=True)

train_iter = DataLoader(mnist_train, batch_size, shuffle=True)
test_iter = DataLoader(mnist_test, batch_size, shuffle=False)
Model = Net().to(device)
Model.apply(init_weights)  # 不用这个初始化的话,效果很差
optimizer = optim.SGD(Model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss().to(device)
step = 0

for epoch in range(num_epoch):
    Model.train()
    train_accuracy_list = []
    for batch_x, batch_y in tqdm(train_iter):
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        out = Model(batch_x)
        loss = criterion(out, batch_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        y_pred = torch.argmax(out, dim=1)
        accuracy = accuracy_score(batch_y.detach().cpu().numpy(), y_pred.detach().cpu().numpy())
        train_accuracy_list.append(accuracy)

        step += 1

    train_accuracy_list = np.array(train_accuracy_list)
    avg_acc = train_accuracy_list.mean()
    tqdm.write(f'epoch:{epoch + 1}, train_accuracy:{avg_acc}')

    with torch.no_grad():
        Model.eval()
        test_accuracy_list = []
        for batch_x, batch_y in test_iter:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            out = Model(batch_x)
            y_pred = torch.argmax(out, dim=1)
            accuracy = accuracy_score(batch_y.detach().cpu().numpy(), y_pred.detach().cpu().numpy())
            test_accuracy_list.append(accuracy)

        test_accuracy_list = np.array(test_accuracy_list)
        test_acc = np.mean(test_accuracy_list)
        tqdm.write(f'epoch:{epoch + 1},test_accuracy{test_acc}')



Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ../data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw



100%|██████████| 235/235 [01:26<00:00,  2.71it/s]


epoch:1, train_accuracy:0.5260139627659575
epoch:1,test_accuracy0.7705078125


100%|██████████| 235/235 [01:27<00:00,  2.70it/s]


epoch:2, train_accuracy:0.8189273049645391
epoch:2,test_accuracy0.824609375


100%|██████████| 235/235 [01:27<00:00,  2.70it/s]


epoch:3, train_accuracy:0.8632480053191489
epoch:3,test_accuracy0.82783203125


100%|██████████| 235/235 [01:26<00:00,  2.71it/s]


epoch:4, train_accuracy:0.882313829787234
epoch:4,test_accuracy0.874609375


100%|██████████| 235/235 [01:27<00:00,  2.70it/s]


epoch:5, train_accuracy:0.8931626773049646
epoch:5,test_accuracy0.884765625


100%|██████████| 235/235 [01:26<00:00,  2.70it/s]


epoch:6, train_accuracy:0.9019946808510638
epoch:6,test_accuracy0.86376953125


100%|██████████| 235/235 [01:26<00:00,  2.70it/s]


epoch:7, train_accuracy:0.9114860372340425
epoch:7,test_accuracy0.89208984375


100%|██████████| 235/235 [01:26<00:00,  2.70it/s]


epoch:8, train_accuracy:0.9179909131205672
epoch:8,test_accuracy0.889453125


100%|██████████| 235/235 [01:27<00:00,  2.70it/s]


epoch:9, train_accuracy:0.9242519946808511
epoch:9,test_accuracy0.90556640625


100%|██████████| 235/235 [01:27<00:00,  2.70it/s]


epoch:10, train_accuracy:0.9326241134751774
epoch:10,test_accuracy0.8923828125
