In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchmetrics.functional as metrics
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.datasets import fetch_openml


In [2]:
class CModel(nn.Module):
    def __init__(self, in_: int, out_: int) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, out_),
        )

    def forward(self, x):
        x = self.layers(x)
        return x


In [3]:
class MNISTDataset(Dataset):
    def __init__(self, X_data, y_data):
        super().__init__()
        try:
            X_data = X_data if isinstance(X_data, np.ndarray) else X_data.values
            y_data = y_data if isinstance(y_data, np.ndarray) else y_data.values

            X_data = (X_data / 255.0).astype(np.float32)
            y_data = y_data.astype(np.int64)

            self.feature = torch.tensor(X_data)
            self.target = torch.tensor(y_data)
        except Exception as e:
            print(e)

    def __len__(self):
        return self.target.shape[0]

    def __getitem__(self, idx):
        return self.feature[idx], self.target[idx]

    def getInOut(self):
        return self.feature.shape[1], len(torch.unique(self.target))


In [4]:
class MNISTModel:
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    loss_fn = nn.CrossEntropyLoss().to(device)

    def __init__(self, in_: int, out_: int, epochs: int, auto_stop: bool = True):
        self.epochs = epochs
        self.in_ = in_
        self.out_ = out_
        self.auto_stop = auto_stop
        self.model = CModel(in_, out_).to(MNISTModel.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.01)
        self.scheduler = ReduceLROnPlateau(
            self.optimizer, mode="min", patience=4, eps=1e-12
        )

    def fitting(self, dataset):
        self.model.train()

        lossList = []
        for feature, target in dataset:
            feature, target = feature.to(MNISTModel.device), target.to(
                MNISTModel.device
            )
            pre_target = self.model(feature)
            loss = MNISTModel.loss_fn(pre_target, target)
            lossList.append(loss.item())

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
        loss = np.mean(lossList)
        acc = metrics.accuracy(
            pre_target, target, task="multiclass", num_classes=self.out_
        )
        return loss, acc.item()

    def testing(self, dataset):
        self.model.eval()

        with torch.no_grad():
            lossList = []
            for feature, target in dataset:
                feature, target = feature.to(MNISTModel.device), target.to(
                    MNISTModel.device
                )
                pre_target = self.model(feature)
                loss = MNISTModel.loss_fn(pre_target, target)
                lossList.append(loss.item())
                break
            loss = np.mean(lossList)
            acc = metrics.accuracy(
                pre_target, target, task="multiclass", num_classes=self.out_
            )
            return loss, acc.item()

    def training(self, train_set, val_set):
        count = 0
        train_list = [[], []]
        val_list = [[], []]
        for epoch in range(self.epochs + 1):
            train_loss, train_acc = self.fitting(train_set)
            val_loss, val_acc = self.testing(val_set)
            train_list[0].append(train_loss), train_list[1].append(train_acc)
            val_list[0].append(val_loss), val_list[1].append(val_acc)
            self.scheduler.step(val_loss)

            self.__print_result(epoch, train_loss, train_acc, val_loss, val_acc, count)

            if self.auto_stop:
                if self.scheduler.num_bad_epochs >= self.scheduler.patience:
                    count += 1
                    if count >= 10:
                        self.__print_result(
                            epoch, train_loss, train_acc, val_loss, val_acc, count
                        )
                        break
        return train_list, val_list

    def __print_result(self, epoch, train_loss, train_acc, val_loss, val_acc, count):
        bad_epochs = self.scheduler.num_bad_epochs
        print(
            f"[Epoch {epoch:4d}/{self.epochs}]     [bad  {bad_epochs}]   [change {count:2d}/10]"
        )
        print(f"[Train Loss] {train_loss:.6f} [Train Accuracy] {train_acc:.6f}")
        print(f"[Valid Loss] {val_loss:.6f} [Valid Accuracy] {val_acc:.6f}")


In [5]:
mnist_fashion = fetch_openml("Fashion-MNIST", parser="auto", as_frame=False)


In [6]:
X, y = mnist_fashion.data, mnist_fashion.target
X.shape, y.shape


((70000, 784), (70000,))

In [7]:
mnist_dataset = MNISTDataset(X, y)


In [8]:
seed = torch.Generator().manual_seed(42)

trainDS, validDS, testDS = random_split(
    mnist_dataset, [50000, 10000, 10000], generator=seed
)

print(
    f"trainDS => {len(trainDS)}개, validDS => {len(validDS)}개, testDS => {len(testDS)}개"
)


trainDS => 50000개, validDS => 10000개, testDS => 10000개


In [9]:
batch = 64
trainDL = DataLoader(trainDS, batch_size=batch, drop_last=True)
validDL = DataLoader(validDS, batch_size=batch, drop_last=True)
testDL = DataLoader(testDS, batch_size=batch, drop_last=True)

