In [1]:
import dill
import pickle
import struct
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
from torch import nn
from torch import optim
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

# tensorboard --logdir=/opt/logs --port=6007
writer = SummaryWriter(log_dir="/opt/logs/task2-mnist-torch", flush_secs=30)
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"device: {device}")

device: cuda


In [2]:
# 用于适配interview接口
class CIFAR_Net(nn.Module):
    @staticmethod
    def unpickle(file):
        with open(file, "rb") as f:
            dict = pickle.load(f, encoding="bytes")
        return dict

    @staticmethod
    def one_hot(labels, num_classes):
        one_hot_labels = np.zeros((len(labels), num_classes))
        for i in range(len(labels)):
            one_hot_labels[i, labels[i]] = 1
        return one_hot_labels

    def interview(self, eval_datafile_path, device):
        data_batch = self.unpickle(eval_datafile_path)
        eval_images, eval_labels = data_batch[b"data"], data_batch[b"labels"]
        eval_images = eval_images.reshape(-1, 3, 32, 32)
        eval_labels = self.one_hot(eval_labels, 10)
        eval_images = torch.from_numpy(eval_images).float()
        eval_labels = torch.from_numpy(eval_labels).float()
        eval_images, eval_labels = eval_images.to(device), eval_labels.to(device)

        self.eval()
        with torch.no_grad():
            pred = self.forward(eval_images)
            accuracy = torch.sum(torch.argmax(pred, dim=1) == torch.argmax(eval_labels, dim=1)).item()
        return accuracy / len(eval_labels) * 100

    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            #
            # Block 1: (3,32,32) -> (64,32,32) -> (64,16,16)
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            #
            # Block 2: (64,16,16) -> (128,16,16) -> (128,8,8)
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            #
            # Block 3: (128,8,8) -> (256,8,8) -> (256,4,4)
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),
            #
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 10),
        )

        # 权重初始化
        for m in self.net.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        return self.net(x)

In [3]:
model_cifar_path = "task2-cifar.pkl"
with open(model_cifar_path, "rb") as f:
    model_cifar = dill.load(f).to(device)

In [4]:
model = CIFAR_Net()

In [5]:
model.load_state_dict(model_cifar.state_dict())

<All keys matched successfully>

In [6]:
with open(model_cifar_path, "wb") as f:
    dill.dump(model, f)