<a href="https://colab.research.google.com/github/Tsung-Hung/dummy-git/blob/master/20230316.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [20]:
# https://stackoverflow.com/questions/68478856/pytorch-batchnorm2d-calculation
# https://discuss.pytorch.org/t/how-to-use-scripting-with-custom-batchnorm/85375/6
# https://blog.csdn.net/qq_39208832/article/details/117930625
# https://yichengsu.github.io/2019/12/pytorch-batchnorm-freeze/
# pytorch BatchNorm参数详解，计算过程 - 水木清扬 - 博客园
# https://www.zhihu.com/question/487766088

In [1]:
from torch.nn.modules.batchnorm import BatchNorm1d
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
torch.manual_seed(0)
batch, sentence_length, embedding_dim = 2, 3, 4
embedding = torch.randn(batch, sentence_length, embedding_dim)
print(embedding)
layer_norm = nn.LayerNorm(embedding_dim)

# Activate module
origin = layer_norm(embedding)
print(origin)
print(layer_norm.state_dict())

tensor([[[-1.1258, -1.1524, -0.2506, -0.4339],
         [ 0.8487,  0.6920, -0.3160, -2.1152],
         [ 0.4681, -0.1577,  1.4437,  0.2660]],

        [[ 0.1665,  0.8744, -0.1435, -0.1116],
         [ 0.9318,  1.2590,  2.0050,  0.0537],
         [ 0.6181, -0.4128, -0.8411, -2.3160]]])
tensor([[[-0.9539, -1.0196,  1.2137,  0.7598],
         [ 0.9075,  0.7747, -0.0791, -1.6031],
         [-0.0629, -1.1288,  1.5988, -0.4070]],

        [[-0.0732,  1.6553, -0.8299, -0.7521],
         [-0.1864,  0.2808,  1.3460, -1.4403],
         [ 1.2863,  0.3084, -0.0978, -1.4969]]],
       grad_fn=<NativeLayerNormBackward0>)
OrderedDict([('weight', tensor([1., 1., 1., 1.])), ('bias', tensor([0., 0., 0., 0.]))])


In [3]:
def Layer_norm(x):
    '''
    BatchNorm2d
    test = Layer_norm1(embedding)
    test
    '''
    batch, sentence_length, embedding_dim = x.shape
    x = x.reshape(1, batch*sentence_length, embedding_dim, 1)
    layer_norm = nn.BatchNorm2d(x.shape[1], track_running_stats=True, affine=True)
    output = layer_norm(x)
    output = output.reshape(output.shape[1], output.shape[2]).reshape(batch, sentence_length, embedding_dim)
    return output

def Layer_norm2(x):
    '''
    BatchNorm1d
    test = Layer_norm2(embedding)
    test
    '''
    batch, sentence_length, embedding_dim = x.shape
    x = x.reshape(1, batch*sentence_length, embedding_dim)
    layer_norm = nn.BatchNorm1d(x.shape[1], track_running_stats=True, affine=True, momentum=0)

    # Set the running statistics to constant
    layer_norm.running_mean.fill_(0)
    layer_norm.running_var.fill_(1)
    output = layer_norm(x)
    output = output.reshape(output.shape[1], output.shape[2]).reshape(batch, sentence_length, embedding_dim)
    return output

class LayerNorm2(nn.Module):
    '''
    layer_norm2 = LayerNorm2(embedding_dim)
    # Activate module
    test = layer_norm2(embedding)
    test
    '''
    def __init__(self, num_features, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_features), requires_grad=False)
        self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=False)
        self.eps = eps

    def forward(self, x):
        shape = x.shape
        mean = x.mean(-1)
        std = torch.sqrt(x.var(-1, unbiased=False))
        res = torch.stack([((x[i][j] - mean[i][j]).squeeze(0) / (std[i][j].squeeze(0) + self.eps)) for i in range(x.shape[0]) for j in range(x.shape[1])], dim=0)
        return res.reshape(shape)

