## 重み初期化

In [13]:
import torch
from torch import nn, optim
from torch.nn import init
from torch.nn import functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from functools import partial
import matplotlib.pyplot as plt
import numpy as np

%load_ext autoreload
%autoreload 2
import utils

### Xavierの初期化

In [24]:
X = torch.randn(60, 30)
n_out, n_in = 30, 30
# 正規分布: N(0, 2/(n_in+n_out))
std = torch.sqrt(torch.tensor(2. / (n_in + n_out)))
# 一様分布: U(-sqrt(6/(n_in+n_out)), sqrt(6/(n_in+n_out)))
limit = torch.sqrt(torch.tensor(6./(n_in + n_out)))

for i in range(50):
    # W = torch.randn(n_out, n_in) * std
    W = torch.rand(n_out, n_in)*2*limit - limit
    X = X @ W.T
print(X.mean(), X.std())
X

tensor(-0.0060) tensor(1.1168)


tensor([[-0.2391, -0.5007, -0.6922,  ...,  0.3678, -0.3084, -0.7777],
        [ 0.2905,  0.2817,  0.3490,  ..., -0.2360,  0.2540,  0.5578],
        [ 0.2290,  0.0092, -0.0281,  ..., -0.0654,  0.2377,  0.3167],
        ...,
        [-0.6328, -1.2054, -1.6357,  ...,  0.9476, -0.7491, -1.8050],
        [-0.6814,  0.1882,  0.3912,  ...,  0.0442, -0.9129, -1.1293],
        [-0.2482,  1.5669,  2.2883,  ..., -0.8815, -0.5745,  0.3754]])

### Kaiming初期化

In [28]:
X = torch.randn(60, 30)
n_out, n_in = 30, 30
# 正規分布: N(0, 2/(n_in))
std = torch.sqrt(torch.tensor(2. / n_in))
# 一様分布: U(-sqrt(6/(n_in)), sqrt(6/(n_in)))
limit = torch.sqrt(torch.tensor(6./n_in))

for i in range(50):
    W = torch.randn(n_out, n_in) * std
    # W = torch.rand(n_out, n_in)*2*limit - limit
    X = X @ W.T
    X = torch.clamp(X, min = 0)
print(X.mean(), X.std())
X

tensor(0.0352) tensor(0.0544)


tensor([[0.0000, 0.0276, 0.0000,  ..., 0.0281, 0.0264, 0.0576],
        [0.0000, 0.1291, 0.0000,  ..., 0.1085, 0.1284, 0.2367],
        [0.0000, 0.1473, 0.0000,  ..., 0.1429, 0.1472, 0.2953],
        ...,
        [0.0000, 0.0892, 0.0000,  ..., 0.0884, 0.0900, 0.1847],
        [0.0000, 0.0362, 0.0000,  ..., 0.0381, 0.0312, 0.0750],
        [0.0000, 0.0513, 0.0000,  ..., 0.0625, 0.0441, 0.1235]])

### Pytorchのkaiming初期化

In [10]:
conv = nn.Conv2d(1, 8, kernel_size = 3, stride = 2, padding = 1)
print(conv.weight.mean(), conv.weight.std()) # default std = np.sqrt(1/3*n_in*k*k)
init.kaiming_normal(conv.weight)
print(conv.weight.mean(), conv.weight.std()) # kaiming std = np.sqrt(2/n_in*k*k)

tensor(-0.0162, grad_fn=<MeanBackward0>) tensor(0.2042, grad_fn=<StdBackward0>)
tensor(-0.0536, grad_fn=<MeanBackward0>) tensor(0.4938, grad_fn=<StdBackward0>)


  init.kaiming_normal(conv.weight)


### batch norm + kiming init

In [21]:
def get_conv_model():
    return nn.Sequential(
    # 1x28x28
    nn.Conv2d(1, 4, kernel_size=3, stride=2, padding=1),
    nn.BatchNorm2d(4),
    nn.ReLU(),
    # 4x14x14
    nn.Conv2d(4, 8, kernel_size=3, stride=2, padding=1),
    nn.BatchNorm2d(8),
    nn.ReLU(),
    # 8x7x7
    nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1),
    nn.BatchNorm2d(16),
    nn.ReLU(),
    # 16x4x4
    nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    # 32x2x2
    nn.AdaptiveAvgPool2d(1),
    # 32x1x1
    nn.Flatten(),
    # 32
    nn.Linear(32, 10)
    # 10
)
conv_model = get_conv_model()

In [22]:
for layer in conv_model:
    if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
        init.kaiming_normal_(layer.weight)

In [23]:
# データ準備
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.FashionMNIST('./fmnist_data', train = True, download = True, transform = transform)
val_dataset = torchvision.datasets.FashionMNIST('./fmnist_data', train = False, download = True, transform = transform)
train_loader = DataLoader(train_dataset, batch_size = 1024, shuffle = True, num_workers = 4)
val_loader = DataLoader(val_dataset, batch_size = 1024, shuffle = False, num_workers = 4)

opt = optim.SGD(conv_model.parameters(), lr = 0.6)

In [24]:
act_stats = utils.ActivationStatistics(conv_model)

In [25]:
train_losses, val_losses, val_accuracies = utils.learn(conv_model, train_loader, val_loader, opt, F.cross_entropy, 3)

                                                          

epoch: 0: train error: 0.9445504849239931, validation error: 0.6044596910476685, validation accuracy: 0.7850247144699096


                                                          

epoch: 1: train error: 0.502600534992703, validation error: 0.49538477063179015, validation accuracy: 0.8227080702781677


                                                          

epoch: 2: train error: 0.43476225814576874, validation error: 0.47066461741924287, validation accuracy: 0.8229113519191742


