In [1]:
import torch
from torch import nn
from torch import Tensor
from torch.nn import functional as F

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda

from tqdm import tqdm
from matplotlib import pyplot as plt

In [2]:
def mnist_loader(train_batch_size: int=1, test_batch_size: int=1) -> (DataLoader, DataLoader):
    '''
    pytorch MNIST train & test 데이터 로더 반환
    + z-normalization
    + Flatten
    '''
    transform = Compose([
        ToTensor(),
        Normalize(mean=(0.1307, ), std=(0.3081, )),
        Lambda(lambda x: torch.flatten(x)),
    ])
    
    train_loader = DataLoader(
        MNIST(
            root='./mnist/',
            train=True,
            transform=transform,
            download=True
        ), shuffle=True, batch_size=train_batch_size
    )
    
    test_loader = DataLoader(
        MNIST(
            root='./mnist/',
            train=False,
            transform=transform,
            download=True
        ), shuffle=False, batch_size=test_batch_size
    )
    return train_loader, test_loader

In [3]:
def create_input_data(x: Tensor, y: Tensor) -> Tensor:
    '''
    forward-forward Model 입력 데이터 반환
    
    X shape: Batch x 784(Ch * Height * Width)
    Y shape: Batch x Label
    '''
    batch_size = y.size(0)
    
    x_ = x.clone()
    x_[:, :10] = 0.
    x_[range(batch_size), y] = x_.max()
    return x_

In [23]:
class FFLinear(nn.Linear):
    __constants__ = ('in_features', 'out_features')
    in_features: int
    out_features: int
    weight: Tensor
        
    def __init__(self, in_feature: int, out_features: int, bias: bool=True, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(FFLinear, self).__init__(in_feature, out_features, bias, **factory_kwargs)
        
        self.activation = nn.ReLU()
        self.optim = torch.optim.SGD(self.parameters(), lr=0.01)
        self.threshold = 2.0
        
    def forward(self, input) -> Tensor:
        out = self.__layerNorm(input)
        out = F.linear(out, self.weight, self.bias)
        return self.activation(out)
    
    def update(self, pos_x, neg_x) -> Tensor:
        pos_out = self.forward(pos_x).pow(exponent=2).mean(dim=1) #shape: (Batch, )
        neg_out = self.forward(neg_x).pow(exponent=2).mean(dim=1) #shape: (Batch, )
        
        print('=' * 50)
        print(pos_out)
        print(neg_out)
        print('=' * 50)
        print(-pos_out + self.threshold)
        print(neg_out - self.threshold)
        print('=' * 50)
        print(
            1. + torch.exp(
                torch.cat([
                    -pos_out + self.threshold,
                    neg_out - self.threshold
                ])
            )
        )
        raise
    
    def __layerNorm(self, input: Tensor, eps: float=1e-4) -> Tensor:
        '''
        + 참고 repository의 정규화 코드
        input / (input.norm(p=2, dim=1, keepdim=True) + 1e-4)
        
        ## https://github.com/mohammadpz/pytorch_forward_forward/blob/main/main.py
        '''
        
        mean_ = input.mean(dim=1, keepdim=True)
        var_ = input.var(dim=1, keepdim=True, unbiased=False) #unbiased True=(N-1), False=N
        return (input - mean_) / torch.sqrt(var_ + eps)
    
    
    
pos_x = torch.randn(2, 4)
neg_x = torch.randn(2, 4)

layer = FFLinear(4, 2)

print(pos_x)
print(neg_x)
layer.update(pos_x, neg_x)

tensor([[-0.3620,  0.6799,  0.2078, -1.0088],
        [-0.0130,  1.0682, -1.3792, -1.3598]])
tensor([[-0.2595,  0.4917,  1.3721, -1.1487],
        [ 0.3802, -1.0547,  2.2308, -0.9562]])
tensor([0.5960, 0.4617], grad_fn=<MeanBackward1>)
tensor([0.7995, 0.9600], grad_fn=<MeanBackward1>)
tensor([1.4040, 1.5383], grad_fn=<AddBackward0>)
tensor([-1.2005, -1.0400], grad_fn=<SubBackward0>)
tensor([5.0716, 5.6568, 1.3010, 1.3535], grad_fn=<AddBackward0>)


RuntimeError: No active exception to reraise