In [1]:
import time
import torch
from torch import nn, optim
import torch.nn.functional as F

import sys
sys.path.append('..')
import d2lzh_pytorch as d2l

device = torch.device('cuda' if torch.cuda.is_available else 'cpu')

#### 定义残差块

In [16]:
class Residual(nn.Module):
    """
    构造残差块
    """
    def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):
        super(Residual, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1,stride=stride)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        return F.relu(Y + X)

#### 定义resnet模块

In [17]:
def resnet_block(in_channels, out_channels, num_residuals, first_block=False):
    if first_block:
        assert in_channels == out_channels
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            # 除第一个模块外的其他模块，在第一个残差块里将上一个模块的通道数翻倍，并将高和宽减半
            blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))
        else:
            blk.append(Residual(out_channels, out_channels))
    return nn.Sequential(*blk)

#### 构造renset_11模型

In [24]:
# 第一个卷积层
net = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
# 依次加入所有残差块，一共四个模块，每个模块使用两个残差块
net.add_module('resnet_block1', resnet_block(64, 64, 2, first_block=True))
net.add_module('resnet_block2', resnet_block(64, 128, 2))
net.add_module('resnet_block3', resnet_block(128, 256, 2))
net.add_module('resnet_block4', resnet_block(256, 512, 2))
# 加入全局平均池化层
net.add_module('global_avg_pool', d2l.GlobalAvgPool2d()) # output shape (batch_size, 512, 1, 1)
# 加入全连接层
net.add_module('fc', nn.Sequential(d2l.FlattenLayer(),nn.Linear(512, 10)))

- 构造数据观察一下形状变化

In [19]:
X = torch.rand((1, 1, 224, 224))
for name, layer in net.named_children():
    X = layer(X)
    print(name, 'output shape:\t', X.shape)

0 output shape:	 torch.Size([1, 64, 112, 112])
1 output shape:	 torch.Size([1, 64, 112, 112])
2 output shape:	 torch.Size([1, 64, 112, 112])
3 output shape:	 torch.Size([1, 64, 56, 56])
resnet_block1 output shape:	 torch.Size([1, 64, 56, 56])
resnet_block2 output shape:	 torch.Size([1, 128, 28, 28])
resnet_block3 output shape:	 torch.Size([1, 256, 14, 14])
resnet_block4 output shape:	 torch.Size([1, 512, 7, 7])
global_avg_pool output shape:	 torch.Size([1, 512, 1, 1])
fc output shape:	 torch.Size([1, 10])


#### 训练模型

In [31]:
def train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):
    net = net.to(device)
    print('training on ', device)
    loss = torch.nn.CrossEntropyLoss()
    batch_count = 0
    for epoch in range(1):
        train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            print(y_hat.shape, y.shape)
            print(y)
            l = loss(y_hat, y)
            optimizer.zero_grad() # 梯度清零
            l.backward() # 计算梯度
            optimizer.step() # 迭代参数
            train_l_sum += l.cpu().item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_count += 1
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train_acc %.3f, test acc %.3f, %.1f sec' %
             (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))

In [26]:
batch_size = 256
data_dir = './Datasets/FashionMNIST'
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96, root=data_dir)

In [32]:
lr, num_epochs = 0.001, 5
optimizer = optim.Adam(net.parameters(), lr=lr)
train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

