## MLP

In [1]:
import torch
from sklearn import datasets
import matplotlib.pyplot as plt
from torch.nn import functional as F
from sklearn.model_selection import train_test_split

In [2]:
## データ準備
# 1. データロード
dataset = datasets.load_digits()
images = dataset['images']
target = dataset['target']
# 学習データと検証データ分割
X_train, X_val, y_train, y_val = train_test_split(images, target, test_size=0.2, random_state=42)
print(X_train.shape, y_train.shape)
print(X_val.shape, y_val.shape)
# 前処理
# 2-1.ラベルのone-hot encoing
y_train = F.one_hot(torch.tensor(y_train), num_classes=10)
X_train = torch.tensor(X_train, dtype=torch.float32).reshape(-1, 64)

y_val = F.one_hot(torch.tensor(y_val), num_classes=10)
X_val = torch.tensor(X_val, dtype=torch.float32).reshape(-1, 64)

# 2-2. 画像の標準化
X_train_mean = X_train.mean()
X_train_std = X_train.std()
X_train = (X_train - X_train_mean) / X_train_std
X_val = (X_val - X_train_mean) / X_train_std

(1437, 8, 8) (1437,)
(360, 8, 8) (360,)


### スクラッチ実装 (順伝搬のみ)

In [3]:
m, n = X_train.shape
nh = 30
class_num = 10
# パラメータの初期化
W1 = torch.randn((nh, n), requires_grad=True) # 出力 x 入力
b1 = torch.zeros((1, nh), requires_grad=True) # 1 x nh

W2 = torch.randn((class_num, nh), requires_grad=True) # 出力 x 入力
b2 = torch.zeros((1, class_num), requires_grad=True) # 1 x nh

In [4]:
def linear(X, W, b):
    return X@W.T + b

In [5]:
def relu(Z):
    return Z.clamp_min(0.)

In [6]:
def softmax(x):
    # xが大きすぎると，exp(x)がinfになるので，maxを引くようにする(結果は変わらない)
    e_x = torch.exp(x - torch.max(x, dim=-1, keepdim=True)[0])
    return e_x / (torch.sum(e_x, dim=-1, keepdim=True) + 1e-10)

In [7]:
def model(X):
    Z1 = linear(X, W1, b1)
    A1 = relu(Z1)
    Z2 = linear(A1, W2, b2)
    A2 = softmax(Z2)
    return A2

In [8]:
y_train_pred = model(X_train)

In [9]:
y_train_pred
# y_train_pred.sum(dim=1)　# 合計は全て1になる

tensor([[1.0000e+00, 1.7962e-41, 2.2879e-22,  ..., 1.1210e-44, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 2.4464e-30, 1.0893e-13,  ..., 6.1829e-40, 2.5243e-39,
         2.8841e-33],
        [9.9998e-01, 1.0224e-10, 1.2434e-07,  ..., 3.4712e-35, 1.5989e-25,
         0.0000e+00],
        ...,
        [9.9919e-01, 4.7314e-32, 7.2059e-14,  ..., 2.0465e-33, 3.3547e-40,
         8.3996e-34],
        [1.0000e+00, 8.1733e-17, 6.2011e-16,  ..., 9.8235e-36, 6.5890e-40,
         1.4013e-45],
        [2.1944e-32, 1.5518e-33, 5.1905e-27,  ..., 0.0000e+00, 2.8026e-45,
         0.0000e+00]], grad_fn=<DivBackward0>)