In [1]:
import torch
from torch import nn
from torch import Tensor
from torch.nn import functional as F

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda

from tqdm import tqdm
from matplotlib import pyplot as plt

In [2]:
def mnist_loader(train_batch_size: int=1, test_batch_size: int=1) -> (DataLoader, DataLoader):
    '''
    pytorch MNIST train & test 데이터 로더 반환
    + z-normalization
    + Flatten
    '''
    transform = Compose([
        ToTensor(),
        Normalize(mean=(0.1307, ), std=(0.3081, )),
        Lambda(lambda x: torch.flatten(x)),
    ])
    
    train_loader = DataLoader(
        MNIST(
            root='./mnist/',
            train=True,
            transform=transform,
            download=True
        ), shuffle=True, batch_size=train_batch_size
    )
    
    test_loader = DataLoader(
        MNIST(
            root='./mnist/',
            train=False,
            transform=transform,
            download=True
        ), shuffle=False, batch_size=test_batch_size
    )
    return train_loader, test_loader

In [3]:
def create_input_data(x: Tensor, y: Tensor) -> Tensor:
    '''
    forward-forward Model 입력 데이터 반환
    
    X shape: Batch x 784(Ch * Height * Width)
    Y shape: Batch x Label
    '''
    batch_size = y.size(0)
    
    x_ = x.clone()
    x_[:, :10] = 0.
    x_[range(batch_size), y] = x_.max()
    return x_

In [4]:
class FFLinear(nn.Linear):
    __constants__ = ('in_features', 'out_features')
    in_features: int
    out_features: int
    weight: Tensor
        
    def __init__(self, in_feature: int, out_features: int, bias: bool=True, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(FFLinear, self).__init__(in_feature, out_features, bias, **factory_kwargs)
        
        self.activation = nn.ReLU()
        self.optim = torch.optim.SGD(self.parameters(), lr=0.01)
        self.threshold = 2.0
        
    def forward(self, input) -> Tensor:
        return F.linear(input, self.weight, self.bias)
    
    def update(self, pos_x, neg_x) -> Tensor:
        pass
    
    def __layerNorm(self, input: Tensor, eps: float=1e-4) -> Tensor:
        '''
        + 참고 repository의 정규화 코드
        input / (input.norm(p=2, dim=1, keepdim=True) + 1e-4)
        
        ## https://github.com/mohammadpz/pytorch_forward_forward/blob/main/main.py
        '''
        
        mean_ = input.mean(dim=1, keepdim=True)
        var_ = input.var(dim=1, keepdim=True, unbiased=False) #unbiased True=(N-1), False=N
        return (input - mean_) / torch.sqrt(var + eps)

In [18]:
tens = torch.randn(2, 5)
layerNorm = torch.nn.LayerNorm(5, elementwise_affine = False)

print(tens)
print(layerNorm(tens))

mean_ = tens.mean(dim=1, keepdim=True)
var = tens.var(dim=1, keepdim=True, unbiased=False)

print((tens - mean_) / torch.sqrt(var + 1e-4))

tensor([[ 1.0942,  0.3430, -1.4909, -0.3613, -2.2951],
        [ 1.7743,  0.0248,  0.1637,  0.2644,  0.5107]])
tensor([[ 1.3389,  0.7242, -0.7765,  0.1479, -1.4346],
        [ 1.9362, -0.8251, -0.6059, -0.4470, -0.0582]])
tensor([[ 1.3389,  0.7242, -0.7764,  0.1479, -1.4345],
        [ 1.9360, -0.8250, -0.6059, -0.4470, -0.0582]])


In [5]:
tens = torch.randn(2, 5)

print(tens)
print(tens / (tens.norm(p=2, dim=1, keepdim=True)))

tensor([[-1.7091,  0.8794,  1.9976,  2.6119,  2.8200],
        [ 1.1504,  0.9328,  1.7525, -1.1793,  1.8903]])
tensor([[-0.3606,  0.1856,  0.4215,  0.5511,  0.5950],
        [ 0.3597,  0.2917,  0.5480, -0.3687,  0.5910]])


In [6]:
train_loader, test_loader = mnist_loader(train_batch_size=4, test_batch_size=1)

for x, y in train_loader:
    break

input_ = create_input_data(x, y)
print(input_)
print(input_ / (input_.norm(p=2, dim=1, keepdim=True) + 1e-4)) #-> (∑|x|**p)**(1/p)

mean_ = input_.mean(dim=1, keepdim=True)
var = input_.var(dim=1, keepdim=True, unbiased=True)

print((input_ - mean_) / torch.sqrt(var + 1e-4))

tensor([[ 0.0000,  2.8215,  0.0000,  ..., -0.4242, -0.4242, -0.4242],
        [ 2.8215,  0.0000,  0.0000,  ..., -0.4242, -0.4242, -0.4242],
        [ 2.8215,  0.0000,  0.0000,  ..., -0.4242, -0.4242, -0.4242],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.4242, -0.4242, -0.4242]])
tensor([[ 0.0000,  0.1170,  0.0000,  ..., -0.0176, -0.0176, -0.0176],
        [ 0.0937,  0.0000,  0.0000,  ..., -0.0141, -0.0141, -0.0141],
        [ 0.0767,  0.0000,  0.0000,  ..., -0.0115, -0.0115, -0.0115],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0155, -0.0155, -0.0155]])
tensor([[ 0.1538,  3.4660,  0.1538,  ..., -0.3442, -0.3442, -0.3442],
        [ 2.5507, -0.0786, -0.0786,  ..., -0.4740, -0.4740, -0.4740],
        [ 1.9563, -0.2632, -0.2632,  ..., -0.5969, -0.5969, -0.5969],
        [ 0.0065,  0.0065,  0.0065,  ..., -0.4267, -0.4267, -0.4267]])
