In [None]:
from PIL import Image
import torch
from torch import utils
import os
import struct
import numpy as np
from torch import nn
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix, accuracy_score, ConfusionMatrixDisplay

In [None]:

def read_kmnist_train(path, is_train='train'):
    labels_path = os.path.join(path,f'{is_train}-labels-idx1-ubyte')
    images_path = os.path.join(path,f'{is_train}-images-idx3-ubyte')
    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II',lbpath.read(8))
        labels = np.fromfile(lbpath,dtype=np.uint8)
    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack('>IIII',imgpath.read(16))
        images = np.fromfile(imgpath,dtype=np.uint8).reshape(len(labels), 1, 28, 28)
    return images.astype(np.float32), labels, is_train

class kmnistDataset(utils.data.Dataset):
  def __init__(self, file_path, is_train):
    self.features, self.labels, process_type = read_kmnist_train(file_path, is_train)
    print("read "+str(len(self.features))+f' {process_type} examples')
  
  def __getitem__(self, idx):
    return self.features[idx], self.labels[idx]

  def __len__(self):
    return len(self.features)

whole_set = kmnistDataset("./data", is_train="train")
length = len(whole_set)
train_size,validate_size=int(0.8*length),int(0.2*length)
batch_size, lr, num_epochs = 64, 0.1, 10
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers

train_dataset,val_dataset=torch.utils.data.random_split(whole_set,[train_size,validate_size])
train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw)

val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw)

In [None]:
# model 4
net_4 = nn.Sequential(nn.Flatten(),
                    nn.Linear(784, 1024),
                    nn.ReLU(),
                    nn.Linear(1024, 512),
                    nn.ReLU(),
                    nn.Linear(512, 128),
                    nn.ReLU(),
                    nn.Linear(128, 10)
                    )
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)
net_4.apply(init_weights)
net = net_4

In [None]:
# model 6
net_6 = nn.Sequential(nn.Flatten(),
                    nn.Linear(784, 1024),
                    nn.ReLU(),
                    nn.Linear(1024, 512),
                    nn.ReLU(),
                    nn.Linear(512, 256),
                    nn.ReLU(),
                    nn.Linear(256, 128),
                    nn.ReLU(),
                    nn.Linear(128, 64),
                    nn.ReLU(),
                    nn.Linear(64, 10)
                    )
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)
net_6.apply(init_weights)
net = net_6

In [None]:
# model 20
def sub_net_6():
  return nn.Sequential(#6
                    nn.Linear(1024, 2048),
                    nn.ReLU(),
                    nn.Linear(2048, 4096),
                    nn.ReLU(),
                    nn.Linear(4096, 8192),
                    nn.ReLU(),
                    nn.Linear(8192, 4096),
                    nn.ReLU(),
                    nn.Linear(4096, 2048),
                    nn.ReLU(),
                    nn.Linear(2048, 1024),
                    nn.ReLU())

net_20 = nn.Sequential(nn.Flatten(),
                    nn.Linear(784, 1024),
                    nn.ReLU(),
                    sub_net_6(),
                    nn.Linear(1024, 512),
                    nn.ReLU(),
                    nn.Linear(512, 256),
                    nn.ReLU(),
                    nn.Linear(256, 512),
                    nn.ReLU(),
                    nn.Linear(512, 1024),
                    nn.ReLU(),
                    nn.Linear(1024, 512),
                    nn.ReLU(),
                    nn.Linear(512, 256),
                    nn.ReLU(),
                    nn.Linear(256, 128),
                    nn.ReLU(),
                    nn.Linear(128, 64),
                    nn.ReLU(),
                    nn.Linear(64, 32),
                    nn.ReLU(),
                    nn.Linear(32, 16),
                    nn.ReLU(),
                    nn.Linear(256, 128),
                    nn.ReLU(),
                    nn.Linear(512, 128),
                    nn.ReLU(),
                    nn.Linear(128, 10)
                    )
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)
net_20.apply(init_weights)
net = net_20

In [None]:
# model 25, but this 25 layers model is prone to memory overflow
def sub_net_5():
  return nn.Sequential(#6
                    nn.Linear(1024, 2048),
                    nn.ReLU(),
                    nn.Linear(2048, 4096),
                    nn.ReLU(),
                    nn.Linear(4096, 8192),
                    nn.ReLU(),
                    nn.Linear(8192, 2048),
                    nn.ReLU(),
                    nn.Linear(2048, 1024),
                    nn.ReLU())

net_25 = nn.Sequential(nn.Flatten(),
                    nn.Linear(784, 1024),
                    nn.ReLU(),
                    sub_net_6(),
                    nn.Linear(1024, 512),
                    nn.ReLU(),
                    nn.Linear(512, 256),
                    nn.ReLU(),
                    nn.Linear(256, 512),
                    nn.ReLU(),
                    nn.Linear(512, 1024),
                    nn.ReLU(),
                    sub_net_5(),
                    nn.Linear(1024, 512),
                    nn.ReLU(),
                    nn.Linear(512, 256),
                    nn.ReLU(),
                    nn.Linear(256, 128),
                    nn.ReLU(),
                    nn.Linear(128, 64),
                    nn.ReLU(),
                    nn.Linear(64, 32),
                    nn.ReLU(),
                    nn.Linear(32, 16),
                    nn.ReLU(),
                    nn.Linear(256, 128),
                    nn.ReLU(),
                    nn.Linear(512, 128),
                    nn.ReLU(),
                    nn.Linear(128, 10)
                    )
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)
net_25.apply(init_weights)
net = net_25

