In [1]:
import torch 
import torchvision
import torch.optim as optim
import torch.nn as nn
import numpy as np
import torchvision.transforms as transforms
import torch.nn.functional as F
#import pandas as pd
import matplotlib
import matplotlib.pyplot as plt

import wandb

wandb.init(project="mnist", entity="jaspertan")

import torch
import os
os.environ["WANDB_API_KEY"] = "9b51c1a70a432bca6e85f45f9d7936ed1ae780ff"

# 如果能用GPU则选用GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.1307,), (0.3081,))  #分别为数据集的均值和方差，中心化数据
     ])
data_train = torchvision.datasets.MNIST(root = "./work/mnist/",
                            transform=transform,
                            train = True,
                            download = True)
data_test = torchvision.datasets.MNIST(root="./work/mnist/",
                           transform = transform,
                           train = False,
                           download = True)

import matplotlib.pyplot as plt
import numpy as np

# 展示图像的函数

def imshow(img):
    img = img *0.3081  + 0.1307   # unnormalize
    print(img.shape)
    npimg = img.numpy()
    plt.figure(figsize=(20,20))
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

# 获取随机数据
dataiter = iter(trainloader)
images, labels = dataiter.next()

# 展示图像
print(images[1:4].shape)
imshow(torchvision.utils.make_grid(images[0:8]))
# 显示图像标签
print(' '.join('%5s' % labels[j].item() for j in range(batchsize)))

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5,padding=2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):                                       #input size :  [batchsize, 1, 28, 28]
        x = F.max_pool2d(F.relu(self.conv1(x)),2) # [batchsize,6,14,14]
        x = F.max_pool2d(F.relu(self.conv2(x)),2) # [batchsize,16,5,5]
        x = x.view(-1, 16 * 5 * 5)                           # [batchsize,16*5*5]
        x = F.relu(self.fc1(x))                                 # [batchsize,120]
        x = F.relu(self.fc2(x))                                 # [batchsize,84]
        x = self.fc3(x)                                            # [batchsize,10]
        return x


net = Net()

print(net)

model = Net().to(device)
optimizer = optim.Adam(model.parameters()) #使用adam优化器优化

def train(model, device, loader, optimizer, epoch):
    model.train()
    criterion = nn.CrossEntropyLoss()
    for i, (inputs, labels) in enumerate(loader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(inputs)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

        if (i % 100) == 0:    # compute loss every 300 batches
            predict = output.max(1).indices
            correct = torch.sum(predict == labels).item()
            wandb.log({"Train Loss": loss.item(), "Train Acc": correct/len(labels) }  )
            print('[%d, %5d] loss: %.3f Acc:%.3f' %
                   (epoch + 1, i + 1, loss.item(),  correct/len(labels) ))
            test(model,device,testloader,epoch)

def test(model, device, loader, epoch):
    model.eval()
    criterion = nn.CrossEntropyLoss(reduction='sum')
    avg_loss = 0
    correct = 0
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(loader):
            inputs, labels = inputs.to(device), labels.to(device)
            output = model(inputs)
            loss = criterion(output, labels)
            avg_loss += loss.item()
            predict = output.max(1).indices
            correct += torch.sum(predict == labels).item()
    total = len(loader.dataset)
    avg_loss /= total
    print('Avg Loss : %.3f , Accuracy : %.3f [%d/%d] \n' % (avg_loss, correct/total, correct, total) )
    wandb.log({"Test Loss": avg_loss, "Test Acc": correct/total})

#使用wandb存储当前参数
wandb.init(project="mnist")
wandb.watch(model, log="all")
config = wandb.config          # Initialize config
config.batch_size = 64          # input batch size for training (default: 64)
config.test_batch_size = 1000    # input batch size for testing (default: 1000)
config.epochs = 6             # number of epochs to train (default: 10)
#config.log_interval = 10 

batchsize = config.batch_size;
trainloader = torch.utils.data.DataLoader(data_train, batch_size=batchsize,
                shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(data_test, batch_size=batchsize,
                shuffle=False, num_workers=2)

for epoch in range(config.epochs):
    train(model,device,trainloader,optimizer, epoch)
    #test(model,device,testloader,epoch)

#保存parameters
torch.save(model.state_dict(), "model.h5")
wandb.save('model.h5')