In [1]:
import torch
from torch import nn

def batch_norm(X,gamma,beta,moving_mean,moving_var,eps,momentum):
    if not torch.is_grad_enabled():
        X_hat=(X-moving_mean)/torch.sqrt(moving_var+eps)
    else:
        assert len(X.shape) in (2,4)

        if len(X.shape)==2:
            mean=X.mean(dim=0)
            var=((X-mean)**2).mean(dim=0)
        else:
            mean=X.mean(dim=(0,2,3),keepdim=True)
            var=((X-mean)**2).mean(dim=(0,2,3),keepdim=True)

        X_hat=(X-mean)/torch.sqrt(var+eps)

        moving_mean=momentum*moving_mean+(1.0-momentum)*mean
        moving_var=momentum*moving_var+(1.0-momentum)*var
    Y=gamma*X_hat+beta
    return Y,moving_mean.data,moving_var.data

In [2]:
class BatchNorm(nn.Module):
    def __init__(self,num_features,num_dims):
        super(BatchNorm,self).__init__()
        if num_dims==2:
            shape=(1,num_features)
        else:
            shape=(1,num_features,1,1)
        self.gamma=nn.Parameter(torch.ones(shape))
        self.beta=nn.Parameter(torch.zeros(shape))
        self.moving_mean=torch.zeros(shape)
        self.moving_var=torch.zeros(shape)

    def forward(self,X):
        if self.moving_mean.device!=X.device:
            self.moving_mean=self.moving_mean.to(X.device)
            self.moving_var=self.moving_var.to(X.device)
        Y,self.moving_mean,self.moving_var=batch_norm(X,self.gamma,self.beta,self.moving_mean,self.moving_var,eps=1e-5,momentum=0.9)
        return Y

In [3]:
net=nn.Sequential(
    nn.Conv2d(1,6,5),
    BatchNorm(6,num_dims=4),
    nn.Sigmoid(),
    nn.MaxPool2d(2,2),
    nn.Conv2d(6,16,5),
    BatchNorm(16,num_dims=4),
    nn.Sigmoid(),
    nn.MaxPool2d(2,2),
    nn.Flatten(),
    nn.Linear(16*4*4,120),
    BatchNorm(120,num_dims=2),
    nn.Sigmoid(),
    nn.Linear(120,84),
    BatchNorm(84,num_dims=2),
    nn.Sigmoid(),
    nn.Linear(84,10)
)

In [4]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.transforms import Resize

In [5]:
transform = transforms.Compose([
    ToTensor()
])
mnist_training = datasets.MNIST(
    root="../data",
    train=True,
    transform=transform,
    download=False
)

mnist_test = datasets.MNIST(
    root="../data",
    train=False,
    transform=transform,
    download=False
)

torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

BATCH_SIZE = 256
lr = 1
epochs = 20

train_dataloader = DataLoader(mnist_training, batch_size=BATCH_SIZE, shuffle=True)

test_dataloader = DataLoader(mnist_test, batch_size=BATCH_SIZE, shuffle=False)

# net = xxxNet().to(device)
net=net.to(device)


def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)


net.apply(init_weights)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9)

