Based on

> Chorowski, Jan, and Jacek M. Zurada. "Learning understandable neural networks with nonnegative weight constraints." IEEE transactions on neural networks and learning systems 26.1 (2014)
https://sci-hub.se/10.1109/TNNLS.2014.2310059

> https://stats.stackexchange.com/q/572043

In [None]:
from argparse import Namespace

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt
import numpy as np

print("GPU :", torch.cuda.is_available())
print("CUDA:", torch.version.cuda)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.w1 = nn.Parameter(torch.Tensor(784, 10))
        # self.w2 = nn.Parameter(torch.Tensor(64, 10))

        self.a = nn.Sigmoid()
        # self.a = nn.ReLU()

        self.w1.data.uniform_(0, 1)
        # self.w2.data.uniform_(0, 1)

        # self.f = nn.Linear(28 * 28, 10)
        self.out = nn.LogSoftmax(dim=1)

    def forward(self, x):
        y = x.reshape(-1, 784) @ self.a(self.w1)
        return self.out(y)
        # return self.out(self.f(x.reshape(-1, 784)))


def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item(),
                )
            )


def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(
                output, target, reduction="sum"
            ).item()  # sum up batch loss
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print(
        "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            test_loss,
            correct,
            len(test_loader.dataset),
            100.0 * correct / len(test_loader.dataset),
        )
    )

In [None]:
args = Namespace()

args.batch_size = 512
args.test_batch_size = 512
args.epochs = 20
args.lr = 1e-3
args.gamma = 0.7
args.seed = 31337
args.log_interval = 50
args.cuda = True

use_cuda = args.cuda and torch.cuda.is_available()
if use_cuda:
    print("using CUDA")
else:
    print("using CPU")

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        os.path.join(os.environ["HOME"], "workspace/ml-data"),
        train=True,
        download=True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                # transforms.Normalize((0.1307,), (0.3081,))
            ]
        ),
    ),
    batch_size=args.batch_size,
    shuffle=True,
    **kwargs
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        os.path.join(os.environ["HOME"], "workspace/ml-data"),
        train=False,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                # transforms.Normalize((0.1307,), (0.3081,))
            ]
        ),
    ),
    batch_size=args.test_batch_size,
    shuffle=True,
    **kwargs
)

# torch.save(model.state_dict(), "mnist_cnn.pt")

In [None]:
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)

scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    test(args, model, device, test_loader)
    # scheduler.step()

In [None]:
model.to("cpu")

In [None]:
with torch.no_grad():
    fs = np.zeros((10, 784))

    for d in range(fs.shape[0]):
        for i in range(fs.shape[1]):
            x = torch.zeros(1, 784, dtype=torch.float32)
            x[0, i] = 1
            y = np.exp(model(x).squeeze().numpy())

            fs[d, i] = y[d]

plt.figure(figsize=(20, 8))

for i in range(fs.shape[0]):
    plt.subplot(2, 5, i + 1)
    plt.imshow(fs[i].reshape(28, 28))
    # plt.imshow(np.random.uniform(0, 1, size=(28, 28)))
    plt.title(i)

plt.tight_layout()