In [None]:
# train
if not os.path.exists("./weight/"):
  os.makedirs("./weight/")

loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=lr)

def evaluate_accuracy_gpu(net, data_iter, device=None):
  if isinstance(net, torch.nn.Module):
    net.eval() 
    if not device:
      device = next(iter(net.parameters())).device
  metric = Accumulator(2) 
  for X,y in data_iter:
    if isinstance(X, list): 
      X = [x.to(device) for x in X]
    else:
      X = X.to(device)
    y = y.to(device)
    metric.add(accuracy(net(X),y), y.numel()) 
  return metric[0]/metric[1] 

def get_dataloader_workers():
    return 4

class Accumulator:
  def __init__(self, n): 
    self.data = [0.0] * n 
  def add(self, *args):
    self.data = [a + float(b) for a, b in zip(self.data, args)]
  def reset(self):
    self.data = [0.0] * len(self.data)
  def __getitem__(self, idx):
    return self.data[idx]

def accuracy(y_hat, y):
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

def train_cls_process(net, train_iter, test_iter, num_epochs, lr, device, save_path=None):
    result = ""
    def init_weights(m):
        if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform_(m.weight)

    net.apply(init_weights) 
    print("training on", device)
    net.to(device) 
    optimizer = torch.optim.SGD(net.parameters(), lr=lr) 
    loss = nn.CrossEntropyLoss() 
    best_acc = 0
    for epoch in range(num_epochs):
        metric = Accumulator(3) 
        net.train()
        for i, (X,y) in enumerate(train_iter):
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()
            metric.add(l*X.shape[0], accuracy(y_hat, y), X.shape[0])
            train_l = metric[0]/metric[2]
            train_acc = metric[1]/metric[2]
        test_acc = evaluate_accuracy_gpu(net, test_iter)

    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(net.state_dict(), f"./weights/mlp.pth")

        print(f'epoch {epoch+1}, train loss {train_l:f}, val_acc {test_acc:f}')
        result += f"{train_l}, {test_acc}\n"
    print(f'train loss {train_l:.3f}, train acc {train_acc:.3f}, '
            f'val acc {test_acc:.3f}, best acc {best_acc:.3f}' )
    with open(f"./test/mlp__pro_params.txt", "w") as f:
        f.write(result)

device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") 
train_cls_process(net, train_loader, val_loader, num_epochs, lr=0.001,  device=device)

In [None]:
# test
if __name__ == '__main__':
    test_imgs, test_labels, is_train = read_kmnist_train(
        "./data",is_train="t10k"
    )
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    pre_clas = []
    total_loss = 0
    model = net.to(device)
    model_weight_path = f"./weights/mlp.pth"
    model.load_state_dict(torch.load(model_weight_path,map_location=device))
    model.eval()
    for i in range(len(test_imgs)):
        test_img, test_label = torch.Tensor(test_imgs[i]), torch.from_numpy(np.array(test_labels[i]))
        loss_function = torch.nn.CrossEntropyLoss()
        test_img = torch.unsqueeze(test_img, dim=0)

        with torch.no_grad():
            output = torch.squeeze(model(test_img.to(device))).cpu()
            predict = torch.softmax(output, dim=0)
            pre_cla = torch.argmax(predict)# 
            loss = loss_function(output, test_label.long())
            total_loss += loss

        pre_clas.append(str(pre_cla.numpy()))

    avg_loss = total_loss/int(len(test_imgs))
    loss_ = str(avg_loss.numpy())

    result = ",".join(pre_clas)


    if not os.path.exists("./test/"):
        os.makedirs("./test/")
    
    with open(f"./test/mlp_result.txt", "w") as f:
        f.write(result)

    with open(f"./test/mlp_avgloss.txt", "w") as f:
        f.write(loss_)
    print(f"the test based on mlp is done, and its average loss value is {loss_}")


In [None]:
y_pred_mlp = np.loadtxt('./test/best_mlp_pred_result.txt', dtype=int, delimiter=',')
y_true = np.loadtxt('./test_label.txt', dtype=int, delimiter=',')
mlp = confusion_matrix(y_true, y_pred_mlp)
print("Confusion matrix of MLP:")
print(mlp)
print("Accuracy: ", accuracy_score(y_true, y_pred_mlp))
disp = ConfusionMatrixDisplay(confusion_matrix = mlp)
disp.plot(
    include_values=True,            # 混淆矩阵每个单元格上显示具体数值
    cmap="viridis",                 # 使用的sklearn中的默认值
    ax=None,                        # 同上
    xticks_rotation="horizontal",   # 同上
    values_format="d"               # 显示的数值格式
)
plt.show()