class LayerNorm_sim(nn.Module):
    '''
    Provided by "Gina" 
    layer_norm_sim = LayerNorm_sim(embedding_dim)
    # Activate module
    test = layer_norm_sim(embedding)
    test
    '''
    def __init__(self, num_features):
        super().__init__()
        self.BN = nn.BatchNorm1d(num_features, momentum=1, affine=True)

    def forward(self, x):
        x_norm_list = []
        for i in range(x.shape[0]):
          x_batch = x[i, ...].unsqueeze(0)
          x_norm_list.append(self.BN(x_batch))
        return torch.cat(x_norm_list, dim=0)

layer_norm_sim = LayerNorm_sim(3)
# Activate module
test = layer_norm_sim(embedding)
layer_norm_sim.BN.state_dict()

test = Layer_norm2(embedding)
test

tensor([[[-0.9539, -1.0196,  1.2137,  0.7598],
         [ 0.9075,  0.7747, -0.0791, -1.6031],
         [-0.0629, -1.1288,  1.5988, -0.4070]],

        [[-0.0732,  1.6553, -0.8299, -0.7521],
         [-0.1864,  0.2808,  1.3460, -1.4403],
         [ 1.2863,  0.3084, -0.0978, -1.4969]]],
       grad_fn=<ReshapeAliasBackward0>)

# Layernorm

In [4]:
torch.manual_seed(0)
# Define the SimpleNet model
class SimpleNet(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleNet, self).__init__()
        self.ln1 = nn.LayerNorm(64, elementwise_affine = True)
        self.fc1 = nn.Linear(640, 1)

    def forward(self, x):
        x = self.ln1(x) #input(1,10,64)
        x = x.view(-1, 640)
        x = self.fc1(x)
        return x

# Generate a simulated dataset
inputs = torch.rand(3, 10, 64)
raw_inputs = inputs
labels = torch.randint(low=0, high=10, size=(3,))
inputs = torch.Tensor(inputs)
labels = torch.Tensor(labels).long()
train_dataset = torch.utils.data.TensorDataset(inputs, labels)

# Create a DataLoader to load and batch the data
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=3, shuffle=False)

# Define the loss function and the optimizer
net = SimpleNet()
criterion = nn.L1Loss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# Train the model
for epoch in range(25):
    running_loss = 0.0
    #print("1. Running mean:", net.bn1.running_mean)
    for i, data in enumerate(train_loader):
        inputs, labels = data
        outputs = net(inputs)

        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        #print("2. Running mean:", net.bn1.running_mean)
        #print("2. Running var:", net.bn1.running_var)
    print('Epoch [{}/50], Loss: {:.4f}'.format(epoch+1, running_loss / len(train_loader)))
    #print("3. Running mean:", net.bn1.running_mean)
    print("===========================")

print('\n', net.ln1.state_dict(), '\n')

torch.manual_seed(1)
test_sample = torch.randn(1, 10, 64)
net.eval()
output = net(test_sample)
print('test results: ', output)

