In [1]:
import math
import numpy as np
from IPython.display import clear_output
from tqdm import tqdm_notebook as tqdm

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.color_palette("bright")
import matplotlib as mpl
import matplotlib.cm as cm

import torch
from torch import Tensor
from torch import nn
from torch.nn  import functional as F 
from torch.autograd import Variable
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_cuda = torch.cuda.is_available()

In [2]:
def save_checkpoint(epoch, model, optimizer):
    """
    Save model checkpoint.
    :param epoch: epoch number
    :param model: model
    :param optimizer: optimizer
    """
    state = {'epoch': epoch,
             'model': model,
             'optimizer': optimizer}
    filename = 'checkpoint_neuralode_mixup.pth.tar'
    torch.save(state, filename)

class AverageMeter(object):
    """
    Keeps track of most recent, average, sum, and count of a metric.
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [3]:
# 数据集读取

epochs = 200
pre_epoch = 0
BATCH_SIZE = 128
LR = 0.01

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

train_dataset = torchvision.datasets.CIFAR10(root='data', train=True, download=False, transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

test_dataset = torchvision.datasets.CIFAR10(root='data', train=False, download=False, transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [4]:
def ode_solve(z0, t0, t1, f):
    """
    Simplest Euler ODE initial value solver
    """
    h_max = 0.05
    n_steps = math.ceil((abs(t1 - t0)/h_max).max().item())  #向上取整

    h = (t1 - t0)/n_steps
    t = t0
    z = z0

    for i_step in range(n_steps):
        z = z + h * f(z, t)
        t = t + h
    return z

In [5]:
class ODEF(nn.Module):
    def forward_with_grad(self, z, t, grad_outputs):
        """Compute f and a df/dz, a df/dp, a df/dt"""
        batch_size = z.shape[0]  #矩阵的列数

        out = self.forward(z, t)

        a = grad_outputs
        adfdz, adfdt, *adfdp = torch.autograd.grad(
            (out,), (z, t) + tuple(self.parameters()), grad_outputs=(a),
            allow_unused=True, retain_graph=True
        )    #自动求导
        # grad method automatically sums gradients for batch items, we have to expand them back 
        if adfdp is not None:
            adfdp = torch.cat([p_grad.flatten() for p_grad in adfdp]).unsqueeze(0)  #关于p的参数转化为一维数组并加一个参数
            adfdp = adfdp.expand(batch_size, -1) / batch_size   #加batch_size个行 再除以batch_size
        if adfdt is not None:
            adfdt = adfdt.expand(batch_size, 1) / batch_size
        return out, adfdz, adfdt, adfdp

    def flatten_parameters(self):
        p_shapes = []
        flat_parameters = []
        for p in self.parameters():
            p_shapes.append(p.size())
            flat_parameters.append(p.flatten())
        return torch.cat(flat_parameters)   #所有向量放到一个维度

In [6]:
class ODEAdjoint(torch.autograd.Function):
    @staticmethod
    def forward(ctx, z0, t, flat_parameters, func):
        assert isinstance(func, ODEF)      #判断func是否在ODEF类
        bs, *z_shape = z0.size()   #bs是有几行(数据维度),z_shape有几列(数据数量)
        time_len = t.size(0)  #有几行

        with torch.no_grad():  #即使输入求导，输出不求导
            z = torch.zeros(time_len, bs, *z_shape).to(z0)   #三个维度, 时间*数据
            z[0] = z0
            for i_t in range(time_len - 1):
                z0 = ode_solve(z0, t[i_t], t[i_t+1], func)
                z[i_t+1] = z0
        ctx.func = func
        ctx.save_for_backward(t, z.clone(), flat_parameters)  #为了反向传播保留input全部信息
        return z

    @staticmethod
    def backward(ctx, dLdz):
        """
        dLdz shape: time_len, batch_size, *z_shape
        """
        func = ctx.func
        t, z, flat_parameters = ctx.saved_tensors   #被储存的参数
        time_len, bs, *z_shape = z.size()    #z的时间，维度，数据量
        n_dim = np.prod(z_shape)      #np内元素的乘积 
        n_params = flat_parameters.size(0)  #参数数量，size[0]是有几行
        # Dynamics of augmented system to be calculated backwards in time
        def augmented_dynamics(aug_z_i, t_i):
            """
            tensors here are temporal slices
            t_i - is tensor with size: bs, 1
            aug_z_i - is tensor with size: bs, 2*n_dim + n_params + 1
            """
            #aug_z_i为原方程中的增广状态[z(t1),dL/dz(t1),0(\theta),-dL/dt]
            #t_i的维度？ time_len?
            z_i, a = aug_z_i[:, :n_dim], aug_z_i[:, n_dim:2*n_dim]  # ignore parameters and time  
            #z_i是z(t1)，a是dL/dz
            
            # Unflatten z and a
            z_i = z_i.view(bs, *z_shape)   #括号里面是想要的维度,torch.view()的作用
            a = a.view(bs, *z_shape)
            with torch.set_grad_enabled(True):
                #detach_()将计算图中节点转为叶子节点，也就是将节点.grad_fn设置为none，这样detach_()的前一个节点就不会再与当前变量连接
                t_i = t_i.detach().requires_grad_(True) 
                z_i = z_i.detach().requires_grad_(True)
                func_eval, adfdz, adfdt, adfdp = func.forward_with_grad(z_i, t_i, grad_outputs=a)  # bs, *z_shape
                #func_eval是输出, 其余为a乘以导数
                adfdz = adfdz.to(z_i) if adfdz is not None else torch.zeros(bs, *z_shape).to(z_i)  #torch.to()
                adfdp = adfdp.to(z_i) if adfdp is not None else torch.zeros(bs, n_params).to(z_i)
                adfdt = adfdt.to(z_i) if adfdt is not None else torch.zeros(bs, 1).to(z_i)

            # Flatten f and adfdz
            func_eval = func_eval.view(bs, n_dim)
            adfdz = adfdz.view(bs, n_dim) 
            return torch.cat((func_eval, -adfdz, -adfdp, -adfdt), dim=1)  #输出新的增广状态

        dLdz = dLdz.view(time_len, bs, n_dim)  # flatten dLdz for convenience

        with torch.no_grad():
            ## Create placeholders for output gradients
            # Prev computed backwards adjoints to be adjusted by direct gradients
            adj_z = torch.zeros(bs, n_dim).to(dLdz)
            adj_p = torch.zeros(bs, n_params).to(dLdz)
            # In contrast to z and p we need to return gradients for all times
            adj_t = torch.zeros(time_len, bs, 1).to(dLdz)

            for i_t in range(time_len-1, 0, -1):     #反向传播
                z_i = z[i_t]
                t_i = t[i_t]
                f_i = func(z_i, t_i).view(bs, n_dim)

                # Compute direct gradients
                dLdz_i = dLdz[i_t]
                dLdt_i = torch.bmm(torch.transpose(dLdz_i.unsqueeze(-1), 1, 2), f_i.unsqueeze(-1))[:, 0] 
                #bmm三维tensor乘积 第一个维度一样 后两个维度矩阵乘积
                #torch.transpose(a,1,2)是交换a第2和第3维度
                #torch.unsqueeze(-1)是在最后一个维度上加一个
                #dL/dt = dL/dz*dz/dt 在对应的时间节点求

                # Adjusting adjoints with direct gradients
                adj_z += dLdz_i    #z的伴随就是dL/dz  根据节点调整 图上的
                adj_t[i_t] = adj_t[i_t] - dLdt_i   

                # Pack augmented variable
                aug_z = torch.cat((z_i.view(bs, n_dim), adj_z, torch.zeros(bs, n_params).to(z), adj_t[i_t]), dim=-1) #按照bs拼接

                # Solve augmented system backwards
                aug_ans = ode_solve(aug_z, t_i, t[i_t-1], augmented_dynamics)
                
                # Unpack solved backwards augmented system
                adj_z[:] = aug_ans[:, n_dim:2*n_dim]
                adj_p[:] += aug_ans[:, 2*n_dim:2*n_dim + n_params]   #不需要调整，一直加
                adj_t[i_t-1] = aug_ans[:, 2*n_dim + n_params:]
                #得出来的结果分配到新的增广状态

                del aug_z, aug_ans

            ## Adjust 0 time adjoint with direct gradients
            # Compute direct gradients 

            dLdz_0 = dLdz[0]
            dLdt_0 = torch.bmm(torch.transpose(dLdz_0.unsqueeze(-1), 1, 2), f_i.unsqueeze(-1))[:, 0]

            # Adjust adjoints
            adj_z += dLdz_0
            adj_t[0] = adj_t[0] - dLdt_0
        return adj_z.view(bs, *z_shape), adj_t, adj_p, None    #得到了到t0的增广状态

In [7]:
class NeuralODE(nn.Module):
    def __init__(self, func):
        super(NeuralODE, self).__init__()  #调用父类
        assert isinstance(func, ODEF)
        self.func = func

    def forward(self, z0, t=Tensor([0., 1.]), return_whole_sequence=False):
        t = t.to(z0)
        z = ODEAdjoint.apply(z0, t, self.func.flatten_parameters(), self.func)   #只有前向传播
        if return_whole_sequence:
            return z
        else:
            return z[-1]

In [8]:
def norm(dim):
    return nn.BatchNorm2d(dim)  #norm为batchnormalization函数

def conv3x3(in_feats, out_feats, stride=1):
    return nn.Conv2d(in_feats, out_feats, kernel_size=3, stride=stride, padding=1, bias=False) #二维数据的卷积操作

def add_time(in_tensor, t):
    bs, c, w, h = in_tensor.shape
    return torch.cat((in_tensor, t.expand(bs, 1, w, h)), dim=1)  #增加时间变量

In [9]:
class ConvODEF(ODEF):
    def __init__(self, dim):
        super(ConvODEF, self).__init__()
        self.conv1 = conv3x3(dim + 1, dim)
        self.norm1 = norm(dim)
        self.conv2 = conv3x3(dim + 1, dim)
        self.norm2 = norm(dim)

    def forward(self, x, t):
        xt = add_time(x, t)
        h = self.norm1(torch.relu(self.conv1(xt)))
        ht = add_time(h, t)
        dxdt = self.norm2(torch.relu(self.conv2(ht)))
        return dxdt

In [10]:
class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(outchannel)
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or inchannel != outchannel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(outchannel)
            )

    def forward(self, x):
        out = self.left(x)
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [11]:
class odenet(nn.Module):
    def __init__(self, ResidualBlock, ConvODEF, num_classes=10):
        super(odenet, self).__init__()
        self.inchannel = 64
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        func = ConvODEF(64)
        self.layer1 = NeuralODE(func)
        self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
        self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)
        self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)
        self.fc = nn.Linear(512, num_classes)

    def make_layer(self, block, channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)   #strides=[1,1]
        layers = []
        for stride in strides:
            layers.append(block(self.inchannel, channels, stride))
            self.inchannel = channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)    #核为4，步长为4
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


def odeNet18():

    return odenet(ResidualBlock, ConvODEF)

In [12]:
def mixup_data(x, y, alpha=1.0, use_cuda=True):

    '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
    if alpha > 0.:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.
    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index,:]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(y_a, y_b, lam):
    return lambda criterion, pred: lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [None]:
checkpoint = 'checkpoint_neuralode_mixup.pth.tar'
print_freq = 20

def main():
    """
    Training.
    """
    
    global start_epoch, classes, epoch, checkpoint
    
    # 初始化模型
    
    if checkpoint is None:
        start_epoch = 0
        model = odeNet18()
        optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
    else:
        checkpoint = torch.load(checkpoint, map_location = 'cuda')
        start_epoch = checkpoint['epoch'] + 1
        print('\nLoaded checkpoint from epoch %d.\n' % start_epoch)
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']
    
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(start_epoch, epochs):
        
        if epoch == 100:
            optimizer = optim.SGD(model.parameters(), lr=LR*0.1, momentum=0.9, weight_decay=5e-4)
        if epoch == 150:
            optimizer = optim.SGD(model.parameters(), lr=LR*0.01, momentum=0.9, weight_decay=5e-4)
        
        train(train_loader = train_loader,
             model = model,
             criterion=criterion,
             optimizer=optimizer,
             epoch=epoch)
        save_checkpoint(epoch, model, optimizer)
        evaluate(test_loader, model)
        

def train(train_loader, model, criterion, optimizer, epoch):
    
    model = model.train()
    
    batch_time = AverageMeter()  # forward prop. + back prop. time
    data_time = AverageMeter()  # data loading time
    losses = AverageMeter()  # loss

    start = time.time()
    total = 0
    correct = 0
    alpha = 1.0
    for i, data in enumerate(train_loader, 0):
        length = len(train_loader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        inputs, labels_a, labels_b, lam = mixup_data(inputs, labels, alpha, use_cuda)
        optimizer.zero_grad()

        # forward + backward
        outputs = model(inputs)
        loss_func = mixup_criterion(labels_a, labels_b, lam)
        loss = loss_func(criterion, outputs)
        loss.backward()
        optimizer.step()
        
        losses.update(loss.item(), inputs.size(0))
        batch_time.update(time.time() - start)
        start = time.time()
        
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch, i, len(train_loader),
                                                    batch_time=batch_time,
                                                    data_time=data_time, loss=losses))
            f1 = open("train_odenet18_mixup.txt", "a")
            f1.write('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t\n'.format(epoch, i, len(train_loader),
                                                    batch_time=batch_time,
                                                    data_time=data_time, loss=losses))
            f1.close()
    print(correct/total)
    f2 = open("train_acc_odenet18_mixup.txt", "a")
    f2.write('odenet_cifar10_mixup测试分类准确率为：%.3f%%\n' % (100 * correct / total))
    f2.close()
def evaluate(test_loader, model):
    
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for i, data in enumerate(test_loader, 0):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            # 取得分最高的那个类 (outputs.data的索引号)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()
        print('测试分类准确率为：%.3f%%' % (100 * correct / total))
        f3 = open("test_acc_odenet18_mixup.txt", "a")
        f3.write('odenet_mixup测试分类准确率为：%.3f%%\n' % (100 * correct / total))
        f3.close()
        acc = 100. * correct / total
        best_acc = 85
        if acc > best_acc:
            f3 = open("best_acc_odenet18_mixup.txt", "w")
            f3.write("EPOCH=%d,best_acc= %.3f%%" % (epoch + 1, acc))
            f3.close()
            best_acc = acc
if __name__ == '__main__':
    main()


Loaded checkpoint from epoch 96.

Epoch: [96][0/391]	Batch Time 7.484 (7.484)	Data Time 0.000 (0.000)	Loss 0.9755 (0.9755)	
Epoch: [96][20/391]	Batch Time 1.250 (1.544)	Data Time 0.000 (0.000)	Loss 1.2330 (0.7824)	
Epoch: [96][40/391]	Batch Time 1.250 (1.399)	Data Time 0.000 (0.000)	Loss 0.4217 (0.7986)	
Epoch: [96][60/391]	Batch Time 1.250 (1.350)	Data Time 0.000 (0.000)	Loss 0.8412 (0.8099)	
Epoch: [96][80/391]	Batch Time 1.250 (1.325)	Data Time 0.000 (0.000)	Loss 0.8991 (0.8331)	
Epoch: [96][100/391]	Batch Time 1.250 (1.310)	Data Time 0.000 (0.000)	Loss 0.5895 (0.8377)	
Epoch: [96][120/391]	Batch Time 1.250 (1.301)	Data Time 0.000 (0.000)	Loss 1.2165 (0.8438)	
Epoch: [96][140/391]	Batch Time 1.265 (1.294)	Data Time 0.000 (0.000)	Loss 0.8284 (0.8576)	
Epoch: [96][160/391]	Batch Time 1.250 (1.290)	Data Time 0.000 (0.000)	Loss 0.7445 (0.8453)	
Epoch: [96][180/391]	Batch Time 1.250 (1.285)	Data Time 0.000 (0.000)	Loss 1.2367 (0.8492)	
Epoch: [96][200/391]	Batch Time 1.250 (1.283)	Data 

Epoch: [100][160/391]	Batch Time 1.276 (1.278)	Data Time 0.000 (0.000)	Loss 1.1776 (0.8279)	
Epoch: [100][180/391]	Batch Time 1.265 (1.276)	Data Time 0.000 (0.000)	Loss 1.1710 (0.8412)	
Epoch: [100][200/391]	Batch Time 1.265 (1.275)	Data Time 0.000 (0.000)	Loss 1.1721 (0.8364)	
Epoch: [100][220/391]	Batch Time 1.265 (1.274)	Data Time 0.000 (0.000)	Loss 1.1425 (0.8273)	
Epoch: [100][240/391]	Batch Time 1.265 (1.273)	Data Time 0.000 (0.000)	Loss 1.0361 (0.8264)	
Epoch: [100][260/391]	Batch Time 1.265 (1.273)	Data Time 0.000 (0.000)	Loss 0.9151 (0.8281)	
Epoch: [100][280/391]	Batch Time 1.281 (1.272)	Data Time 0.000 (0.000)	Loss 0.7373 (0.8212)	
Epoch: [100][300/391]	Batch Time 1.250 (1.272)	Data Time 0.000 (0.000)	Loss 1.1272 (0.8218)	
Epoch: [100][320/391]	Batch Time 1.265 (1.271)	Data Time 0.000 (0.000)	Loss 0.8594 (0.8231)	
Epoch: [100][340/391]	Batch Time 1.281 (1.271)	Data Time 0.000 (0.000)	Loss 0.9923 (0.8286)	
Epoch: [100][360/391]	Batch Time 1.265 (1.270)	Data Time 0.000 (0.000)

Epoch: [104][300/391]	Batch Time 1.267 (1.272)	Data Time 0.000 (0.000)	Loss 0.8773 (0.8090)	
Epoch: [104][320/391]	Batch Time 1.265 (1.272)	Data Time 0.000 (0.000)	Loss 0.2589 (0.8063)	
Epoch: [104][340/391]	Batch Time 1.281 (1.271)	Data Time 0.000 (0.000)	Loss 1.1714 (0.8050)	
Epoch: [104][360/391]	Batch Time 1.250 (1.271)	Data Time 0.000 (0.000)	Loss 1.1185 (0.8099)	
Epoch: [104][380/391]	Batch Time 1.250 (1.271)	Data Time 0.000 (0.000)	Loss 1.1850 (0.8054)	
tensor(0.5473)
测试分类准确率为：94.330%
Epoch: [105][0/391]	Batch Time 3.803 (3.803)	Data Time 0.000 (0.000)	Loss 0.7230 (0.7230)	
Epoch: [105][20/391]	Batch Time 1.265 (1.383)	Data Time 0.000 (0.000)	Loss 0.6139 (0.7288)	
Epoch: [105][40/391]	Batch Time 1.265 (1.325)	Data Time 0.000 (0.000)	Loss 0.9682 (0.8026)	
Epoch: [105][60/391]	Batch Time 1.265 (1.305)	Data Time 0.000 (0.000)	Loss 0.4187 (0.7912)	
Epoch: [105][80/391]	Batch Time 1.265 (1.295)	Data Time 0.000 (0.000)	Loss 0.9518 (0.8030)	
Epoch: [105][100/391]	Batch Time 1.265 (1.28

Epoch: [109][40/391]	Batch Time 1.250 (1.321)	Data Time 0.000 (0.000)	Loss 0.3810 (0.8304)	
Epoch: [109][60/391]	Batch Time 1.265 (1.302)	Data Time 0.000 (0.000)	Loss 1.0063 (0.7917)	
Epoch: [109][80/391]	Batch Time 1.269 (1.291)	Data Time 0.000 (0.000)	Loss 1.1519 (0.7794)	
Epoch: [109][100/391]	Batch Time 1.265 (1.285)	Data Time 0.000 (0.000)	Loss 0.9301 (0.7973)	
Epoch: [109][120/391]	Batch Time 1.265 (1.281)	Data Time 0.000 (0.000)	Loss 0.9209 (0.7933)	
Epoch: [109][140/391]	Batch Time 1.265 (1.278)	Data Time 0.000 (0.000)	Loss 1.1867 (0.7900)	
Epoch: [109][160/391]	Batch Time 1.265 (1.276)	Data Time 0.000 (0.000)	Loss 1.0913 (0.7881)	
Epoch: [109][180/391]	Batch Time 1.265 (1.274)	Data Time 0.000 (0.000)	Loss 1.0750 (0.8078)	
Epoch: [109][200/391]	Batch Time 1.250 (1.272)	Data Time 0.000 (0.000)	Loss 0.6744 (0.8058)	
Epoch: [109][220/391]	Batch Time 1.265 (1.271)	Data Time 0.000 (0.000)	Loss 0.1506 (0.7961)	
Epoch: [109][240/391]	Batch Time 1.250 (1.270)	Data Time 0.000 (0.000)	Lo

Epoch: [113][200/391]	Batch Time 1.269 (1.277)	Data Time 0.000 (0.000)	Loss 1.0880 (0.8155)	
Epoch: [113][220/391]	Batch Time 1.271 (1.276)	Data Time 0.000 (0.000)	Loss 0.8853 (0.8094)	
Epoch: [113][240/391]	Batch Time 1.253 (1.275)	Data Time 0.000 (0.000)	Loss 1.1698 (0.8150)	
Epoch: [113][260/391]	Batch Time 1.255 (1.274)	Data Time 0.000 (0.000)	Loss 1.1514 (0.8130)	
Epoch: [113][280/391]	Batch Time 1.271 (1.273)	Data Time 0.000 (0.000)	Loss 0.5474 (0.8116)	
Epoch: [113][300/391]	Batch Time 1.255 (1.272)	Data Time 0.000 (0.000)	Loss 0.6017 (0.8157)	
Epoch: [113][320/391]	Batch Time 1.270 (1.272)	Data Time 0.000 (0.000)	Loss 0.8578 (0.8129)	
Epoch: [113][340/391]	Batch Time 1.271 (1.271)	Data Time 0.000 (0.000)	Loss 1.0659 (0.8140)	
Epoch: [113][360/391]	Batch Time 1.254 (1.271)	Data Time 0.000 (0.000)	Loss 1.1186 (0.8155)	
Epoch: [113][380/391]	Batch Time 1.271 (1.270)	Data Time 0.000 (0.000)	Loss 0.2162 (0.8114)	
tensor(0.5272)
测试分类准确率为：94.810%
Epoch: [114][0/391]	Batch Time 3.894 (

Epoch: [117][340/391]	Batch Time 1.265 (1.269)	Data Time 0.000 (0.000)	Loss 1.0768 (0.7825)	
Epoch: [117][360/391]	Batch Time 1.250 (1.269)	Data Time 0.000 (0.000)	Loss 0.9252 (0.7866)	
Epoch: [117][380/391]	Batch Time 1.265 (1.269)	Data Time 0.000 (0.000)	Loss 1.0937 (0.7895)	
tensor(0.5330)
测试分类准确率为：94.260%
Epoch: [118][0/391]	Batch Time 3.801 (3.801)	Data Time 0.000 (0.000)	Loss 0.1363 (0.1363)	
Epoch: [118][20/391]	Batch Time 1.250 (1.382)	Data Time 0.000 (0.000)	Loss 0.6138 (0.6359)	
Epoch: [118][40/391]	Batch Time 1.265 (1.323)	Data Time 0.000 (0.000)	Loss 1.0780 (0.7389)	
Epoch: [118][60/391]	Batch Time 1.265 (1.303)	Data Time 0.000 (0.000)	Loss 0.9445 (0.7584)	
Epoch: [118][80/391]	Batch Time 1.265 (1.293)	Data Time 0.000 (0.000)	Loss 1.1357 (0.7904)	
Epoch: [118][100/391]	Batch Time 1.265 (1.287)	Data Time 0.000 (0.000)	Loss 0.7026 (0.7772)	
Epoch: [118][120/391]	Batch Time 1.265 (1.283)	Data Time 0.000 (0.000)	Loss 0.1740 (0.7905)	
Epoch: [118][140/391]	Batch Time 1.250 (1.28

Epoch: [122][80/391]	Batch Time 1.250 (1.293)	Data Time 0.000 (0.000)	Loss 0.6091 (0.8302)	
Epoch: [122][100/391]	Batch Time 1.265 (1.287)	Data Time 0.000 (0.000)	Loss 0.5980 (0.8214)	
Epoch: [122][120/391]	Batch Time 1.254 (1.283)	Data Time 0.000 (0.000)	Loss 1.1133 (0.8039)	
Epoch: [122][140/391]	Batch Time 1.265 (1.280)	Data Time 0.000 (0.000)	Loss 0.8285 (0.8081)	
Epoch: [122][160/391]	Batch Time 1.265 (1.278)	Data Time 0.000 (0.000)	Loss 1.0546 (0.8111)	
Epoch: [122][180/391]	Batch Time 1.250 (1.276)	Data Time 0.000 (0.000)	Loss 1.1374 (0.8181)	
Epoch: [122][200/391]	Batch Time 1.265 (1.274)	Data Time 0.000 (0.000)	Loss 1.1105 (0.8161)	
Epoch: [122][220/391]	Batch Time 1.265 (1.273)	Data Time 0.000 (0.000)	Loss 1.0083 (0.8190)	
Epoch: [122][240/391]	Batch Time 1.250 (1.272)	Data Time 0.000 (0.000)	Loss 1.0489 (0.8147)	
Epoch: [122][260/391]	Batch Time 1.265 (1.272)	Data Time 0.000 (0.000)	Loss 0.2131 (0.8100)	
Epoch: [122][280/391]	Batch Time 1.250 (1.271)	Data Time 0.000 (0.000)	

Epoch: [126][220/391]	Batch Time 1.265 (1.273)	Data Time 0.000 (0.000)	Loss 0.9988 (0.7768)	
Epoch: [126][240/391]	Batch Time 1.265 (1.272)	Data Time 0.000 (0.000)	Loss 1.0342 (0.7812)	
Epoch: [126][260/391]	Batch Time 1.265 (1.272)	Data Time 0.000 (0.000)	Loss 0.9126 (0.7783)	
Epoch: [126][280/391]	Batch Time 1.265 (1.271)	Data Time 0.000 (0.000)	Loss 0.9616 (0.7788)	
Epoch: [126][300/391]	Batch Time 1.265 (1.270)	Data Time 0.000 (0.000)	Loss 0.4818 (0.7651)	
Epoch: [126][320/391]	Batch Time 1.265 (1.270)	Data Time 0.000 (0.000)	Loss 0.9201 (0.7687)	
Epoch: [126][340/391]	Batch Time 1.250 (1.269)	Data Time 0.000 (0.000)	Loss 0.3068 (0.7688)	
Epoch: [126][360/391]	Batch Time 1.254 (1.269)	Data Time 0.000 (0.000)	Loss 0.7178 (0.7652)	
Epoch: [126][380/391]	Batch Time 1.265 (1.268)	Data Time 0.000 (0.000)	Loss 1.1669 (0.7680)	
tensor(0.5609)
测试分类准确率为：94.530%
Epoch: [127][0/391]	Batch Time 3.844 (3.844)	Data Time 0.000 (0.000)	Loss 1.0966 (1.0966)	
Epoch: [127][20/391]	Batch Time 1.250 (1

Epoch: [130][360/391]	Batch Time 1.265 (1.269)	Data Time 0.000 (0.000)	Loss 0.4227 (0.7654)	
Epoch: [130][380/391]	Batch Time 1.256 (1.269)	Data Time 0.000 (0.000)	Loss 0.9779 (0.7704)	
tensor(0.5479)
测试分类准确率为：94.460%
Epoch: [131][0/391]	Batch Time 3.814 (3.814)	Data Time 0.000 (0.000)	Loss 0.9274 (0.9274)	
Epoch: [131][20/391]	Batch Time 1.266 (1.382)	Data Time 0.000 (0.000)	Loss 1.0593 (0.7318)	
Epoch: [131][40/391]	Batch Time 1.265 (1.323)	Data Time 0.000 (0.000)	Loss 0.9648 (0.7879)	
Epoch: [131][60/391]	Batch Time 1.265 (1.303)	Data Time 0.000 (0.000)	Loss 0.9875 (0.8040)	
Epoch: [131][80/391]	Batch Time 1.250 (1.293)	Data Time 0.000 (0.000)	Loss 0.7845 (0.7953)	
Epoch: [131][100/391]	Batch Time 1.265 (1.287)	Data Time 0.000 (0.000)	Loss 1.0084 (0.7841)	
Epoch: [131][120/391]	Batch Time 1.265 (1.283)	Data Time 0.000 (0.000)	Loss 1.1280 (0.7948)	
Epoch: [131][140/391]	Batch Time 1.265 (1.280)	Data Time 0.000 (0.000)	Loss 0.2823 (0.7811)	
Epoch: [131][160/391]	Batch Time 1.269 (1.27

Epoch: [135][100/391]	Batch Time 1.265 (1.287)	Data Time 0.000 (0.000)	Loss 0.9647 (0.8229)	
Epoch: [135][120/391]	Batch Time 1.250 (1.282)	Data Time 0.000 (0.000)	Loss 1.1329 (0.8384)	
Epoch: [135][140/391]	Batch Time 1.271 (1.280)	Data Time 0.000 (0.000)	Loss 0.4895 (0.8444)	
Epoch: [135][160/391]	Batch Time 1.250 (1.277)	Data Time 0.000 (0.000)	Loss 1.0326 (0.8386)	
Epoch: [135][180/391]	Batch Time 1.250 (1.276)	Data Time 0.000 (0.000)	Loss 0.9520 (0.8393)	
Epoch: [135][200/391]	Batch Time 1.265 (1.274)	Data Time 0.000 (0.000)	Loss 1.0542 (0.8343)	
Epoch: [135][220/391]	Batch Time 1.250 (1.273)	Data Time 0.000 (0.000)	Loss 1.0166 (0.8224)	
Epoch: [135][240/391]	Batch Time 1.250 (1.272)	Data Time 0.000 (0.000)	Loss 0.2993 (0.8225)	
Epoch: [135][260/391]	Batch Time 1.258 (1.272)	Data Time 0.000 (0.000)	Loss 0.3538 (0.8126)	
Epoch: [135][280/391]	Batch Time 1.265 (1.271)	Data Time 0.000 (0.000)	Loss 1.0909 (0.8080)	
Epoch: [135][300/391]	Batch Time 1.265 (1.270)	Data Time 0.000 (0.000)

Epoch: [139][240/391]	Batch Time 1.265 (1.273)	Data Time 0.000 (0.000)	Loss 0.3796 (0.7790)	
Epoch: [139][260/391]	Batch Time 1.265 (1.272)	Data Time 0.000 (0.000)	Loss 0.6313 (0.7853)	
Epoch: [139][280/391]	Batch Time 1.270 (1.271)	Data Time 0.000 (0.000)	Loss 0.3088 (0.7691)	
Epoch: [139][300/391]	Batch Time 1.265 (1.271)	Data Time 0.000 (0.000)	Loss 1.1191 (0.7709)	
Epoch: [139][320/391]	Batch Time 1.265 (1.270)	Data Time 0.000 (0.000)	Loss 1.1700 (0.7801)	
Epoch: [139][340/391]	Batch Time 1.264 (1.270)	Data Time 0.000 (0.000)	Loss 0.0855 (0.7811)	
Epoch: [139][360/391]	Batch Time 1.265 (1.269)	Data Time 0.000 (0.000)	Loss 0.9193 (0.7801)	
Epoch: [139][380/391]	Batch Time 1.265 (1.269)	Data Time 0.000 (0.000)	Loss 1.1249 (0.7741)	
tensor(0.5567)
测试分类准确率为：94.510%
Epoch: [140][0/391]	Batch Time 3.797 (3.797)	Data Time 0.000 (0.000)	Loss 0.9628 (0.9628)	
Epoch: [140][20/391]	Batch Time 1.265 (1.382)	Data Time 0.000 (0.000)	Loss 0.6781 (0.8262)	
Epoch: [140][40/391]	Batch Time 1.265 (1.

Epoch: [143][380/391]	Batch Time 1.250 (1.268)	Data Time 0.000 (0.000)	Loss 0.6530 (0.7869)	
tensor(0.5432)
测试分类准确率为：94.760%
Epoch: [144][0/391]	Batch Time 3.815 (3.815)	Data Time 0.000 (0.000)	Loss 0.6133 (0.6133)	
Epoch: [144][20/391]	Batch Time 1.250 (1.382)	Data Time 0.000 (0.000)	Loss 1.0987 (0.7311)	
Epoch: [144][40/391]	Batch Time 1.250 (1.324)	Data Time 0.000 (0.000)	Loss 0.6371 (0.7641)	
Epoch: [144][60/391]	Batch Time 1.265 (1.303)	Data Time 0.000 (0.000)	Loss 0.4681 (0.7728)	
Epoch: [144][80/391]	Batch Time 1.265 (1.293)	Data Time 0.000 (0.000)	Loss 0.7380 (0.7798)	
Epoch: [144][100/391]	Batch Time 1.265 (1.287)	Data Time 0.000 (0.000)	Loss 0.2279 (0.7751)	
Epoch: [144][120/391]	Batch Time 1.265 (1.283)	Data Time 0.000 (0.000)	Loss 0.1569 (0.7837)	
Epoch: [144][140/391]	Batch Time 1.265 (1.280)	Data Time 0.000 (0.000)	Loss 0.3273 (0.7914)	
Epoch: [144][160/391]	Batch Time 1.265 (1.278)	Data Time 0.000 (0.000)	Loss 0.5841 (0.8030)	
Epoch: [144][180/391]	Batch Time 1.250 (1.27

Epoch: [148][120/391]	Batch Time 1.265 (1.283)	Data Time 0.000 (0.000)	Loss 0.4365 (0.7526)	
Epoch: [148][140/391]	Batch Time 1.265 (1.280)	Data Time 0.000 (0.000)	Loss 1.2139 (0.7477)	
Epoch: [148][160/391]	Batch Time 1.265 (1.278)	Data Time 0.000 (0.000)	Loss 0.4107 (0.7551)	
Epoch: [148][180/391]	Batch Time 1.250 (1.276)	Data Time 0.000 (0.000)	Loss 0.4318 (0.7598)	
Epoch: [148][200/391]	Batch Time 1.265 (1.275)	Data Time 0.000 (0.000)	Loss 1.1082 (0.7490)	
Epoch: [148][220/391]	Batch Time 1.250 (1.273)	Data Time 0.000 (0.000)	Loss 0.8265 (0.7528)	
Epoch: [148][240/391]	Batch Time 1.265 (1.272)	Data Time 0.000 (0.000)	Loss 0.1192 (0.7586)	
Epoch: [148][260/391]	Batch Time 1.265 (1.272)	Data Time 0.000 (0.000)	Loss 0.7020 (0.7654)	
Epoch: [148][280/391]	Batch Time 1.265 (1.271)	Data Time 0.000 (0.000)	Loss 0.9364 (0.7641)	
Epoch: [148][300/391]	Batch Time 1.265 (1.270)	Data Time 0.000 (0.000)	Loss 0.8581 (0.7619)	
Epoch: [148][320/391]	Batch Time 1.261 (1.270)	Data Time 0.000 (0.000)

Epoch: [152][260/391]	Batch Time 1.265 (1.272)	Data Time 0.000 (0.000)	Loss 1.0618 (0.7458)	
Epoch: [152][280/391]	Batch Time 1.265 (1.271)	Data Time 0.000 (0.000)	Loss 1.1906 (0.7423)	
Epoch: [152][300/391]	Batch Time 1.265 (1.270)	Data Time 0.000 (0.000)	Loss 0.5857 (0.7457)	
Epoch: [152][320/391]	Batch Time 1.265 (1.270)	Data Time 0.000 (0.000)	Loss 0.9381 (0.7518)	
Epoch: [152][340/391]	Batch Time 1.265 (1.270)	Data Time 0.000 (0.000)	Loss 0.1809 (0.7491)	
Epoch: [152][360/391]	Batch Time 1.265 (1.269)	Data Time 0.000 (0.000)	Loss 0.7453 (0.7480)	
Epoch: [152][380/391]	Batch Time 1.265 (1.269)	Data Time 0.000 (0.000)	Loss 0.5117 (0.7462)	
tensor(0.5444)
测试分类准确率为：94.790%
Epoch: [153][0/391]	Batch Time 3.797 (3.797)	Data Time 0.000 (0.000)	Loss 0.9223 (0.9223)	
Epoch: [153][20/391]	Batch Time 1.265 (1.383)	Data Time 0.000 (0.000)	Loss 0.4755 (0.8224)	
Epoch: [153][40/391]	Batch Time 1.265 (1.324)	Data Time 0.000 (0.000)	Loss 1.1207 (0.8302)	
Epoch: [153][60/391]	Batch Time 1.250 (1.3

tensor(0.5575)
测试分类准确率为：94.860%
Epoch: [157][0/391]	Batch Time 3.824 (3.824)	Data Time 0.000 (0.000)	Loss 1.1227 (1.1227)	
Epoch: [157][20/391]	Batch Time 1.265 (1.383)	Data Time 0.000 (0.000)	Loss 1.0206 (0.8335)	
Epoch: [157][40/391]	Batch Time 1.250 (1.323)	Data Time 0.000 (0.000)	Loss 0.8354 (0.8609)	
Epoch: [157][60/391]	Batch Time 1.265 (1.303)	Data Time 0.000 (0.000)	Loss 0.0706 (0.8292)	
Epoch: [157][80/391]	Batch Time 1.265 (1.293)	Data Time 0.000 (0.000)	Loss 0.5832 (0.8153)	
Epoch: [157][100/391]	Batch Time 1.265 (1.287)	Data Time 0.000 (0.000)	Loss 0.2250 (0.8008)	
Epoch: [157][120/391]	Batch Time 1.250 (1.283)	Data Time 0.000 (0.000)	Loss 0.4664 (0.7792)	
Epoch: [157][140/391]	Batch Time 1.265 (1.280)	Data Time 0.000 (0.000)	Loss 0.0430 (0.7861)	
Epoch: [157][160/391]	Batch Time 1.265 (1.277)	Data Time 0.000 (0.000)	Loss 1.0282 (0.7924)	
Epoch: [157][180/391]	Batch Time 1.265 (1.276)	Data Time 0.000 (0.000)	Loss 1.0426 (0.7877)	
Epoch: [157][200/391]	Batch Time 1.265 (1.27

Epoch: [161][140/391]	Batch Time 1.265 (1.280)	Data Time 0.000 (0.000)	Loss 0.9775 (0.7502)	
Epoch: [161][160/391]	Batch Time 1.265 (1.278)	Data Time 0.000 (0.000)	Loss 1.0484 (0.7519)	
Epoch: [161][180/391]	Batch Time 1.265 (1.276)	Data Time 0.000 (0.000)	Loss 1.1514 (0.7672)	
Epoch: [161][200/391]	Batch Time 1.265 (1.275)	Data Time 0.000 (0.000)	Loss 1.0610 (0.7622)	
Epoch: [161][220/391]	Batch Time 1.250 (1.273)	Data Time 0.000 (0.000)	Loss 0.8252 (0.7694)	
Epoch: [161][240/391]	Batch Time 1.265 (1.272)	Data Time 0.000 (0.000)	Loss 1.0084 (0.7726)	
Epoch: [161][260/391]	Batch Time 1.265 (1.272)	Data Time 0.000 (0.000)	Loss 0.8399 (0.7694)	
Epoch: [161][280/391]	Batch Time 1.250 (1.271)	Data Time 0.000 (0.000)	Loss 0.4462 (0.7707)	
Epoch: [161][300/391]	Batch Time 1.265 (1.270)	Data Time 0.000 (0.000)	Loss 0.7263 (0.7677)	
Epoch: [161][320/391]	Batch Time 1.265 (1.270)	Data Time 0.000 (0.000)	Loss 1.0362 (0.7716)	
Epoch: [161][340/391]	Batch Time 1.265 (1.269)	Data Time 0.000 (0.000)

Epoch: [165][280/391]	Batch Time 1.265 (1.271)	Data Time 0.000 (0.000)	Loss 1.1397 (0.7403)	
Epoch: [165][300/391]	Batch Time 1.250 (1.271)	Data Time 0.000 (0.000)	Loss 0.8412 (0.7432)	
Epoch: [165][320/391]	Batch Time 1.265 (1.270)	Data Time 0.000 (0.000)	Loss 0.9863 (0.7493)	
Epoch: [165][340/391]	Batch Time 1.265 (1.270)	Data Time 0.000 (0.000)	Loss 0.9215 (0.7485)	
Epoch: [165][360/391]	Batch Time 1.254 (1.269)	Data Time 0.000 (0.000)	Loss 0.5049 (0.7515)	
Epoch: [165][380/391]	Batch Time 1.265 (1.269)	Data Time 0.000 (0.000)	Loss 0.4122 (0.7517)	
tensor(0.5705)
测试分类准确率为：94.920%
Epoch: [166][0/391]	Batch Time 3.806 (3.806)	Data Time 0.000 (0.000)	Loss 0.9696 (0.9696)	
Epoch: [166][20/391]	Batch Time 1.265 (1.383)	Data Time 0.000 (0.000)	Loss 0.7997 (0.8492)	
Epoch: [166][40/391]	Batch Time 1.250 (1.323)	Data Time 0.000 (0.000)	Loss 0.0940 (0.7617)	
Epoch: [166][60/391]	Batch Time 1.265 (1.303)	Data Time 0.000 (0.000)	Loss 0.7720 (0.7663)	
Epoch: [166][80/391]	Batch Time 1.265 (1.29

Epoch: [170][20/391]	Batch Time 1.260 (1.382)	Data Time 0.000 (0.000)	Loss 0.8040 (0.7456)	
Epoch: [170][40/391]	Batch Time 1.265 (1.324)	Data Time 0.000 (0.000)	Loss 1.0455 (0.7526)	
Epoch: [170][60/391]	Batch Time 1.265 (1.304)	Data Time 0.000 (0.000)	Loss 1.1686 (0.7772)	
Epoch: [170][80/391]	Batch Time 1.265 (1.293)	Data Time 0.000 (0.000)	Loss 1.1358 (0.7770)	
Epoch: [170][100/391]	Batch Time 1.265 (1.287)	Data Time 0.000 (0.000)	Loss 0.7346 (0.7728)	
Epoch: [170][120/391]	Batch Time 1.265 (1.283)	Data Time 0.000 (0.000)	Loss 0.4178 (0.7686)	
Epoch: [170][140/391]	Batch Time 1.263 (1.280)	Data Time 0.000 (0.000)	Loss 1.1431 (0.7631)	
Epoch: [170][160/391]	Batch Time 1.250 (1.278)	Data Time 0.000 (0.000)	Loss 0.8178 (0.7720)	
Epoch: [170][180/391]	Batch Time 1.265 (1.276)	Data Time 0.000 (0.000)	Loss 0.9996 (0.7668)	
Epoch: [170][200/391]	Batch Time 1.265 (1.275)	Data Time 0.000 (0.000)	Loss 1.0751 (0.7746)	
Epoch: [170][220/391]	Batch Time 1.250 (1.273)	Data Time 0.000 (0.000)	Los

Epoch: [174][180/391]	Batch Time 1.265 (1.276)	Data Time 0.000 (0.000)	Loss 0.1666 (0.7410)	
Epoch: [174][200/391]	Batch Time 1.265 (1.275)	Data Time 0.000 (0.000)	Loss 0.7712 (0.7403)	
Epoch: [174][220/391]	Batch Time 1.265 (1.274)	Data Time 0.000 (0.000)	Loss 0.1314 (0.7394)	
Epoch: [174][240/391]	Batch Time 1.265 (1.273)	Data Time 0.000 (0.000)	Loss 0.7879 (0.7479)	
Epoch: [174][260/391]	Batch Time 1.271 (1.272)	Data Time 0.000 (0.000)	Loss 1.1442 (0.7597)	
Epoch: [174][280/391]	Batch Time 1.250 (1.271)	Data Time 0.000 (0.000)	Loss 0.9903 (0.7587)	
Epoch: [174][300/391]	Batch Time 1.265 (1.271)	Data Time 0.000 (0.000)	Loss 0.9307 (0.7613)	
Epoch: [174][320/391]	Batch Time 1.265 (1.270)	Data Time 0.000 (0.000)	Loss 0.3860 (0.7595)	
Epoch: [174][340/391]	Batch Time 1.265 (1.270)	Data Time 0.000 (0.000)	Loss 0.1988 (0.7569)	
Epoch: [174][360/391]	Batch Time 1.265 (1.269)	Data Time 0.000 (0.000)	Loss 0.4511 (0.7602)	
Epoch: [174][380/391]	Batch Time 1.250 (1.269)	Data Time 0.000 (0.000)

Epoch: [178][320/391]	Batch Time 1.265 (1.269)	Data Time 0.000 (0.000)	Loss 0.3288 (0.7985)	
Epoch: [178][340/391]	Batch Time 1.265 (1.269)	Data Time 0.000 (0.000)	Loss 0.3014 (0.7994)	
Epoch: [178][360/391]	Batch Time 1.265 (1.269)	Data Time 0.000 (0.000)	Loss 0.6098 (0.8014)	
Epoch: [178][380/391]	Batch Time 1.265 (1.268)	Data Time 0.000 (0.000)	Loss 0.9684 (0.8014)	
tensor(0.5113)
测试分类准确率为：94.900%
Epoch: [179][0/391]	Batch Time 3.796 (3.796)	Data Time 0.000 (0.000)	Loss 0.8101 (0.8101)	
Epoch: [179][20/391]	Batch Time 1.250 (1.381)	Data Time 0.000 (0.000)	Loss 0.3446 (0.6129)	
Epoch: [179][40/391]	Batch Time 1.265 (1.323)	Data Time 0.000 (0.000)	Loss 0.6116 (0.6659)	
Epoch: [179][60/391]	Batch Time 1.262 (1.303)	Data Time 0.000 (0.000)	Loss 0.5580 (0.6821)	
Epoch: [179][80/391]	Batch Time 1.265 (1.293)	Data Time 0.000 (0.000)	Loss 1.1169 (0.7209)	
Epoch: [179][100/391]	Batch Time 1.265 (1.287)	Data Time 0.000 (0.000)	Loss 0.7279 (0.7387)	
Epoch: [179][120/391]	Batch Time 1.265 (1.28

Epoch: [183][60/391]	Batch Time 1.250 (1.303)	Data Time 0.000 (0.000)	Loss 0.5243 (0.7964)	
Epoch: [183][80/391]	Batch Time 1.250 (1.293)	Data Time 0.000 (0.000)	Loss 0.4772 (0.7809)	
Epoch: [183][100/391]	Batch Time 1.265 (1.287)	Data Time 0.000 (0.000)	Loss 0.2502 (0.7763)	
Epoch: [183][120/391]	Batch Time 1.265 (1.283)	Data Time 0.000 (0.000)	Loss 0.2555 (0.7578)	
Epoch: [183][140/391]	Batch Time 1.265 (1.280)	Data Time 0.000 (0.000)	Loss 0.1765 (0.7498)	
Epoch: [183][160/391]	Batch Time 1.250 (1.277)	Data Time 0.000 (0.000)	Loss 1.0618 (0.7440)	
Epoch: [183][180/391]	Batch Time 1.265 (1.276)	Data Time 0.000 (0.000)	Loss 1.1162 (0.7470)	
Epoch: [183][200/391]	Batch Time 1.265 (1.274)	Data Time 0.000 (0.000)	Loss 0.8600 (0.7536)	
Epoch: [183][220/391]	Batch Time 1.265 (1.273)	Data Time 0.000 (0.000)	Loss 0.2430 (0.7471)	
Epoch: [183][240/391]	Batch Time 1.265 (1.272)	Data Time 0.000 (0.000)	Loss 0.8868 (0.7498)	
Epoch: [183][260/391]	Batch Time 1.265 (1.271)	Data Time 0.000 (0.000)	L

Epoch: [187][200/391]	Batch Time 1.265 (1.274)	Data Time 0.000 (0.000)	Loss 1.0127 (0.7412)	
Epoch: [187][220/391]	Batch Time 1.265 (1.273)	Data Time 0.000 (0.000)	Loss 0.9980 (0.7475)	
Epoch: [187][240/391]	Batch Time 1.265 (1.272)	Data Time 0.000 (0.000)	Loss 0.9767 (0.7478)	
Epoch: [187][260/391]	Batch Time 1.265 (1.272)	Data Time 0.000 (0.000)	Loss 0.6506 (0.7460)	
Epoch: [187][280/391]	Batch Time 1.265 (1.271)	Data Time 0.000 (0.000)	Loss 0.7531 (0.7471)	
Epoch: [187][300/391]	Batch Time 1.250 (1.270)	Data Time 0.000 (0.000)	Loss 0.8949 (0.7511)	
Epoch: [187][320/391]	Batch Time 1.250 (1.270)	Data Time 0.000 (0.000)	Loss 0.4309 (0.7593)	
Epoch: [187][340/391]	Batch Time 1.250 (1.269)	Data Time 0.000 (0.000)	Loss 0.4997 (0.7536)	
Epoch: [187][360/391]	Batch Time 1.265 (1.269)	Data Time 0.000 (0.000)	Loss 0.9235 (0.7546)	
Epoch: [187][380/391]	Batch Time 1.265 (1.269)	Data Time 0.000 (0.000)	Loss 1.0169 (0.7553)	
tensor(0.5523)
测试分类准确率为：94.950%
Epoch: [188][0/391]	Batch Time 3.809 (

Epoch: [191][340/391]	Batch Time 1.265 (1.270)	Data Time 0.000 (0.000)	Loss 1.0302 (0.7528)	
Epoch: [191][360/391]	Batch Time 1.265 (1.269)	Data Time 0.000 (0.000)	Loss 0.8155 (0.7561)	
Epoch: [191][380/391]	Batch Time 1.265 (1.269)	Data Time 0.000 (0.000)	Loss 1.0260 (0.7535)	
tensor(0.5380)
测试分类准确率为：95.100%
Epoch: [192][0/391]	Batch Time 3.812 (3.812)	Data Time 0.000 (0.000)	Loss 1.1241 (1.1241)	
Epoch: [192][20/391]	Batch Time 1.265 (1.383)	Data Time 0.000 (0.000)	Loss 1.0030 (0.7765)	
Epoch: [192][40/391]	Batch Time 1.265 (1.323)	Data Time 0.000 (0.000)	Loss 1.0515 (0.7931)	
Epoch: [192][60/391]	Batch Time 1.250 (1.303)	Data Time 0.000 (0.000)	Loss 0.6126 (0.7776)	
Epoch: [192][80/391]	Batch Time 1.265 (1.293)	Data Time 0.000 (0.000)	Loss 0.9012 (0.7766)	
Epoch: [192][100/391]	Batch Time 1.265 (1.287)	Data Time 0.000 (0.000)	Loss 1.0161 (0.7871)	
Epoch: [192][120/391]	Batch Time 1.265 (1.283)	Data Time 0.000 (0.000)	Loss 0.4383 (0.7703)	
Epoch: [192][140/391]	Batch Time 1.265 (1.28

Epoch: [196][80/391]	Batch Time 1.254 (1.293)	Data Time 0.000 (0.000)	Loss 0.9234 (0.8060)	
Epoch: [196][100/391]	Batch Time 1.265 (1.287)	Data Time 0.000 (0.000)	Loss 1.0388 (0.7995)	
Epoch: [196][120/391]	Batch Time 1.265 (1.283)	Data Time 0.000 (0.000)	Loss 0.8389 (0.7929)	
Epoch: [196][140/391]	Batch Time 1.265 (1.280)	Data Time 0.000 (0.000)	Loss 1.1451 (0.7808)	
Epoch: [196][160/391]	Batch Time 1.265 (1.278)	Data Time 0.000 (0.000)	Loss 0.8293 (0.7813)	
Epoch: [196][180/391]	Batch Time 1.265 (1.276)	Data Time 0.000 (0.000)	Loss 0.6185 (0.7763)	
Epoch: [196][200/391]	Batch Time 1.265 (1.274)	Data Time 0.000 (0.000)	Loss 0.8886 (0.7556)	
Epoch: [196][220/391]	Batch Time 1.261 (1.273)	Data Time 0.000 (0.000)	Loss 0.9993 (0.7531)	
Epoch: [196][240/391]	Batch Time 1.250 (1.272)	Data Time 0.000 (0.000)	Loss 0.4781 (0.7462)	
Epoch: [196][260/391]	Batch Time 1.265 (1.272)	Data Time 0.000 (0.000)	Loss 0.9864 (0.7476)	
Epoch: [196][280/391]	Batch Time 1.250 (1.271)	Data Time 0.000 (0.000)	

In [15]:
checkpoint = 'checkpoint_neuralode_mixup.pth.tar'
checkpoint = torch.load(checkpoint, map_location = 'cuda')
model = checkpoint['model']
optimizer = checkpoint['optimizer']
evaluate(test_loader, model)

测试分类准确率为：91.980%
