In [None]:
# pylint: disable=all
if __name__ == '__main__':
    import torch
    import torchvision
    from torchvision import transforms
    import torch.nn.functional as F
    from torch import optim, nn
    import matplotlib.pyplot as plt
    import numpy as np

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    batch_size = 50
    epoch_num = 20
    disp_interval = 200

    folder_path = './CV2023_HW3B'
    model_path = '/cifar_net_N4.pth'

    trainset = torchvision.datasets.CIFAR10(root=folder_path+'/CIFAR10_data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root=folder_path+'/CIFAR10_data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    def imshow(img):
        img = img.cpu()
        img = img / 2 + 0.5
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
        plt.show(block=False)
        plt.pause(2)
        plt.close()

    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
            self.bn1 = nn.BatchNorm2d(32)
            self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
            self.bn2 = nn.BatchNorm2d(64)
            self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
            self.bn3 = nn.BatchNorm2d(128)
            self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
            self.bn4 = nn.BatchNorm2d(256)
            self.pool = nn.MaxPool2d(2, 2)
            self.dropout = nn.Dropout(0.5)
            self.fc1 = nn.Linear(256 * 2 * 2, 512)
            self.fc2 = nn.Linear(512, 10)

        def forward(self, x):
            x = self.pool(F.relu(self.bn1(self.conv1(self.dropout(x)))))
            x = self.pool(F.relu(self.bn2(self.conv2(self.dropout(x)))))
            x = self.pool(F.relu(self.bn3(self.conv3(self.dropout(x)))))
            x = self.pool(F.relu(self.bn4(self.conv4(self.dropout(x)))))
            
            x = x.view(-1, 256 * 2 * 2) # Update the size for reshaping
            x = F.relu(self.fc1(self.dropout(x)))
            x = self.fc2(x)
            return x

    if 1:
        net = Net().to(device)

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(net.parameters(), lr=0.001)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, verbose=True)

        for epoch in range(epoch_num):
            running_loss = 0.0
            for i, data in enumerate(trainloader, 0):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()

                outputs = net(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                if i % disp_interval == disp_interval - 1:
                    print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / disp_interval:.3f}')
                    running_loss = 0.0
            scheduler.step(loss)

        print('Finished Training')
        torch.save(net.state_dict(), model_path)

    dataiter = iter(testloader)
    images, labels = next(dataiter)
    images, labels = images.to(device), labels.to(device)
    imshow(torchvision.utils.make_grid(images))
    print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

    net = Net().to(device)
    net.load_state_dict(torch.load(model_path))
    outputs = net(images)
    _, predicted = torch.max(outputs, 1)

    print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}' for j in range(batch_size)))

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

    correct_pred = {classname: 0 for classname in classes}
    total_pred = {classname: 0 for classname in classes}

    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predictions = torch.max(outputs, 1)
            for label, prediction in zip(labels, predictions):
                if label == prediction:
                    correct_pred[classes[label]] += 1
                total_pred[classes[label]] += 1

    for classname, correct_count in correct_pred.items():
        accuracy = 100 * float(correct_count) / total_pred[classname]
        print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

    conf_matrix = np.zeros((10, 10))
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs, 1)
            for i in range(len(labels)):
                conf_matrix[labels[i]][predicted[i]] += 1
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.matshow(conf_matrix, cmap=plt.cm.Blues, alpha=0.5)
    for i in range(conf_matrix.shape[0]):
        for j in range(conf_matrix.shape[1]):
            ax.text(x=j, y=i, s=int(conf_matrix[i, j]), va='center', ha='center', size='xx-large')

    ax.set_xticks(np.arange(len(classes)))
    ax.set_yticks(np.arange(len(classes)))
    ax.set_xticklabels(classes, rotation=90)
    ax.set_yticklabels(classes)

    ax.set_xlabel('Predicted')
    ax.set_ylabel('True')
    plt.show(block=False)   # show the image without blocking the code
    plt.pause(2)            # pause the code execution for 1 second
    plt.close()             # close the image

Files already downloaded and verified
Files already downloaded and verified
[1,   200] loss: 2.073
[1,   400] loss: 1.885
[1,   600] loss: 1.812
[1,   800] loss: 1.768
[1,  1000] loss: 1.722
[2,   200] loss: 1.695
[2,   400] loss: 1.662
[2,   600] loss: 1.643
[2,   800] loss: 1.640
[2,  1000] loss: 1.606
[3,   200] loss: 1.599
[3,   400] loss: 1.583
[3,   600] loss: 1.576
[3,   800] loss: 1.594
[3,  1000] loss: 1.564
[4,   200] loss: 1.565
[4,   400] loss: 1.552
[4,   600] loss: 1.540
[4,   800] loss: 1.555
[4,  1000] loss: 1.514
[5,   200] loss: 1.524
[5,   400] loss: 1.534
[5,   600] loss: 1.516
[5,   800] loss: 1.505
[5,  1000] loss: 1.506
[6,   200] loss: 1.505
[6,   400] loss: 1.497
[6,   600] loss: 1.503
[6,   800] loss: 1.480
[6,  1000] loss: 1.462
[7,   200] loss: 1.472
[7,   400] loss: 1.482
[7,   600] loss: 1.489
[7,   800] loss: 1.463
[7,  1000] loss: 1.461
Epoch 00007: reducing learning rate of group 0 to 1.0000e-04.
[8,   200] loss: 1.431
[8,   400] loss: 1.416
[8,   600] 