# Imports

In [None]:
import numpy as np
import torch
from matplotlib import pyplot as plt
from scipy.signal import convolve2d
from torch import tensor, Tensor
import torchvision
from tqdm import tqdm
import os
import torchvision
from torch import nn
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import accuracy_score


# Required classes and Functions

In [None]:
def get_y_neg(y):
    y_neg = y.clone()
    for idx, y_samp in enumerate(y):
        allowed_indices = list(range(10))
        allowed_indices.remove(y_samp.item())
        y_neg[idx] = torch.tensor(allowed_indices)[
            torch.randint(len(allowed_indices), size=(1,))
        ].item()
    return y_neg.to(device)


def overlay_y_on_x(x, y, classes=10):
    x_ = x.clone()
    x_[:, :classes] *= 0.0
    x_[range(x.shape[0]), y] = x.max()
    return x_


def get_metrics(preds, labels):
    acc = accuracy_score(labels, preds)
    return dict(accuracy_score=acc)


class FF_Layer(nn.Linear):
    def __init__(self, in_features: int, out_features: int, n_epochs: int, bias: bool, device):
        super().__init__(in_features, out_features, bias=bias)
        self.n_epochs = n_epochs
        self.opt = torch.optim.Adam(self.parameters(),lr=0.03)
        self.goodness = self.goodness_score
        self.to(device)
        self.relu = torch.nn.ReLU()


    def goodness_score(self, x_pos, x_neg, threshold=2):
        g_pos = self(x_pos).pow(2).mean(1)
        g_neg = self(x_neg).pow(2).mean(1)
        loss = torch.log(1+ torch.exp(
                      torch.cat([-g_pos + 2, g_neg - 2]))).mean()
        return loss

    def forward(self, x):
        x_direction = x / (x.norm(2, 1, keepdim=True) + 1e-4)
        return self.relu(torch.mm(x_direction, self.weight.T) + self.bias.unsqueeze(0))

    def ff_train(self, pos_acts, neg_acts):

        self.opt.zero_grad()
        goodness = self.goodness(pos_acts, neg_acts)
        goodness.backward()
        self.opt.step()

        return goodness.item()



class Unsupervised_FF(nn.Module):
    def __init__(self, n_layers: int = 2, n_neurons=500, input_size: int = 28 * 28, n_epochs: int = 100,
                 bias: bool = True, n_classes: int = 10, n_hid_to_log: int = 2, device='cuda'):
        super().__init__()
        self.n_hid_to_log = n_hid_to_log
        self.n_epochs = n_epochs
        self.device = device
        self.n_layers = n_layers

        ff_layers = [
            FF_Layer(in_features=input_size if idx == 0 else n_neurons,
                     out_features=n_neurons,
                     n_epochs=n_epochs,
                     bias=bias,
                     device=device) for idx in range(n_layers)]

        self.ff_layers = ff_layers

    def train_ff_layers(self, pos_dataloader):
        loss = [0]*self.n_layers
        for epoch in range(self.n_epochs):
            for i, pos_data in enumerate(pos_dataloader):
                pos_imgs, labels = pos_data
                pos_acts = torch.reshape(pos_imgs, (pos_imgs.shape[0], -1)).to(self.device)
                pos_acts = overlay_y_on_x(pos_acts, labels)
                neg_labels = get_y_neg(labels)
                neg_acts = overlay_y_on_x(pos_acts, neg_labels)
                if i%100 == 0:
                  print(f'[loss layer 1, loss layer 2] = {loss}')
                for idx, layer in enumerate(self.ff_layers):
                    loss[idx] = layer.ff_train(pos_acts, neg_acts)
                    pos_acts = layer(pos_acts)
                    neg_acts = layer(neg_acts)



    def evaluate(self, dataloader: DataLoader, dataset_type: str = "train"):
        self.eval()
        all_labels = []
        all_preds = []
        for images, labels in dataloader:
            images = images.to(self.device)
            labels = labels.to(self.device)
            x = torch.reshape(images, (images.shape[0], -1)).to(self.device)
            goodness_per_label = []
            for label in range(10):
                h = overlay_y_on_x(x, label)
                goodness = []
                for layer in self.ff_layers:
                    h = layer(h)
                    goodness = goodness + [h.pow(2).mean(1)]
                goodness_per_label += [sum(goodness).unsqueeze(1)]
            goodness_per_label = torch.cat(goodness_per_label, 1)
            return goodness_per_label.argmax(1)



# Training the Model and Showing the results

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])
pos_dataset = torchvision.datasets.MNIST(root='./', download=True, transform=transform, train=True)
pos_dataloader = DataLoader(pos_dataset, batch_size=256, shuffle=True)
test_dataset = torchvision.datasets.MNIST(root='./', train=False, download=True, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=10000)


unsupervised_ff = Unsupervised_FF(device=device, n_epochs=100)
unsupervised_ff.train_ff_layers(pos_dataloader)


[loss layer 1, loss layer 2] = [0, 0]
[loss layer 1, loss layer 2] = [0.695438027381897, 0.6961270570755005]
[loss layer 1, loss layer 2] = [0.631627082824707, 0.6682161688804626]
[loss layer 1, loss layer 2] = [0.5928605794906616, 0.6224926710128784]
[loss layer 1, loss layer 2] = [0.5431445837020874, 0.5708013772964478]
[loss layer 1, loss layer 2] = [0.4694785475730896, 0.49569588899612427]
[loss layer 1, loss layer 2] = [0.4455793797969818, 0.4524715840816498]
[loss layer 1, loss layer 2] = [0.4109407067298889, 0.4268926978111267]
[loss layer 1, loss layer 2] = [0.38530591130256653, 0.3999406099319458]
[loss layer 1, loss layer 2] = [0.369320273399353, 0.390566349029541]
[loss layer 1, loss layer 2] = [0.36510998010635376, 0.3598301410675049]
[loss layer 1, loss layer 2] = [0.35682758688926697, 0.38679641485214233]
[loss layer 1, loss layer 2] = [0.30811697244644165, 0.3351837992668152]
[loss layer 1, loss layer 2] = [0.30782175064086914, 0.3316900432109833]
[loss layer 1, loss lay

In [None]:
train_dataloader = DataLoader(pos_dataset, batch_size=50000)
x_train, y_train = next(iter(train_dataloader))
x_train, y_train = x_train.to(device), y_train.to(device)
x_te, y_te = next(iter(test_dataloader))
x_te, y_te = x_te.to(device), y_te.to(device)
print("train accuracy:", unsupervised_ff.evaluate(train_dataloader).eq(y_train).float().mean().item())
print("test accuracy:", unsupervised_ff.evaluate(test_dataloader).eq(y_te).float().mean().item())

train accuracy: 0.9863199591636658
test accuracy: 0.9734999537467957
