In [65]:
import torch
import numpy as np
import matplotlib.pyplot as plt

class Flatten:
  def __init__(self, block_size):
    self.block_size = block_size

  def __call__(self, x):
    B,T,C = x.shape
    out = x.view(B, T//self.block_size, self.block_size*C)
    if T//self.block_size == 1:
      out = out.squeeze(1)
    return out

  def parameters(self):
    return []

class BatchNorm1D:
  def __init__(self, num_features, eps=1e-5, momentum=0.1):
    self.beta = torch.zeros(num_features)
    self.gamma = torch.ones(num_features)
    self.eps = eps
    self.momentum = momentum
    self.training = True
    self.running_mean = torch.zeros(num_features)
    self.running_var = torch.randn(num_features)

  def __call__(self, x):
    dim = tuple(torch.arange(len(x.shape) - 1))
    # print(x.shape, dim)
    if self.training:
      mu = x.mean(dim, keepdim=True)
      var = x.var(dim, keepdim=True)
      self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mu
      self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
    else:
      mu = self.running_mean
      var = self.running_var

    # print(mu,var)
    x1 = (x - mu) / (var + self.eps)
    self.out = self.gamma * x1 + self.beta
    return self.out

  def parameters(self):
    return [self.gamma, self.beta]

class Sequential:
  def __init__(self, *args):
    pass

flatten = Flatten(2)
print(f'flatten: {flatten(torch.arange(3*2*2).view(3,2,2))}')

bn = BatchNorm1D(4)
print(f'bn: {bn(torch.ones(3,4))}')

flatten: tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
bn: tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