training on  cuda
torch.Size([256, 10]) torch.Size([256])
tensor([7, 9, 5, 5, 5, 5, 5, 8, 3, 8, 9, 6, 4, 0, 1, 3, 9, 0, 1, 4, 2, 8, 1, 5,
        9, 6, 0, 5, 1, 7, 5, 2, 6, 0, 5, 6, 3, 7, 6, 1, 6, 5, 9, 3, 4, 6, 1, 0,
        1, 0, 7, 9, 9, 4, 6, 8, 3, 5, 6, 7, 9, 5, 8, 2, 5, 8, 8, 5, 6, 6, 7, 5,
        0, 9, 5, 7, 0, 0, 7, 9, 8, 4, 2, 6, 2, 4, 8, 6, 8, 6, 9, 0, 2, 4, 3, 3,
        8, 7, 6, 0, 8, 3, 3, 7, 4, 6, 7, 3, 3, 8, 9, 9, 8, 8, 2, 2, 4, 8, 6, 7,
        8, 5, 0, 4, 4, 9, 2, 3, 3, 2, 9, 1, 2, 5, 6, 2, 6, 5, 3, 0, 1, 7, 6, 2,
        3, 3, 6, 7, 5, 3, 8, 7, 7, 7, 4, 4, 6, 0, 6, 6, 2, 7, 5, 3, 9, 8, 3, 3,
        0, 6, 3, 1, 2, 2, 5, 4, 9, 7, 0, 1, 6, 5, 1, 1, 0, 1, 0, 9, 2, 9, 5, 5,
        4, 4, 1, 5, 8, 9, 5, 3, 7, 8, 9, 6, 2, 1, 4, 0, 6, 5, 7, 9, 7, 2, 6, 2,
        9, 3, 3, 6, 3, 0, 4, 6, 7, 6, 2, 4, 2, 7, 5, 5, 2, 9, 7, 8, 0, 9, 3, 2,
        1, 7, 2, 8, 4, 1, 0, 1, 5, 7, 3, 0, 0, 7, 1, 3], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([7, 9, 7, 0, 6, 3, 0,

torch.Size([256, 10]) torch.Size([256])
tensor([1, 2, 9, 7, 4, 3, 6, 3, 3, 0, 1, 9, 8, 3, 8, 2, 7, 0, 4, 7, 4, 9, 2, 3,
        8, 4, 2, 8, 1, 5, 7, 0, 7, 9, 0, 3, 0, 5, 7, 7, 1, 0, 4, 2, 1, 6, 4, 5,
        2, 5, 7, 4, 6, 3, 2, 2, 8, 4, 7, 9, 7, 8, 9, 1, 7, 4, 0, 7, 1, 8, 8, 1,
        6, 4, 6, 5, 4, 3, 2, 2, 4, 4, 8, 1, 6, 1, 9, 8, 4, 0, 0, 8, 5, 6, 5, 7,
        5, 8, 2, 2, 7, 6, 9, 3, 2, 0, 2, 5, 1, 6, 2, 9, 4, 2, 2, 9, 8, 3, 0, 7,
        4, 6, 4, 4, 0, 4, 1, 3, 8, 8, 0, 8, 2, 1, 8, 8, 6, 6, 4, 9, 2, 5, 5, 9,
        5, 7, 5, 8, 4, 0, 2, 3, 6, 3, 7, 4, 1, 0, 9, 5, 6, 5, 4, 1, 4, 3, 9, 5,
        9, 3, 3, 6, 3, 1, 8, 6, 7, 2, 3, 6, 2, 4, 6, 4, 9, 8, 5, 4, 7, 3, 8, 0,
        7, 6, 8, 1, 4, 7, 0, 1, 4, 2, 0, 2, 0, 1, 2, 0, 3, 9, 5, 2, 3, 8, 0, 7,
        2, 5, 9, 3, 4, 3, 8, 0, 7, 4, 2, 2, 0, 2, 1, 8, 1, 5, 3, 4, 7, 5, 5, 9,
        8, 1, 3, 6, 9, 6, 5, 2, 3, 0, 0, 1, 0, 4, 6, 1], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([4, 6, 0, 5, 3, 5, 3, 2, 2, 1, 1, 7, 8,

torch.Size([256, 10]) torch.Size([256])
tensor([9, 8, 2, 9, 6, 4, 5, 2, 0, 9, 9, 9, 4, 7, 9, 9, 5, 5, 4, 3, 8, 7, 5, 8,
        7, 9, 6, 0, 8, 5, 6, 5, 7, 9, 0, 7, 2, 4, 5, 0, 1, 1, 0, 0, 6, 1, 1, 8,
        4, 0, 9, 0, 6, 6, 9, 1, 6, 6, 2, 9, 0, 8, 1, 9, 8, 8, 3, 7, 1, 7, 4, 2,
        1, 2, 4, 8, 5, 1, 7, 1, 0, 5, 9, 8, 6, 2, 0, 2, 0, 8, 0, 7, 1, 0, 7, 6,
        5, 1, 8, 7, 5, 5, 3, 2, 1, 1, 0, 9, 9, 7, 9, 1, 7, 9, 1, 2, 1, 0, 5, 4,
        7, 2, 4, 0, 9, 0, 8, 6, 6, 3, 0, 3, 4, 0, 6, 5, 9, 2, 4, 8, 5, 5, 8, 0,
        8, 4, 2, 5, 6, 7, 5, 9, 9, 3, 3, 0, 6, 7, 0, 3, 3, 9, 6, 2, 0, 1, 7, 1,
        2, 3, 5, 1, 6, 7, 7, 3, 9, 8, 5, 3, 1, 0, 7, 5, 1, 0, 6, 2, 6, 3, 1, 5,
        2, 3, 8, 6, 0, 7, 4, 5, 9, 8, 6, 8, 7, 8, 2, 3, 9, 2, 7, 3, 6, 0, 9, 4,
        4, 1, 7, 6, 7, 2, 0, 6, 0, 8, 1, 1, 5, 3, 6, 8, 2, 6, 3, 5, 4, 1, 1, 7,
        6, 9, 7, 6, 3, 2, 5, 3, 4, 9, 1, 2, 0, 4, 5, 6], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([8, 7, 2, 1, 2, 0, 1, 1, 1, 6, 6, 7, 9,

torch.Size([256, 10]) torch.Size([256])
tensor([8, 5, 0, 1, 9, 7, 5, 8, 8, 4, 8, 1, 2, 5, 5, 4, 1, 8, 9, 2, 7, 1, 9, 6,
        3, 4, 4, 6, 0, 4, 5, 2, 6, 1, 2, 6, 4, 2, 8, 7, 3, 1, 0, 9, 1, 6, 8, 9,
        5, 5, 9, 9, 9, 7, 0, 6, 8, 7, 8, 0, 7, 5, 3, 9, 2, 8, 1, 5, 9, 6, 6, 1,
        4, 9, 8, 5, 7, 9, 7, 0, 8, 7, 1, 3, 1, 6, 9, 5, 1, 9, 4, 5, 8, 1, 7, 7,
        1, 8, 1, 2, 4, 5, 9, 2, 1, 2, 7, 0, 3, 5, 4, 8, 7, 2, 5, 8, 9, 8, 8, 8,
        3, 2, 5, 2, 4, 1, 7, 9, 2, 0, 7, 6, 4, 4, 6, 0, 4, 1, 4, 8, 4, 8, 4, 5,
        5, 5, 2, 7, 9, 3, 8, 0, 1, 5, 3, 2, 4, 5, 0, 2, 1, 1, 0, 8, 5, 4, 9, 7,
        8, 7, 9, 3, 4, 0, 9, 3, 6, 0, 6, 2, 7, 1, 8, 0, 3, 4, 9, 6, 8, 0, 7, 1,
        8, 0, 4, 7, 8, 5, 1, 8, 9, 4, 3, 1, 9, 6, 5, 9, 8, 9, 0, 1, 2, 5, 8, 5,
        5, 8, 8, 7, 7, 5, 1, 6, 2, 4, 0, 3, 5, 4, 1, 9, 5, 2, 6, 3, 9, 3, 4, 8,
        7, 1, 8, 4, 2, 4, 6, 1, 7, 5, 1, 9, 7, 2, 4, 2], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([0, 7, 3, 4, 8, 4, 0, 4, 4, 2, 2, 0, 0,

torch.Size([256, 10]) torch.Size([256])
tensor([6, 2, 4, 0, 2, 2, 9, 2, 6, 1, 7, 6, 2, 0, 9, 2, 0, 1, 6, 8, 5, 7, 1, 5,
        6, 3, 2, 8, 6, 7, 3, 1, 4, 5, 4, 0, 1, 7, 4, 6, 3, 5, 0, 9, 9, 0, 8, 1,
        0, 1, 9, 2, 8, 6, 5, 8, 2, 5, 7, 8, 0, 3, 5, 4, 1, 6, 9, 7, 6, 2, 4, 8,
        4, 1, 2, 3, 3, 1, 4, 9, 3, 7, 2, 9, 5, 9, 9, 2, 8, 4, 7, 7, 2, 9, 7, 4,
        5, 3, 0, 2, 8, 3, 7, 1, 0, 1, 1, 8, 7, 5, 8, 4, 8, 5, 8, 4, 8, 5, 2, 8,
        0, 2, 7, 4, 1, 0, 3, 5, 2, 6, 3, 9, 4, 8, 3, 8, 6, 2, 3, 8, 4, 9, 6, 2,
        0, 0, 6, 4, 0, 5, 6, 5, 3, 3, 7, 8, 3, 5, 4, 7, 8, 6, 1, 8, 3, 4, 1, 3,
        8, 6, 9, 3, 6, 4, 0, 9, 3, 4, 5, 2, 7, 9, 4, 8, 3, 7, 3, 7, 2, 6, 8, 5,
        2, 4, 3, 7, 7, 0, 5, 8, 0, 8, 8, 6, 6, 4, 4, 4, 3, 8, 3, 8, 8, 5, 6, 1,
        9, 1, 4, 8, 6, 3, 5, 2, 0, 6, 6, 6, 4, 4, 7, 1, 2, 1, 9, 3, 8, 6, 8, 0,
        7, 7, 8, 1, 0, 7, 9, 4, 5, 3, 7, 9, 5, 1, 2, 2], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([0, 5, 3, 3, 6, 6, 6, 9, 4, 2, 6, 1, 1,

torch.Size([256, 10]) torch.Size([256])
tensor([5, 1, 5, 0, 9, 9, 2, 2, 2, 9, 8, 5, 2, 2, 6, 1, 7, 5, 3, 0, 7, 5, 5, 1,
        7, 8, 0, 9, 4, 6, 7, 8, 3, 9, 5, 7, 7, 0, 6, 4, 0, 2, 4, 0, 7, 8, 0, 1,
        9, 4, 9, 7, 1, 6, 2, 0, 6, 4, 1, 5, 8, 3, 7, 6, 5, 0, 9, 3, 1, 8, 2, 3,
        2, 3, 4, 0, 0, 0, 4, 4, 2, 0, 9, 6, 9, 1, 0, 1, 8, 6, 9, 2, 5, 1, 2, 8,
        8, 9, 8, 6, 1, 4, 4, 1, 6, 0, 7, 3, 2, 8, 1, 8, 8, 7, 3, 3, 1, 6, 0, 7,
        3, 2, 0, 8, 9, 5, 1, 3, 7, 9, 3, 7, 6, 8, 2, 4, 5, 0, 5, 5, 4, 6, 0, 1,
        9, 1, 1, 1, 4, 2, 4, 3, 3, 0, 9, 3, 6, 5, 7, 8, 2, 5, 8, 6, 0, 8, 6, 0,
        0, 7, 2, 0, 4, 9, 7, 5, 5, 5, 2, 1, 7, 8, 3, 9, 6, 4, 3, 4, 5, 2, 8, 8,
        0, 3, 9, 2, 8, 8, 9, 2, 5, 9, 6, 7, 7, 0, 5, 6, 5, 9, 2, 0, 9, 3, 0, 6,
        0, 9, 0, 7, 5, 1, 4, 1, 2, 7, 4, 3, 3, 4, 8, 5, 1, 1, 5, 5, 1, 5, 5, 1,
        9, 0, 1, 8, 0, 9, 0, 0, 3, 9, 6, 2, 8, 5, 2, 2], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([8, 7, 7, 6, 8, 6, 2, 5, 1, 7, 2, 2, 4,

torch.Size([256, 10]) torch.Size([256])
tensor([5, 4, 0, 5, 9, 0, 0, 8, 3, 6, 7, 1, 0, 3, 3, 3, 4, 5, 9, 1, 5, 8, 5, 6,
        1, 1, 6, 9, 3, 5, 2, 9, 7, 9, 5, 5, 1, 8, 9, 2, 7, 6, 9, 7, 8, 7, 2, 7,
        1, 7, 6, 4, 8, 9, 8, 7, 8, 9, 3, 8, 5, 4, 0, 3, 0, 4, 6, 6, 2, 6, 8, 7,
        8, 6, 1, 8, 8, 9, 1, 8, 6, 2, 5, 0, 3, 1, 8, 0, 8, 4, 2, 1, 3, 0, 4, 0,
        6, 2, 5, 2, 8, 6, 6, 9, 6, 3, 9, 5, 0, 6, 7, 2, 5, 8, 4, 9, 4, 0, 2, 3,
        2, 9, 6, 6, 5, 4, 1, 6, 0, 2, 5, 6, 3, 3, 5, 0, 4, 8, 2, 7, 7, 8, 7, 5,
        6, 7, 4, 1, 4, 8, 4, 1, 4, 8, 5, 2, 6, 7, 6, 2, 7, 6, 7, 2, 1, 0, 5, 9,
        5, 6, 6, 2, 6, 0, 3, 9, 1, 9, 9, 0, 7, 3, 1, 8, 5, 3, 9, 6, 0, 4, 1, 8,
        6, 4, 5, 2, 3, 1, 5, 3, 7, 4, 4, 0, 3, 7, 5, 8, 0, 9, 3, 4, 6, 7, 9, 0,
        6, 6, 9, 5, 2, 7, 9, 6, 5, 0, 6, 6, 9, 5, 2, 2, 2, 8, 9, 7, 5, 9, 8, 8,
        3, 5, 4, 2, 8, 0, 1, 1, 2, 9, 5, 2, 7, 9, 9, 6], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([2, 1, 4, 4, 7, 9, 0, 2, 0, 0, 5, 5, 4,

torch.Size([256, 10]) torch.Size([256])
tensor([6, 0, 1, 0, 5, 3, 3, 4, 5, 5, 2, 7, 3, 3, 8, 3, 6, 9, 1, 3, 0, 0, 7, 9,
        7, 2, 4, 2, 5, 0, 6, 9, 8, 5, 9, 7, 6, 9, 7, 9, 6, 3, 5, 5, 0, 7, 2, 1,
        0, 5, 4, 9, 1, 3, 6, 7, 3, 7, 9, 2, 0, 0, 3, 1, 5, 0, 4, 1, 7, 9, 4, 5,
        7, 3, 9, 2, 8, 0, 3, 8, 7, 2, 7, 5, 5, 0, 2, 8, 7, 0, 5, 9, 5, 1, 3, 6,
        2, 8, 6, 9, 0, 1, 4, 8, 6, 1, 9, 8, 5, 2, 4, 3, 4, 4, 2, 9, 4, 3, 7, 9,
        0, 1, 0, 4, 3, 2, 8, 2, 9, 0, 3, 6, 1, 1, 2, 6, 9, 5, 0, 2, 3, 9, 7, 8,
        9, 0, 5, 0, 6, 2, 9, 4, 8, 4, 3, 4, 8, 2, 7, 7, 3, 0, 7, 0, 0, 6, 5, 3,
        0, 6, 6, 5, 6, 5, 7, 8, 3, 8, 3, 0, 7, 8, 7, 7, 8, 2, 1, 2, 0, 4, 0, 4,
        1, 4, 0, 3, 8, 6, 1, 8, 6, 7, 2, 2, 4, 1, 1, 5, 2, 5, 5, 1, 0, 8, 1, 4,
        2, 8, 9, 1, 6, 8, 3, 4, 2, 1, 3, 9, 9, 0, 4, 1, 9, 0, 1, 4, 3, 9, 0, 2,
        2, 1, 7, 0, 6, 6, 8, 3, 2, 6, 3, 6, 9, 2, 1, 2], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([7, 9, 5, 7, 6, 6, 0, 2, 1, 9, 8, 9, 4,

torch.Size([256, 10]) torch.Size([256])
tensor([4, 0, 5, 9, 4, 9, 6, 3, 8, 0, 0, 1, 5, 4, 6, 7, 7, 4, 4, 8, 4, 0, 0, 8,
        1, 8, 1, 8, 2, 9, 9, 4, 2, 7, 0, 6, 8, 2, 3, 5, 2, 1, 1, 8, 6, 3, 1, 3,
        3, 1, 6, 3, 6, 4, 5, 6, 1, 5, 3, 7, 3, 6, 1, 2, 8, 7, 8, 7, 3, 9, 1, 4,
        9, 8, 2, 5, 6, 5, 9, 2, 6, 8, 6, 1, 7, 5, 3, 3, 8, 6, 3, 8, 4, 9, 7, 1,
        1, 7, 6, 2, 7, 6, 0, 8, 2, 5, 5, 8, 8, 8, 6, 5, 3, 5, 1, 4, 3, 6, 4, 8,
        1, 2, 7, 0, 5, 9, 9, 0, 3, 8, 1, 3, 5, 7, 9, 7, 5, 2, 2, 9, 8, 7, 5, 6,
        1, 6, 3, 5, 6, 7, 7, 0, 8, 0, 7, 7, 8, 5, 4, 7, 0, 2, 7, 5, 4, 8, 5, 4,
        0, 7, 3, 2, 7, 6, 1, 2, 4, 0, 8, 3, 9, 1, 2, 8, 1, 7, 3, 0, 4, 2, 5, 4,
        8, 0, 3, 8, 9, 2, 9, 6, 1, 6, 8, 4, 4, 2, 1, 7, 2, 6, 2, 8, 8, 7, 4, 0,
        2, 0, 9, 4, 5, 4, 4, 8, 0, 1, 2, 2, 9, 8, 7, 6, 5, 6, 3, 5, 0, 8, 9, 6,
        8, 9, 6, 5, 0, 6, 9, 5, 2, 7, 9, 8, 1, 1, 5, 5], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([1, 6, 0, 8, 9, 1, 0, 1, 0, 7, 8, 8, 6,

torch.Size([256, 10]) torch.Size([256])
tensor([2, 1, 0, 2, 9, 2, 6, 5, 7, 7, 1, 5, 5, 0, 8, 1, 4, 1, 3, 9, 5, 1, 1, 6,
        9, 3, 4, 3, 6, 1, 7, 7, 8, 7, 1, 2, 5, 5, 6, 5, 7, 3, 7, 0, 6, 8, 1, 3,
        5, 8, 7, 4, 0, 2, 4, 4, 7, 7, 1, 0, 5, 6, 3, 2, 2, 2, 3, 4, 3, 3, 0, 5,
        1, 6, 2, 3, 4, 3, 1, 7, 0, 2, 4, 0, 5, 0, 2, 3, 0, 1, 5, 9, 8, 0, 5, 4,
        0, 2, 6, 9, 4, 6, 3, 6, 2, 2, 1, 5, 1, 0, 8, 1, 9, 7, 2, 8, 7, 2, 0, 0,
        9, 3, 9, 3, 9, 4, 0, 0, 9, 0, 3, 9, 0, 9, 9, 5, 1, 0, 5, 6, 8, 5, 0, 4,
        4, 1, 0, 5, 9, 6, 8, 9, 1, 2, 7, 6, 8, 4, 2, 8, 9, 4, 5, 0, 8, 5, 6, 8,
        0, 4, 0, 2, 5, 1, 6, 9, 0, 8, 9, 2, 0, 0, 4, 4, 9, 0, 6, 1, 5, 2, 0, 3,
        1, 4, 3, 2, 9, 1, 8, 0, 2, 8, 0, 1, 2, 0, 3, 1, 2, 7, 4, 5, 5, 8, 7, 2,
        5, 1, 8, 7, 5, 7, 6, 9, 1, 8, 2, 8, 1, 8, 4, 2, 0, 9, 0, 6, 1, 3, 0, 9,
        1, 8, 9, 9, 6, 7, 1, 4, 5, 4, 9, 9, 1, 9, 6, 3], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([8, 6, 7, 2, 2, 9, 9, 8, 9, 0, 1, 1, 1,

torch.Size([256, 10]) torch.Size([256])
tensor([9, 0, 0, 0, 8, 5, 0, 2, 3, 5, 3, 0, 4, 7, 5, 0, 0, 4, 8, 9, 6, 1, 2, 8,
        7, 3, 5, 0, 6, 2, 2, 3, 5, 3, 9, 0, 1, 3, 3, 0, 7, 8, 7, 3, 6, 5, 5, 8,
        4, 3, 3, 0, 3, 3, 5, 6, 7, 5, 4, 1, 6, 5, 4, 8, 4, 5, 7, 9, 8, 1, 6, 0,
        6, 6, 9, 2, 0, 0, 5, 9, 7, 3, 3, 8, 4, 1, 9, 6, 2, 8, 0, 0, 8, 8, 4, 2,
        4, 9, 7, 9, 0, 5, 7, 1, 1, 0, 6, 5, 9, 8, 9, 4, 8, 4, 6, 2, 6, 6, 2, 1,
        4, 1, 0, 0, 4, 7, 0, 9, 0, 9, 1, 7, 4, 4, 7, 7, 2, 9, 0, 3, 8, 5, 7, 9,
        2, 8, 3, 6, 7, 1, 6, 9, 2, 8, 8, 3, 2, 6, 1, 0, 3, 2, 4, 3, 4, 3, 8, 1,
        7, 4, 7, 6, 0, 8, 6, 4, 7, 7, 6, 5, 2, 6, 2, 5, 8, 5, 0, 4, 7, 4, 4, 9,
        5, 0, 3, 2, 1, 5, 0, 4, 7, 1, 4, 5, 8, 4, 2, 0, 9, 5, 1, 8, 5, 3, 2, 4,
        6, 8, 9, 0, 1, 3, 5, 0, 9, 5, 5, 4, 7, 3, 2, 0, 0, 5, 9, 8, 6, 1, 1, 6,
        4, 3, 8, 6, 8, 3, 8, 6, 5, 6, 5, 2, 0, 4, 8, 5], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([6, 5, 8, 1, 2, 1, 8, 8, 0, 7, 2, 9, 3,

torch.Size([256, 10]) torch.Size([256])
tensor([0, 7, 6, 9, 6, 4, 3, 5, 6, 5, 4, 6, 1, 9, 9, 3, 9, 3, 4, 3, 0, 6, 1, 8,
        0, 1, 2, 6, 5, 6, 7, 3, 2, 8, 8, 4, 1, 5, 3, 9, 8, 3, 5, 6, 5, 0, 0, 8,
        5, 9, 5, 9, 8, 3, 0, 0, 8, 1, 2, 2, 9, 8, 2, 6, 6, 2, 5, 9, 2, 0, 6, 4,
        0, 8, 5, 2, 5, 9, 0, 4, 4, 5, 5, 5, 1, 5, 2, 0, 6, 9, 6, 7, 1, 6, 6, 5,
        8, 7, 9, 2, 4, 1, 5, 8, 4, 3, 6, 6, 1, 5, 6, 4, 0, 1, 5, 4, 0, 9, 2, 4,
        3, 2, 7, 2, 5, 9, 8, 1, 2, 1, 4, 9, 8, 5, 9, 0, 5, 1, 8, 6, 0, 1, 1, 9,
        1, 1, 7, 2, 6, 6, 6, 2, 6, 0, 8, 5, 4, 4, 2, 2, 9, 2, 3, 6, 6, 8, 8, 4,
        3, 5, 2, 8, 9, 0, 7, 5, 9, 3, 5, 3, 4, 8, 4, 5, 4, 9, 7, 7, 6, 0, 8, 8,
        9, 0, 8, 7, 6, 5, 8, 6, 8, 2, 3, 7, 5, 3, 1, 9, 2, 3, 3, 1, 9, 4, 1, 2,
        2, 8, 5, 4, 9, 6, 7, 4, 4, 5, 4, 3, 2, 5, 1, 9, 9, 3, 0, 5, 8, 5, 5, 9,
        9, 8, 7, 8, 3, 9, 0, 6, 7, 4, 4, 0, 7, 2, 6, 6], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([9, 6, 4, 2, 8, 1, 4, 7, 0, 9, 8, 1, 2,

torch.Size([256, 10]) torch.Size([256])
tensor([6, 6, 2, 2, 4, 5, 1, 5, 8, 3, 4, 4, 7, 9, 0, 1, 8, 5, 7, 4, 7, 0, 8, 6,
        9, 5, 8, 2, 4, 1, 5, 9, 7, 7, 9, 4, 4, 5, 2, 2, 4, 6, 5, 7, 8, 6, 5, 3,
        8, 3, 6, 6, 4, 1, 6, 7, 0, 7, 1, 8, 2, 5, 1, 6, 3, 5, 7, 9, 7, 9, 0, 1,
        3, 4, 5, 5, 4, 2, 8, 7, 2, 6, 7, 5, 8, 9, 2, 5, 0, 0, 6, 8, 1, 7, 1, 4,
        3, 3, 9, 3, 4, 5, 8, 1, 7, 7, 5, 1, 4, 8, 3, 9, 6, 5, 5, 5, 6, 9, 7, 6,
        0, 0, 8, 8, 6, 3, 0, 7, 4, 5, 0, 2, 6, 3, 3, 9, 8, 7, 8, 5, 9, 3, 9, 5,
        1, 3, 8, 9, 2, 1, 7, 1, 4, 7, 5, 2, 5, 0, 3, 9, 9, 6, 3, 7, 8, 5, 5, 6,
        6, 8, 3, 6, 1, 0, 2, 9, 2, 3, 5, 4, 3, 3, 9, 8, 3, 4, 3, 2, 2, 4, 8, 5,
        8, 5, 8, 9, 7, 6, 9, 9, 7, 9, 5, 1, 9, 1, 2, 5, 7, 6, 2, 4, 4, 3, 7, 7,
        3, 6, 6, 2, 8, 9, 5, 6, 4, 9, 9, 7, 5, 1, 3, 2, 2, 8, 4, 4, 2, 5, 9, 1,
        4, 3, 2, 5, 5, 1, 3, 2, 0, 4, 5, 4, 6, 5, 3, 4], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([7, 9, 0, 2, 2, 6, 8, 8, 7, 6, 7, 8, 4,

torch.Size([256, 10]) torch.Size([256])
tensor([1, 5, 4, 3, 3, 7, 3, 6, 2, 4, 5, 6, 4, 2, 4, 3, 8, 4, 2, 4, 5, 9, 9, 7,
        7, 4, 0, 6, 3, 1, 9, 2, 8, 1, 4, 2, 0, 9, 4, 6, 6, 7, 9, 3, 0, 6, 0, 6,
        8, 1, 4, 3, 5, 4, 6, 8, 5, 9, 3, 3, 6, 2, 9, 9, 8, 7, 3, 0, 9, 9, 4, 0,
        1, 4, 6, 4, 9, 9, 1, 9, 2, 3, 2, 5, 2, 5, 7, 2, 9, 2, 7, 7, 9, 4, 2, 0,
        8, 8, 6, 2, 5, 8, 5, 5, 6, 9, 1, 6, 4, 5, 7, 4, 6, 1, 5, 6, 9, 5, 7, 0,
        3, 7, 7, 9, 9, 1, 7, 4, 7, 8, 9, 6, 4, 7, 3, 9, 8, 6, 8, 0, 9, 0, 5, 1,
        5, 0, 5, 8, 6, 6, 1, 2, 0, 7, 8, 3, 9, 4, 0, 1, 6, 1, 7, 8, 8, 9, 2, 9,
        5, 5, 4, 9, 1, 4, 1, 2, 7, 2, 0, 2, 9, 9, 3, 0, 6, 4, 4, 7, 1, 5, 5, 1,
        4, 9, 9, 8, 6, 4, 9, 3, 6, 7, 6, 4, 2, 5, 3, 4, 3, 5, 6, 1, 7, 3, 5, 5,
        4, 5, 4, 0, 9, 6, 9, 3, 8, 7, 1, 6, 6, 5, 9, 6, 9, 8, 4, 1, 2, 5, 1, 7,
        1, 7, 9, 3, 8, 9, 0, 3, 9, 4, 4, 3, 7, 3, 8, 6], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([0, 2, 6, 2, 4, 0, 5, 6, 6, 6, 6, 1, 7,

torch.Size([256, 10]) torch.Size([256])
tensor([0, 8, 2, 7, 8, 7, 2, 0, 1, 9, 2, 0, 0, 1, 2, 7, 3, 3, 0, 2, 3, 8, 2, 5,
        0, 2, 2, 6, 5, 0, 1, 4, 8, 7, 8, 6, 6, 2, 3, 1, 9, 4, 2, 9, 5, 5, 0, 0,
        4, 0, 4, 1, 6, 5, 3, 6, 0, 2, 0, 8, 9, 1, 6, 0, 3, 0, 3, 2, 9, 2, 1, 9,
        1, 8, 7, 8, 1, 8, 5, 8, 0, 6, 8, 1, 0, 3, 6, 1, 5, 5, 0, 5, 3, 7, 2, 8,
        9, 1, 5, 7, 0, 0, 0, 4, 5, 0, 2, 3, 5, 9, 3, 3, 7, 3, 3, 5, 8, 4, 8, 7,
        1, 7, 7, 1, 9, 5, 9, 4, 7, 5, 5, 3, 0, 1, 0, 4, 4, 3, 6, 1, 7, 4, 3, 0,
        2, 4, 2, 8, 9, 0, 9, 5, 4, 4, 7, 6, 4, 1, 7, 7, 1, 5, 0, 2, 5, 5, 1, 4,
        3, 9, 5, 7, 0, 1, 6, 4, 1, 6, 9, 2, 7, 6, 5, 4, 1, 0, 8, 6, 6, 6, 4, 6,
        8, 7, 3, 3, 9, 5, 6, 3, 1, 7, 6, 7, 8, 1, 8, 1, 9, 8, 9, 7, 2, 9, 1, 3,
        6, 0, 9, 3, 6, 7, 9, 0, 9, 4, 4, 6, 5, 4, 9, 6, 3, 3, 0, 4, 0, 5, 1, 2,
        5, 2, 0, 1, 0, 2, 3, 1, 7, 8, 5, 7, 5, 7, 9, 4], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([6, 2, 7, 4, 6, 1, 6, 2, 6, 3, 9, 0, 3,

torch.Size([256, 10]) torch.Size([256])
tensor([7, 1, 7, 7, 6, 4, 0, 7, 4, 6, 1, 7, 4, 5, 4, 6, 5, 1, 8, 6, 4, 2, 4, 7,
        5, 2, 1, 8, 0, 6, 0, 8, 8, 0, 1, 0, 2, 3, 9, 5, 5, 2, 7, 1, 1, 5, 0, 7,
        3, 7, 6, 5, 1, 7, 4, 4, 0, 9, 6, 3, 6, 0, 4, 9, 4, 1, 7, 6, 5, 7, 5, 8,
        2, 0, 8, 4, 2, 9, 3, 8, 2, 1, 3, 7, 7, 3, 2, 4, 5, 9, 1, 5, 2, 5, 2, 7,
        3, 8, 9, 0, 6, 5, 5, 4, 0, 0, 1, 7, 3, 4, 6, 3, 4, 7, 7, 3, 7, 2, 7, 2,
        5, 9, 7, 1, 4, 0, 2, 7, 2, 2, 4, 0, 8, 7, 2, 2, 2, 3, 1, 7, 8, 3, 4, 9,
        4, 1, 1, 1, 5, 0, 0, 2, 9, 9, 7, 2, 0, 5, 5, 5, 4, 9, 6, 5, 2, 5, 0, 9,
        1, 0, 1, 0, 8, 9, 4, 9, 3, 0, 6, 3, 1, 2, 8, 3, 4, 6, 5, 6, 5, 5, 0, 2,
        2, 3, 9, 1, 2, 0, 1, 2, 5, 1, 2, 9, 3, 4, 7, 4, 7, 6, 4, 5, 7, 3, 6, 5,
        1, 9, 8, 6, 3, 5, 6, 5, 0, 3, 1, 7, 7, 8, 2, 9, 2, 3, 9, 6, 7, 7, 7, 8,
        1, 6, 7, 5, 5, 2, 7, 4, 4, 2, 8, 6, 9, 1, 1, 0], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([2, 6, 7, 7, 1, 9, 9, 0, 1, 0, 3, 7, 6,

torch.Size([256, 10]) torch.Size([256])
tensor([2, 0, 2, 3, 0, 3, 8, 7, 8, 7, 7, 3, 5, 4, 1, 3, 1, 4, 2, 6, 7, 5, 5, 9,
        8, 5, 4, 8, 5, 1, 0, 9, 4, 1, 1, 9, 2, 1, 7, 6, 3, 8, 6, 3, 6, 5, 9, 1,
        0, 3, 3, 4, 3, 4, 5, 8, 8, 4, 8, 0, 0, 2, 0, 1, 1, 2, 4, 1, 2, 4, 6, 5,
        6, 3, 7, 9, 6, 8, 1, 2, 7, 3, 2, 3, 8, 2, 0, 9, 0, 5, 7, 8, 7, 5, 0, 7,
        0, 5, 4, 8, 5, 7, 5, 9, 2, 8, 4, 0, 0, 3, 6, 4, 9, 4, 5, 0, 6, 7, 3, 3,
        4, 3, 9, 7, 9, 3, 5, 4, 8, 0, 3, 7, 1, 6, 4, 8, 7, 5, 4, 1, 6, 9, 9, 8,
        9, 7, 7, 5, 3, 6, 9, 5, 6, 9, 5, 0, 5, 3, 8, 0, 1, 6, 6, 4, 2, 8, 9, 0,
        4, 0, 2, 2, 7, 2, 9, 5, 6, 1, 1, 1, 3, 6, 9, 4, 7, 4, 3, 8, 5, 0, 1, 4,
        1, 2, 2, 2, 5, 2, 1, 8, 4, 0, 2, 2, 4, 2, 0, 3, 7, 2, 0, 5, 6, 3, 4, 7,
        9, 3, 2, 3, 2, 8, 8, 3, 3, 0, 6, 0, 9, 8, 0, 6, 2, 4, 1, 1, 2, 9, 5, 2,
        3, 1, 6, 5, 2, 8, 8, 0, 0, 9, 9, 7, 7, 5, 9, 5], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([1, 6, 2, 1, 0, 8, 0, 7, 8, 4, 4, 4, 3,

torch.Size([256, 10]) torch.Size([256])
tensor([9, 7, 3, 9, 8, 5, 4, 3, 9, 5, 5, 2, 3, 9, 5, 9, 4, 9, 9, 3, 2, 6, 8, 5,
        1, 7, 9, 2, 3, 4, 6, 7, 7, 3, 3, 1, 4, 9, 1, 1, 2, 6, 7, 8, 0, 1, 2, 2,
        1, 9, 4, 1, 3, 8, 1, 0, 5, 7, 1, 1, 8, 6, 0, 3, 8, 8, 3, 9, 4, 7, 2, 8,
        0, 9, 6, 0, 2, 9, 7, 3, 2, 1, 6, 8, 6, 7, 4, 2, 6, 7, 6, 6, 1, 7, 0, 2,
        9, 6, 7, 3, 2, 0, 4, 8, 9, 4, 5, 1, 3, 3, 5, 2, 4, 3, 5, 1, 0, 0, 1, 0,
        3, 1, 3, 2, 0, 1, 0, 1, 7, 8, 3, 8, 2, 4, 1, 1, 0, 8, 3, 4, 6, 8, 1, 6,
        3, 0, 5, 2, 2, 5, 5, 5, 2, 4, 3, 3, 8, 1, 0, 1, 9, 0, 9, 9, 7, 3, 9, 4,
        8, 6, 6, 0, 4, 9, 9, 8, 7, 2, 3, 5, 3, 6, 7, 6, 0, 6, 2, 7, 9, 9, 4, 5,
        8, 4, 6, 5, 8, 7, 6, 9, 5, 7, 9, 4, 4, 0, 9, 2, 7, 1, 1, 2, 9, 1, 2, 2,
        0, 9, 5, 9, 3, 3, 5, 8, 0, 1, 8, 4, 8, 3, 3, 8, 6, 3, 8, 0, 4, 6, 3, 5,
        7, 5, 2, 9, 0, 0, 7, 2, 5, 6, 9, 2, 4, 6, 8, 3], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([7, 1, 2, 1, 9, 4, 3, 7, 1, 2, 8, 7, 9,

torch.Size([256, 10]) torch.Size([256])
tensor([9, 4, 4, 2, 4, 9, 4, 5, 6, 3, 2, 7, 4, 2, 9, 1, 5, 7, 1, 5, 9, 5, 5, 1,
        1, 1, 8, 1, 9, 4, 5, 1, 6, 0, 8, 5, 9, 8, 9, 9, 0, 4, 5, 1, 0, 7, 3, 9,
        0, 1, 4, 2, 1, 0, 1, 6, 7, 1, 7, 0, 4, 4, 7, 7, 2, 6, 0, 8, 6, 8, 2, 4,
        7, 2, 9, 1, 5, 7, 0, 4, 5, 2, 1, 2, 0, 7, 2, 4, 1, 8, 0, 9, 4, 9, 2, 2,
        7, 8, 0, 0, 2, 9, 1, 4, 4, 1, 8, 9, 8, 0, 0, 3, 5, 7, 3, 4, 0, 3, 8, 3,
        9, 3, 4, 8, 0, 9, 3, 0, 7, 0, 1, 3, 4, 3, 6, 7, 8, 7, 9, 2, 3, 7, 5, 0,
        8, 6, 9, 1, 8, 7, 1, 3, 4, 6, 7, 8, 1, 7, 7, 2, 6, 5, 5, 3, 6, 3, 0, 0,
        7, 4, 6, 6, 2, 4, 7, 6, 2, 4, 5, 4, 6, 8, 6, 3, 9, 2, 5, 5, 0, 8, 0, 2,
        6, 4, 9, 3, 3, 2, 4, 1, 2, 3, 4, 3, 9, 9, 8, 9, 6, 1, 4, 3, 3, 1, 4, 2,
        6, 7, 9, 2, 0, 0, 1, 8, 7, 7, 2, 2, 6, 2, 7, 7, 3, 6, 4, 6, 6, 2, 7, 9,
        1, 4, 5, 1, 6, 6, 9, 4, 5, 1, 2, 7, 7, 9, 2, 1], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([2, 3, 4, 8, 4, 8, 0, 6, 8, 4, 6, 6, 2,

torch.Size([256, 10]) torch.Size([256])
tensor([7, 5, 1, 2, 7, 3, 0, 2, 1, 5, 1, 0, 1, 0, 5, 6, 6, 8, 8, 1, 3, 4, 8, 9,
        3, 4, 8, 1, 1, 2, 7, 8, 6, 0, 7, 5, 3, 7, 7, 5, 1, 6, 7, 3, 3, 6, 7, 9,
        5, 1, 6, 3, 7, 7, 3, 4, 1, 9, 9, 0, 1, 8, 5, 8, 4, 6, 6, 8, 5, 0, 4, 9,
        8, 8, 0, 7, 7, 2, 7, 2, 1, 0, 2, 4, 0, 6, 1, 2, 5, 6, 0, 3, 5, 8, 0, 1,
        2, 2, 8, 8, 7, 7, 7, 1, 3, 0, 4, 8, 3, 7, 2, 5, 5, 0, 5, 4, 4, 7, 0, 8,
        7, 7, 2, 9, 6, 2, 2, 8, 2, 4, 2, 0, 1, 6, 1, 9, 8, 0, 3, 5, 0, 0, 6, 3,
        1, 3, 1, 3, 5, 9, 3, 1, 4, 9, 4, 9, 0, 8, 7, 3, 1, 6, 0, 0, 0, 5, 8, 8,
        4, 3, 6, 1, 3, 5, 5, 8, 2, 8, 7, 8, 5, 6, 4, 9, 2, 5, 9, 5, 2, 8, 5, 3,
        6, 3, 9, 4, 0, 7, 4, 1, 6, 7, 7, 0, 1, 3, 4, 5, 3, 2, 9, 5, 6, 2, 5, 8,
        5, 4, 9, 4, 0, 4, 2, 6, 0, 4, 9, 0, 9, 3, 8, 9, 9, 5, 0, 2, 2, 0, 9, 7,
        5, 5, 1, 4, 0, 4, 7, 2, 8, 1, 4, 0, 3, 2, 6, 5], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([5, 9, 5, 0, 4, 6, 4, 2, 3, 0, 5, 0, 2,

torch.Size([256, 10]) torch.Size([256])
tensor([7, 3, 2, 1, 9, 2, 9, 4, 9, 5, 0, 0, 2, 7, 0, 3, 9, 0, 3, 2, 2, 1, 2, 4,
        4, 9, 1, 3, 9, 8, 8, 8, 2, 4, 1, 2, 1, 9, 8, 3, 5, 3, 5, 4, 7, 5, 6, 9,
        2, 8, 6, 6, 9, 0, 4, 0, 8, 1, 2, 7, 1, 2, 6, 6, 4, 4, 9, 4, 9, 4, 5, 3,
        8, 2, 9, 8, 5, 7, 5, 6, 6, 6, 8, 6, 4, 0, 9, 8, 6, 2, 6, 3, 9, 2, 0, 9,
        8, 6, 9, 2, 4, 5, 9, 9, 0, 7, 9, 7, 1, 1, 6, 8, 0, 1, 3, 7, 3, 1, 6, 8,
        7, 7, 1, 7, 0, 1, 2, 3, 8, 5, 0, 1, 3, 0, 9, 4, 2, 5, 2, 0, 0, 4, 5, 9,
        7, 5, 0, 6, 8, 8, 0, 5, 4, 0, 6, 8, 1, 0, 2, 7, 0, 0, 6, 2, 5, 7, 2, 9,
        1, 7, 2, 7, 0, 2, 3, 0, 9, 0, 7, 5, 1, 7, 7, 4, 1, 6, 8, 6, 4, 2, 4, 0,
        3, 4, 1, 1, 8, 4, 5, 5, 8, 8, 5, 6, 5, 0, 9, 5, 9, 9, 8, 6, 9, 1, 7, 0,
        7, 3, 6, 5, 3, 5, 1, 3, 4, 0, 4, 1, 2, 5, 0, 6, 9, 4, 3, 8, 4, 0, 0, 8,
        9, 0, 2, 9, 4, 5, 5, 9, 1, 3, 2, 4, 3, 0, 4, 8], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([5, 5, 0, 8, 2, 1, 0, 3, 8, 1, 9, 0, 3,

torch.Size([256, 10]) torch.Size([256])
tensor([6, 6, 1, 5, 7, 6, 6, 8, 5, 7, 2, 6, 0, 7, 5, 2, 9, 0, 8, 7, 1, 4, 6, 1,
        0, 8, 6, 1, 0, 9, 9, 6, 5, 4, 3, 6, 6, 2, 3, 8, 1, 5, 2, 7, 9, 3, 3, 2,
        7, 1, 0, 6, 3, 0, 3, 5, 8, 8, 7, 4, 2, 8, 4, 4, 8, 6, 9, 6, 8, 6, 6, 4,
        7, 5, 2, 6, 9, 1, 5, 0, 7, 8, 9, 7, 2, 8, 5, 0, 1, 1, 1, 6, 8, 1, 1, 1,
        9, 7, 2, 1, 9, 9, 2, 2, 8, 6, 6, 6, 5, 2, 6, 2, 8, 6, 7, 8, 3, 9, 0, 6,
        4, 2, 3, 2, 8, 4, 8, 1, 2, 4, 8, 6, 2, 9, 8, 9, 2, 1, 4, 4, 9, 3, 8, 6,
        5, 7, 2, 1, 5, 8, 2, 6, 5, 1, 2, 3, 0, 5, 3, 1, 5, 6, 0, 1, 9, 5, 0, 1,
        7, 6, 2, 5, 7, 8, 0, 0, 4, 1, 1, 3, 3, 3, 4, 1, 9, 0, 7, 4, 8, 9, 7, 7,
        3, 6, 0, 2, 0, 1, 1, 5, 9, 8, 1, 2, 7, 7, 6, 2, 6, 5, 6, 9, 3, 8, 2, 0,
        3, 4, 4, 1, 1, 2, 2, 7, 1, 9, 3, 0, 7, 3, 9, 7, 0, 6, 6, 1, 2, 9, 2, 1,
        7, 4, 3, 7, 8, 1, 8, 1, 1, 2, 7, 1, 1, 9, 9, 9], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([0, 8, 0, 5, 5, 5, 5, 5, 5, 1, 7, 0, 2,

torch.Size([256, 10]) torch.Size([256])
tensor([9, 1, 9, 8, 2, 6, 4, 3, 1, 6, 7, 6, 6, 3, 8, 0, 9, 6, 4, 0, 6, 4, 4, 1,
        4, 2, 4, 2, 3, 4, 2, 8, 9, 7, 1, 1, 3, 5, 6, 7, 0, 6, 2, 9, 2, 7, 8, 6,
        7, 9, 7, 8, 9, 9, 1, 9, 4, 2, 6, 1, 1, 5, 1, 4, 1, 0, 4, 4, 7, 8, 1, 9,
        8, 5, 0, 0, 4, 1, 0, 9, 6, 3, 4, 2, 6, 8, 1, 1, 1, 3, 2, 9, 4, 5, 6, 2,
        2, 4, 7, 8, 4, 0, 1, 0, 8, 4, 7, 5, 0, 1, 3, 5, 3, 7, 2, 2, 0, 6, 3, 3,
        1, 8, 4, 8, 5, 9, 7, 3, 4, 8, 7, 8, 5, 6, 8, 5, 2, 3, 1, 8, 6, 9, 6, 8,
        2, 5, 5, 9, 5, 4, 3, 3, 8, 5, 7, 6, 0, 7, 9, 8, 7, 4, 9, 4, 4, 5, 6, 3,
        7, 5, 1, 7, 3, 0, 6, 8, 5, 1, 6, 7, 1, 5, 5, 7, 6, 3, 0, 2, 9, 3, 0, 2,
        9, 1, 7, 8, 9, 1, 6, 0, 6, 6, 5, 6, 8, 9, 7, 2, 5, 8, 5, 8, 6, 0, 4, 9,
        3, 6, 2, 7, 5, 7, 6, 4, 6, 1, 9, 1, 5, 2, 5, 5, 6, 5, 3, 5, 8, 4, 9, 9,
        5, 1, 5, 8, 7, 3, 6, 8, 3, 5, 6, 6, 2, 2, 6, 8], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([1, 3, 9, 8, 7, 2, 5, 6, 9, 1, 5, 5, 4,

torch.Size([256, 10]) torch.Size([256])
tensor([9, 2, 4, 9, 5, 4, 4, 1, 8, 6, 4, 6, 2, 8, 8, 9, 0, 1, 9, 3, 0, 4, 7, 6,
        5, 9, 6, 2, 9, 2, 2, 2, 2, 5, 9, 4, 9, 8, 8, 4, 1, 2, 0, 0, 5, 6, 3, 9,
        5, 9, 3, 2, 2, 7, 2, 4, 0, 5, 3, 4, 1, 1, 6, 2, 5, 3, 6, 9, 5, 9, 0, 5,
        6, 3, 4, 3, 2, 4, 9, 3, 3, 4, 1, 2, 2, 7, 2, 3, 2, 7, 8, 1, 5, 8, 4, 4,
        2, 8, 2, 1, 8, 2, 3, 9, 6, 8, 3, 7, 8, 0, 0, 8, 1, 3, 2, 1, 0, 5, 2, 8,
        3, 4, 3, 5, 6, 3, 2, 0, 5, 9, 1, 0, 6, 4, 0, 1, 6, 8, 7, 0, 6, 4, 6, 9,
        6, 6, 7, 2, 1, 6, 7, 9, 4, 7, 2, 1, 8, 7, 5, 7, 3, 2, 8, 4, 0, 4, 9, 6,
        1, 0, 2, 7, 8, 5, 6, 9, 5, 4, 6, 5, 2, 3, 0, 3, 2, 3, 2, 9, 4, 1, 5, 4,
        7, 8, 5, 2, 4, 6, 8, 1, 3, 4, 3, 5, 7, 8, 5, 6, 1, 7, 8, 3, 4, 3, 7, 8,
        0, 9, 8, 3, 2, 0, 3, 5, 4, 4, 5, 7, 7, 5, 1, 3, 1, 2, 1, 9, 9, 1, 2, 7,
        4, 5, 9, 3, 3, 1, 8, 9, 3, 1, 1, 6, 5, 9, 5, 9], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([4, 4, 5, 9, 2, 6, 6, 8, 4, 9, 8, 7, 6,

torch.Size([256, 10]) torch.Size([256])
tensor([8, 3, 6, 4, 5, 0, 6, 7, 6, 7, 5, 8, 2, 8, 1, 1, 1, 0, 9, 7, 5, 5, 6, 0,
        8, 0, 5, 2, 3, 2, 1, 8, 3, 9, 3, 8, 4, 1, 4, 9, 7, 9, 6, 0, 7, 9, 7, 0,
        1, 5, 7, 9, 5, 5, 2, 8, 8, 7, 7, 8, 6, 0, 8, 9, 3, 6, 0, 8, 3, 4, 4, 4,
        8, 8, 6, 4, 0, 4, 3, 5, 9, 8, 8, 2, 7, 4, 7, 6, 5, 9, 1, 7, 1, 0, 0, 2,
        0, 0, 2, 8, 4, 0, 8, 9, 8, 6, 1, 8, 1, 2, 1, 2, 1, 3, 5, 6, 8, 4, 5, 0,
        2, 5, 5, 1, 8, 5, 8, 9, 4, 3, 1, 3, 6, 7, 0, 5, 2, 9, 0, 4, 0, 4, 9, 5,
        6, 9, 8, 7, 8, 9, 8, 6, 1, 8, 4, 2, 5, 4, 4, 4, 1, 0, 6, 7, 2, 9, 0, 0,
        8, 5, 9, 7, 7, 7, 3, 1, 4, 2, 5, 7, 5, 5, 7, 2, 7, 2, 8, 7, 1, 5, 7, 5,
        4, 3, 4, 0, 2, 7, 6, 5, 9, 0, 8, 5, 2, 0, 3, 8, 5, 3, 4, 6, 4, 0, 0, 9,
        9, 7, 1, 1, 9, 6, 2, 7, 8, 7, 5, 2, 7, 0, 0, 5, 5, 1, 7, 1, 9, 1, 6, 7,
        0, 0, 1, 2, 2, 6, 8, 3, 3, 0, 6, 6, 7, 6, 2, 5], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([0, 9, 8, 8, 5, 7, 8, 1, 9, 3, 7, 6, 7,

torch.Size([256, 10]) torch.Size([256])
tensor([0, 5, 8, 3, 1, 3, 3, 2, 2, 6, 1, 3, 0, 5, 4, 3, 5, 7, 3, 7, 3, 6, 0, 1,
        0, 9, 7, 5, 1, 6, 5, 1, 9, 4, 1, 4, 0, 2, 7, 6, 7, 4, 3, 3, 1, 5, 3, 8,
        9, 4, 0, 5, 8, 1, 7, 3, 2, 9, 4, 4, 7, 3, 3, 0, 9, 7, 5, 4, 5, 5, 2, 0,
        5, 7, 1, 1, 1, 5, 9, 3, 9, 3, 6, 7, 3, 4, 2, 5, 5, 6, 6, 8, 8, 9, 8, 0,
        4, 2, 4, 2, 5, 2, 7, 9, 7, 8, 6, 3, 7, 3, 1, 1, 0, 5, 6, 9, 5, 4, 9, 0,
        0, 3, 7, 6, 8, 3, 6, 6, 3, 3, 8, 7, 6, 9, 6, 0, 7, 5, 0, 7, 3, 0, 0, 9,
        1, 3, 1, 7, 0, 2, 4, 8, 0, 7, 5, 1, 8, 2, 6, 5, 1, 4, 1, 0, 5, 7, 6, 2,
        3, 0, 5, 2, 6, 3, 1, 7, 9, 4, 6, 6, 3, 6, 5, 7, 3, 2, 0, 0, 7, 3, 0, 8,
        7, 3, 3, 1, 2, 0, 5, 1, 9, 4, 9, 2, 4, 2, 6, 0, 5, 1, 2, 9, 5, 8, 0, 6,
        0, 8, 4, 7, 7, 5, 7, 6, 2, 8, 4, 7, 8, 0, 2, 6, 7, 7, 6, 9, 6, 4, 5, 2,
        7, 4, 9, 0, 7, 3, 2, 5, 0, 6, 9, 0, 4, 0, 1, 2], device='cuda:0')
torch.Size([256, 10]) torch.Size([256])
tensor([7, 9, 2, 3, 4, 5, 9, 2, 9, 0, 3, 1, 5,

torch.Size([96, 10]) torch.Size([96])
tensor([8, 2, 5, 1, 2, 9, 8, 2, 9, 9, 7, 5, 3, 2, 5, 7, 7, 2, 3, 6, 1, 8, 0, 3,
        9, 5, 6, 7, 7, 3, 7, 0, 8, 9, 2, 6, 5, 4, 5, 2, 5, 9, 4, 9, 7, 7, 7, 6,
        2, 5, 0, 4, 1, 6, 3, 2, 7, 0, 4, 8, 4, 5, 1, 7, 7, 2, 6, 0, 0, 9, 9, 3,
        8, 9, 3, 9, 7, 6, 3, 1, 6, 2, 5, 2, 1, 4, 9, 2, 2, 3, 1, 1, 3, 1, 1, 6],
       device='cuda:0')


NameError: name 'evaluate_accuracy' is not defined