for epoch in range(epochs):
    print(
        f"epoch {epoch} \n---------------------"
    )

    for batch, (inputs, labels) in enumerate(train_dataloader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        if batch % 10 == 0:
            loss, current = loss.item(), batch * len(inputs)
            print(f"loss:{loss:>7f} [{current:>5d}/ 60000]")

    with torch.no_grad():
        acc = 0
        total = 0
        for (image, label) in test_dataloader:
            image, label = image.to(device), label.to(device)
            output = net(image)
            _, pred = torch.max(output.data, 1)
            total += label.size(0)
            acc += (pred == label).sum()

        print(f"test: acc {100 * acc / total}")


epoch 0 
---------------------
loss:2.544879 [    0/ 60000]
loss:0.944075 [ 2560/ 60000]
loss:0.318252 [ 5120/ 60000]
loss:0.241130 [ 7680/ 60000]
loss:0.177844 [10240/ 60000]
loss:0.145672 [12800/ 60000]
loss:0.143897 [15360/ 60000]
loss:0.188408 [17920/ 60000]
loss:0.176945 [20480/ 60000]
loss:0.165357 [23040/ 60000]
loss:0.075873 [25600/ 60000]
loss:0.113786 [28160/ 60000]
loss:0.155481 [30720/ 60000]
loss:0.057228 [33280/ 60000]
loss:0.098955 [35840/ 60000]
loss:0.146092 [38400/ 60000]
loss:0.136419 [40960/ 60000]
loss:0.064658 [43520/ 60000]
loss:0.154321 [46080/ 60000]
loss:0.069446 [48640/ 60000]
loss:0.058521 [51200/ 60000]
loss:0.077671 [53760/ 60000]
loss:0.064381 [56320/ 60000]
loss:0.084365 [58880/ 60000]
test: acc 97.20999908447266
epoch 1 
---------------------
loss:0.046129 [    0/ 60000]
loss:0.064040 [ 2560/ 60000]
loss:0.052796 [ 5120/ 60000]
loss:0.077006 [ 7680/ 60000]
loss:0.096489 [10240/ 60000]
loss:0.121403 [12800/ 60000]
loss:0.043085 [15360/ 60000]
loss:0.0996

In [6]:
net[1].gamma.reshape((-1,)), net[1].beta.reshape((-1,))

(tensor([3.9467, 2.6483, 2.2780, 1.6499, 2.1272, 3.9457], device='cuda:0',
        grad_fn=<ReshapeAliasBackward0>),
 tensor([-3.3139,  0.1249,  1.6347, -1.7970, -0.5650, -2.7569], device='cuda:0',
        grad_fn=<ReshapeAliasBackward0>))

In [7]:
net=nn.Sequential(
    nn.Conv2d(1,6,5),
    nn.BatchNorm2d(6),
    nn.Sigmoid(),
    nn.MaxPool2d(2,2),
    nn.Conv2d(6,16,5),
    nn.BatchNorm2d(16),
    nn.Sigmoid(),
    nn.MaxPool2d(2,2),
    nn.Flatten(),
    nn.Linear(16*4*4,120),
    nn.BatchNorm1d(120),
    nn.Sigmoid(),
    nn.Linear(120,84),
    nn.BatchNorm1d(84),
    nn.Sigmoid(),
    nn.Linear(84,10)
).to(device)
def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)


net.apply(init_weights)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9)

for epoch in range(epochs):
    print(
        f"epoch {epoch} \n---------------------"
    )

    for batch, (inputs, labels) in enumerate(train_dataloader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        if batch % 10 == 0:
            loss, current = loss.item(), batch * len(inputs)
            print(f"loss:{loss:>7f} [{current:>5d}/ 60000]")

    with torch.no_grad():
        acc = 0
        total = 0
        for (image, label) in test_dataloader:
            image, label = image.to(device), label.to(device)
            output = net(image)
            _, pred = torch.max(output.data, 1)
            total += label.size(0)
            acc += (pred == label).sum()

        print(f"test: acc {100 * acc / total}")


epoch 0 
---------------------
loss:2.610928 [    0/ 60000]
loss:0.774725 [ 2560/ 60000]
loss:0.248046 [ 5120/ 60000]
loss:0.226470 [ 7680/ 60000]
loss:0.281178 [10240/ 60000]
loss:0.134376 [12800/ 60000]
loss:0.177315 [15360/ 60000]
loss:0.128242 [17920/ 60000]
loss:0.103134 [20480/ 60000]
loss:0.120544 [23040/ 60000]
loss:0.104558 [25600/ 60000]
loss:0.129680 [28160/ 60000]
loss:0.120950 [30720/ 60000]
loss:0.126245 [33280/ 60000]
loss:0.091974 [35840/ 60000]
loss:0.062910 [38400/ 60000]
loss:0.127991 [40960/ 60000]
loss:0.111706 [43520/ 60000]
loss:0.085169 [46080/ 60000]
loss:0.061671 [48640/ 60000]
loss:0.078981 [51200/ 60000]
loss:0.091333 [53760/ 60000]
loss:0.061117 [56320/ 60000]
loss:0.115154 [58880/ 60000]
test: acc 98.0199966430664
epoch 1 
---------------------
loss:0.081720 [    0/ 60000]
loss:0.100422 [ 2560/ 60000]
loss:0.109519 [ 5120/ 60000]
loss:0.085373 [ 7680/ 60000]
loss:0.043798 [10240/ 60000]
loss:0.061911 [12800/ 60000]
loss:0.110254 [15360/ 60000]
loss:0.03308

KeyboardInterrupt: 