Epoch [1/50], Loss: 7.3685
Epoch [2/50], Loss: 7.1339
Epoch [3/50], Loss: 6.6881
Epoch [4/50], Loss: 6.0521
Epoch [5/50], Loss: 5.2450
Epoch [6/50], Loss: 4.2836
Epoch [7/50], Loss: 3.1832
Epoch [8/50], Loss: 1.9571
Epoch [9/50], Loss: 1.3168
Epoch [10/50], Loss: 1.1649
Epoch [11/50], Loss: 1.8080
Epoch [12/50], Loss: 2.6305
Epoch [13/50], Loss: 3.1412
Epoch [14/50], Loss: 3.3626
Epoch [15/50], Loss: 3.3238
Epoch [16/50], Loss: 3.0515
Epoch [17/50], Loss: 2.5698
Epoch [18/50], Loss: 1.9009
Epoch [19/50], Loss: 1.2984
Epoch [20/50], Loss: 1.1313
Epoch [21/50], Loss: 1.3037
Epoch [22/50], Loss: 1.5062
Epoch [23/50], Loss: 1.7159
Epoch [24/50], Loss: 1.8933
Epoch [25/50], Loss: 1.8772

 OrderedDict([('weight', tensor([0.9952, 0.9935, 0.9953, 0.9907, 0.9962, 0.9912, 0.9938, 0.9912, 0.9970,
        0.9944, 0.9961, 0.9965, 0.9958, 0.9965, 0.9892, 0.9986, 0.9935, 0.9935,
        0.9964, 0.9940, 0.9986, 0.9895, 0.9967, 0.9966, 0.9941, 0.9950, 0.9945,
        0.9918, 0.9935, 0.9931, 0.9944, 0.9

  return F.l1_loss(input, target, reduction=self.reduction)


# Replace layernorm with batchnorm

In [5]:
torch.manual_seed(0)
# Define the SimpleNet model
class SimpleNet(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleNet, self).__init__()
        self.bn1 = BatchNorm1d(10, track_running_stats=True, affine=False, momentum=1)
        self.fc1 = nn.Linear(640, 1)

        self.gamma = nn.Parameter(torch.ones(64))
        self.beta = nn.Parameter(torch.zeros(64))

    def Layer_norm1(self, x):
        batch, sentence_length, embedding_dim = x.shape
        x = x.reshape(1, batch*sentence_length, embedding_dim)

        output = self.bn1(x) 
        output = output.reshape(output.shape[1], output.shape[2]).reshape(batch, sentence_length, embedding_dim)

        #apply gamma and beta
        return self.gamma * output + self.beta


    def loop_layernorm1(self, x):      
        split_x = torch.split(x,1,dim=0)
        concate_list = []
        for small_x in split_x:
            #small_x = small_x.squeeze().permute(1,0)
            #small_x = self.bn1(small_x).permute(1,0)
            small_x = self.bn1(small_x)
            concate_list.append(small_x)
    
        concate = torch.stack(concate_list)
        return self.gamma * concate + self.beta


    def forward(self, x):
        x = self.loop_layernorm1(x) #input(1,10,64)
        x = x.reshape(-1, 640)
        x = self.fc1(x)
        return x

# Generate a simulated dataset
inputs = torch.rand(3, 10, 64)
raw_inputs = inputs
labels = torch.randint(low=0, high=10, size=(3,))
inputs = torch.Tensor(inputs)
labels = torch.Tensor(labels).long()
train_dataset = torch.utils.data.TensorDataset(inputs, labels)

# Create a DataLoader to load and batch the data
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=3, shuffle=False)

# Define the loss function and the optimizer
net = SimpleNet()
criterion = nn.L1Loss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# Train the model
for epoch in range(25):
    running_loss = 0.0
    #print("1. Running mean:", net.bn2.state_dict())
    for i, data in enumerate(train_loader):
        inputs, labels = data
        outputs = net(inputs)

        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        #print("2. Running mean:", net.bn1.running_mean)
        #print("2. Running var:", net.bn1.running_var)
    print('Epoch [{}/50], Loss: {:.4f}'.format(epoch+1, running_loss / len(train_loader)))
    #print("3. Running mean:", net.bn1.running_mean)
    print("===========================")

net.bn1.state_dict()

Epoch [1/50], Loss: 7.3685
Epoch [2/50], Loss: 7.1339
Epoch [3/50], Loss: 6.6881
Epoch [4/50], Loss: 6.0521
Epoch [5/50], Loss: 5.2450
Epoch [6/50], Loss: 4.2836
Epoch [7/50], Loss: 3.1832
Epoch [8/50], Loss: 1.9571
Epoch [9/50], Loss: 1.3168
Epoch [10/50], Loss: 1.1649
Epoch [11/50], Loss: 1.8080
Epoch [12/50], Loss: 2.6305
Epoch [13/50], Loss: 3.1412
Epoch [14/50], Loss: 3.3626
Epoch [15/50], Loss: 3.3238
Epoch [16/50], Loss: 3.0515
Epoch [17/50], Loss: 2.5698
Epoch [18/50], Loss: 1.9009
Epoch [19/50], Loss: 1.2984
Epoch [20/50], Loss: 1.1313
Epoch [21/50], Loss: 1.3037
Epoch [22/50], Loss: 1.5062
Epoch [23/50], Loss: 1.7159
Epoch [24/50], Loss: 1.8933
Epoch [25/50], Loss: 1.8772


OrderedDict([('running_mean',
              tensor([0.4767, 0.5704, 0.4466, 0.5212, 0.4753, 0.5117, 0.5128, 0.5090, 0.5241,
                      0.5205])),
             ('running_var',
              tensor([0.0769, 0.0829, 0.0836, 0.0928, 0.0749, 0.0774, 0.0873, 0.0764, 0.0794,
                      0.0868])),
             ('num_batches_tracked', tensor(75))])

In [6]:
def set_bn_training(m):
    if isinstance(m, nn.BatchNorm1d):
        m.train()

torch.manual_seed(1)
test_sample = torch.randn(1, 10, 64)
#print(test_sample)
#print(torch.mean(test_sample, dim=2))
net.eval()
print(net.bn1.training)
print("After Running", net.bn1.state_dict())

net.apply(set_bn_training)
output = net(test_sample)
print("Before Running", net.bn1.state_dict())
print(net.bn1.training)

print('\n', '-----------------------')
print(net.bn1.state_dict())
print(output)

False
After Running OrderedDict([('running_mean', tensor([0.4767, 0.5704, 0.4466, 0.5212, 0.4753, 0.5117, 0.5128, 0.5090, 0.5241,
        0.5205])), ('running_var', tensor([0.0769, 0.0829, 0.0836, 0.0928, 0.0749, 0.0774, 0.0873, 0.0764, 0.0794,
        0.0868])), ('num_batches_tracked', tensor(75))])
Before Running OrderedDict([('running_mean', tensor([-0.1130, -0.0192,  0.1874,  0.1517, -0.0068, -0.0377, -0.0822, -0.1019,
         0.1284, -0.0743])), ('running_var', tensor([1.0302, 0.9497, 1.2041, 1.1384, 1.1395, 0.9414, 0.9778, 1.0286, 0.6879,
        1.3984])), ('num_batches_tracked', tensor(76))])
True

 -----------------------
OrderedDict([('running_mean', tensor([-0.1130, -0.0192,  0.1874,  0.1517, -0.0068, -0.0377, -0.0822, -0.1019,
         0.1284, -0.0743])), ('running_var', tensor([1.0302, 0.9497, 1.2041, 1.1384, 1.1395, 0.9414, 0.9778, 1.0286, 0.6879,
        1.3984])), ('num_batches_tracked', tensor(76))])
tensor([[-1.1568]], grad_fn=<AddmmBackward0>)


In [7]:
# Test Batchnorm
torch.manual_seed(1)
test_sample = torch.randn(1, 10, 64)

mean = torch.mean(test_sample, dim=2).reshape(1,10,1)
var = torch.var(test_sample, dim=2, unbiased=False).reshape(1,10,1)

norm = (test_sample-mean)/torch.sqrt(var + net.bn1.eps)
#norm = (test_sample-net.bn1.running_mean.reshape(1,10,1))/torch.sqrt(net.bn1.running_var.reshape(1,10,1)*63/64 + net.bn1.eps)

#test = net.bn1(test_sample)
results = net.gamma * norm + net.beta
net.fc1(results.reshape(-1,640))
     

tensor([[-1.1568]], grad_fn=<AddmmBackward0>)

# Replace layernorm with batchnorm 2

In [13]:
torch.manual_seed(0)
# Defince the customer batchnorm
class CustomBatchNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=1):
        super(CustomBatchNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum

        # Initialize learnable parameters
        #self.gamma = nn.Parameter(torch.ones(num_features), requires_grad=False)
        #self.beta = nn.Parameter(torch.zeros(num_features), requires_grad=False)
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))


    def forward_(self, x):
        # Calculate batch statistics
        batch_mean = x.mean(dim=2)
        batch_var = x.var(dim=2, unbiased=False)

        # Update running statistics
        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var

        # Normalize input
        x = (x - batch_mean.unsqueeze(dim=2)) / torch.sqrt(batch_var.unsqueeze(dim=2) + self.eps)

        # Apply scale and shift
        #x = self.gamma * x + self.beta
        return x

    def forward(self, x):
        # Calculate batch statistics
        # if self.training:

        #     batch_mean = x.mean([0, 2])
        #     batch_var = x.var([0, 2], unbiased=False)

        #     # Update running statistics
        #     with torch.no_grad():
        #         self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
        #         self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
        # else:
        #     batch_mean = x.mean([0, 2])
        #     batch_var = x.var([0, 2], unbiased=False)
        #     with torch.no_grad():
        #         self.running_mean = batch_mean
        #         self.running_var = batch_var
           
        # Normalize input
        x = F.batch_norm(x, running_mean=self.running_mean, running_var=self.running_var, 
                  training=True, momentum=self.momentum, eps=self.eps)
        return x


# Define the SimpleNet model
class SimpleNet(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleNet, self).__init__()
        #self.bn1 = BatchNorm1d(10, track_running_stats=True, affine=False, momentum=1)
        self.bn1 = CustomBatchNorm(10)
        self.fc1 = nn.Linear(640, 1)

        self.gamma = nn.Parameter(torch.ones(64))
        self.beta = nn.Parameter(torch.zeros(64))

    def Layer_norm1(self, x):
        batch, sentence_length, embedding_dim = x.shape
        x = x.reshape(1, batch*sentence_length, embedding_dim)

        output = self.bn1(x) 
        output = output.reshape(output.shape[1], output.shape[2]).reshape(batch, sentence_length, embedding_dim)

        #apply gamma and beta
        return self.gamma * output + self.beta


    def loop_layernorm1(self, x):      
        split_x = torch.split(x,1,dim=0)
        concate_list = []
        for small_x in split_x:
            small_x = self.bn1(small_x)
            concate_list.append(small_x)
    
        concate = torch.cat(concate_list)
        return self.gamma * concate + self.beta


    def forward(self, x):
        x = self.loop_layernorm1(x) #input(1,10,64)
        x = x.reshape(-1, 640)
        x = self.fc1(x)
        return x

# Generate a simulated dataset
inputs = torch.rand(3, 10, 64)
raw_inputs = inputs
labels = torch.randint(low=0, high=10, size=(3,))
inputs = torch.Tensor(inputs)
labels = torch.Tensor(labels).long()
train_dataset = torch.utils.data.TensorDataset(inputs, labels)

# Create a DataLoader to load and batch the data
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=3, shuffle=False)

# Define the loss function and the optimizer
net = SimpleNet()
criterion = nn.L1Loss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# Train the model
for epoch in range(2):
    running_loss = 0.0
    #print("1. Running mean:", net.bn2.state_dict())
    for i, data in enumerate(train_loader):
        inputs, labels = data
        outputs = net(inputs)

        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print('Epoch [{}/50], Loss: {:.4f}'.format(epoch+1, running_loss / len(train_loader)))
    print("===========================")

net.bn1.state_dict()

Epoch [1/50], Loss: 7.3685
Epoch [2/50], Loss: 7.1339


OrderedDict([('running_mean',
              tensor([0.4767, 0.5704, 0.4466, 0.5212, 0.4753, 0.5117, 0.5128, 0.5090, 0.5241,
                      0.5205])),
             ('running_var',
              tensor([0.0769, 0.0829, 0.0836, 0.0928, 0.0749, 0.0774, 0.0873, 0.0764, 0.0794,
                      0.0868]))])

In [78]:

net.eval()

torch.manual_seed(1)
test_sample = torch.randn(1, 10, 64)


output = net(test_sample)
print(output)
print(net.bn1.state_dict())

tensor([[-1.1568]], grad_fn=<AddmmBackward0>)
OrderedDict([('running_mean', tensor([-0.1130, -0.0192,  0.1874,  0.1517, -0.0068, -0.0377, -0.0822, -0.1019,
         0.1284, -0.0743])), ('running_var', tensor([1.0302, 0.9497, 1.2041, 1.1384, 1.1395, 0.9414, 0.9778, 1.0286, 0.6879,
        1.3984]))])


## Export onnx

In [9]:

net.eval()
x = torch.randn(1, 10, 64) 

# Export the model to ONNX format
with torch.no_grad():
    torch.onnx.export(net, x,'testtest.onnx',\
    verbose=False,opset_version=15,\
    input_names=["input"], output_names=['output'])