print(f"batch_size : {batch}")
print(
    f"trainDS => {len(trainDS)}개, validDS => {len(validDS)}개, testDS => {len(testDS)}개"
)
print(
    f"trainDL => {len(trainDL)}개, validDL => {len(validDL)}개, testDL => {len(testDL)}개"
)


batch_size : 64
trainDS => 50000개, validDS => 10000개, testDS => 10000개
trainDL => 781개, validDL => 156개, testDL => 156개


In [10]:
IN, OUT = mnist_dataset.getInOut()
epoch = 1000
model = MNISTModel(IN, OUT, epoch)


In [11]:
trainList, valList = model.training(trainDL, validDL)


[Epoch    0/1000]     [bad  0]   [change  0/10]
[Train Loss] 0.547263 [Train Accuracy] 0.890625
[Valid Loss] 0.381692 [Valid Accuracy] 0.859375
[Epoch    1/1000]     [bad  0]   [change  0/10]
[Train Loss] 0.420492 [Train Accuracy] 0.906250
[Valid Loss] 0.332668 [Valid Accuracy] 0.906250
[Epoch    2/1000]     [bad  0]   [change  0/10]
[Train Loss] 0.391068 [Train Accuracy] 0.937500
[Valid Loss] 0.275842 [Valid Accuracy] 0.875000
[Epoch    3/1000]     [bad  0]   [change  0/10]
[Train Loss] 0.374178 [Train Accuracy] 0.921875
[Valid Loss] 0.174371 [Valid Accuracy] 0.937500
[Epoch    4/1000]     [bad  1]   [change  0/10]
[Train Loss] 0.366290 [Train Accuracy] 0.921875
[Valid Loss] 0.217494 [Valid Accuracy] 0.921875
[Epoch    5/1000]     [bad  2]   [change  0/10]
[Train Loss] 0.359405 [Train Accuracy] 0.937500
[Valid Loss] 0.202320 [Valid Accuracy] 0.921875
[Epoch    6/1000]     [bad  3]   [change  0/10]
[Train Loss] 0.343478 [Train Accuracy] 0.890625
[Valid Loss] 0.202430 [Valid Accuracy] 0

In [12]:
model.testing(testDL)


(0.30113381147384644, 0.921875)

In [13]:
# import matplotlib.pyplot as plt

# fig, axs = plt.subplots(2, 2, figsize=(10, 8))
# axs[0, 0].plot(lossList[0], label="Train Loss")
# axs[0, 0].legend(loc="upper left")
# axs[0, 0].set_xlabel("Epochs")
# axs[0, 0].set_ylabel("Train Loss")
# axs[0, 0].set_yticks(np.arange(0.25, 2.51, 0.25))
# axs1 = axs[0, 0].twinx()
# axs1.plot(lossList[1], label="Val Loss", color="red")
# axs1.legend(loc="upper right")
# axs1.set_ylabel("Val Loss")
# axs1.set_yticks(np.arange(0.25, 2.51, 0.25))

# axs[0, 1].plot(lossList[0], label="Train Loss")
# axs[0, 1].legend(loc="upper left")
# axs[0, 1].set_xlabel("Epochs")
# axs[0, 1].set_ylabel("Train Loss")
# axs[0, 1].set_yticks(np.arange(0.25, 2.51, 0.25))
# axs2 = axs[0, 1].twinx()
# axs2.plot(accuracyList[0], label="Train Accuracy", color="red")
# axs2.legend(loc="upper right")
# axs2.set_ylabel("Train Accuracy")
# axs2.set_yticks(np.arange(0, 1.1, 0.1))


# axs[1, 0].plot(accuracyList[0], label="Train Accuracy")
# axs[1, 0].legend(loc="upper left")
# axs[1, 0].set_xlabel("Epochs")
# axs[1, 0].set_ylabel("Train Accuracy")
# axs[1, 0].set_yticks(np.arange(0, 1.1, 0.1))
# axs3 = axs[1, 0].twinx()
# axs3.plot(accuracyList[1], label="Val Accuracy", color="red")
# axs3.legend(loc="upper right")
# axs3.set_ylabel("Val Accuracy")
# axs3.set_yticks(np.arange(0, 1.1, 0.1))


# axs[1, 1].plot(lossList[1], label="Val Loss")
# axs[1, 1].legend(loc="upper left")
# axs[1, 1].set_xlabel("Epochs")
# axs[1, 1].set_ylabel("Val Loss")
# axs[1, 1].set_yticks(np.arange(0.25, 2.51, 0.25))
# axs4 = axs[1, 1].twinx()
# axs4.plot(accuracyList[1], label="Val Accuracy", color="red")
# axs4.legend(loc="upper right")
# axs4.set_ylabel("Val Accuracy")
# axs4.set_yticks(np.arange(0, 1.1, 0.1))
# plt.tight_layout()
# plt.show()
