In [1]:
import torch
from torch import nn
from torch import Tensor
from torch.nn import functional as F

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda

import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt

In [2]:
CLASS_NUM = 10

In [3]:
def mnist_loader(train: bool, batch_size: int) -> DataLoader:
    '''
    pytorch MNIST train & test 데이터 로더 반환
    + z-normalization
    + Flatten
    '''
    transform = Compose([
        ToTensor(),
        Normalize(mean=(0.1307, ), std=(0.3081, )),
        Lambda(lambda x: torch.flatten(x)),
    ])

    loader = DataLoader(
        MNIST(
            root='./mnist/',
            train=train,
            transform=transform,
            download=True
        ), shuffle=train, batch_size=batch_size
    )
    return loader

In [4]:
def combine_xy(x: Tensor, y: Tensor) -> Tensor:
    '''
    forward-forward Model 입력 데이터 반환

    X shape: Batch x 784(Ch * Height * Width)
    Y shape: Batch x Label
    '''
    batch_size = y.size(0)

    x_ = x.clone()
    x_[:, :CLASS_NUM] = 0.
    x_[range(batch_size), y] = x_.max()
    return x_

In [5]:
class FFLinear(nn.Linear):
    __constants__ = ('in_features', 'out_features')
    in_features: int
    out_features: int
    weight: Tensor
        
    def __init__(self, in_feature: int, out_features: int, bias: bool=True, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(FFLinear, self).__init__(in_feature, out_features, bias, **factory_kwargs)
        
        self.activation = nn.ReLU()
        self.optim = torch.optim.SGD(self.parameters(), lr=0.1)
        self.threshold = 2.0
        
    def forward(self, input) -> Tensor:
        out = self.__layerNorm(input)
        out = F.linear(out, self.weight, self.bias)
        return self.activation(out)
    
    def update(self, pos_x, neg_x) -> (Tensor, (Tensor, Tensor)):
        pos_out = self.forward(pos_x).pow(exponent=2).mean(dim=1) #shape: (Batch, )
        neg_out = self.forward(neg_x).pow(exponent=2).mean(dim=1) #shape: (Batch, )
        
        loss = torch.cat([-pos_out + self.threshold, neg_out - self.threshold])
        loss = torch.log(1. + torch.exp(loss)).mean()
        
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
        return (self.forward(pos_x).detach(), self.forward(neg_x).detach())
    
    def __layerNorm(self, input: Tensor, eps: float=1e-4) -> Tensor:
        '''
        + 참고 repository의 정규화 코드
        input / (input.norm(p=2, dim=1, keepdim=True) + 1e-4)
        
        ## https://github.com/mohammadpz/pytorch_forward_forward/blob/main/main.py
        '''
        mean_ = input.mean(dim=1, keepdim=True)
        var_ = input.var(dim=1, keepdim=True, unbiased=False) #unbiased True=(N-1), False=N
        return (input - mean_) / torch.sqrt(var_ + eps)

In [6]:
class FFModel(nn.Module):
    __constants__ = ('dims', )
    dims: list
        
    def __init__(self, dims: list, device='cpu') -> None:
        super(FFModel, self).__init__()
        self.layers = tuple(FFLinear(dims[d], dims[d+1]).to(device) for d in range(len(dims) - 1))

    def forward(self, input) -> Tensor:
        batch_size = input.size(0)
        goodness = torch.zeros(batch_size)

        out = input
        for layer in self.layers:
            out = layer(out)
            goodness += out.pow(exponent=2).mean(dim=1)
        return goodness
    
    def update(self, pos_x: Tensor, neg_x: Tensor) -> None:
        pos_out, neg_out = pos_x, neg_x
        for layer in self.layers:
            pos_out, neg_out = layer.update(pos_out, neg_out)

In [7]:
def get_neg_y(y: Tensor, class_num: int) -> Tensor:
    batch_size = y.size(0)
    
    able_idxs = torch.arange(class_num).unsqueeze(0).repeat(batch_size, 1)
    able_idxs = able_idxs[able_idxs != y.view(batch_size, 1)].view(batch_size, class_num-1)
    
    rand_idxs = torch.randint(class_num - 1, size=(batch_size, ))
    return able_idxs[range(batch_size), rand_idxs]

In [8]:
BATCH_SIZE = 1
EPOCHS = 5


model = FFModel(dims=(784, 100, 100))

for e in range(EPOCHS):
    train_loader = iter(mnist_loader(train=False, batch_size=BATCH_SIZE))
    for (pos_x, pos_y), (neg_x, neg_y) in tqdm(zip(train_loader, train_loader), total=5000):
        pos_x = combine_xy(pos_x, pos_y)

        neg_y = get_neg_y(neg_y, class_num=CLASS_NUM)
        neg_x = combine_xy(neg_x, neg_y)

        model.update(pos_x, neg_x)

100%|█████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:09<00:00, 536.03it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:09<00:00, 540.16it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:09<00:00, 548.22it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:09<00:00, 548.41it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:09<00:00, 555.32it/s]


In [9]:
train_loader = mnist_loader(train=True, batch_size=256)

model.eval()
with torch.no_grad():
    acc = list()
    for x, y in tqdm(train_loader):
        batch_size = x.size(0)

        x = x.unsqueeze(1).repeat(1, CLASS_NUM, 1).view(-1, 784)
        y_batchs = torch.arange(CLASS_NUM).repeat(batch_size)
        x = combine_xy(x, y_batchs)

        goodness = model(x).view(batch_size, -1)

        y_hat = goodness.argmax(dim=1)
        acc.extend(y_hat.eq(y).float().tolist())
        
        
print(sum(acc) / len(acc))

100%|████████████████████████████████████████████████████████████████████████████████| 235/235 [00:09<00:00, 24.67it/s]

0.8557833333333333





In [10]:
test_loader = mnist_loader(train=False, batch_size=256)

model.eval()
with torch.no_grad():
    acc = list()
    for x, y in tqdm(test_loader):
        batch_size = x.size(0)

        x = x.unsqueeze(1).repeat(1, CLASS_NUM, 1).view(-1, 784)
        y_batchs = torch.arange(CLASS_NUM).repeat(batch_size)
        x = combine_xy(x, y_batchs)

        goodness = model(x).view(batch_size, -1)

        y_hat = goodness.argmax(dim=1)
        acc.extend(y_hat.eq(y).float().tolist())
        
        
print(sum(acc) / len(acc))

100%|██████████████████████████████████████████████████████████████████████████████████| 40/40 [00:01<00:00, 25.25it/s]

0.8799



