In [None]:
import os, sys
project_dir = os.path.join(os.getcwd(),'..')
if project_dir not in sys.path:
    sys.path.append(project_dir)

import numpy as np
import torch
from torch import nn

from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor

In [None]:
transform = Compose([ToTensor(), nn.Flatten(start_dim=0)])
dataset = MNIST('../dataset', transform=transform, download=True)
loader = DataLoader(dataset, batch_size=128)
x, y = next(iter(loader))

In [None]:
from Sparse import SparseWeights, KWinners

model = nn.Sequential(*[
    SparseWeights(nn.Linear(28*28, 128), weightSparsity=.4),
    KWinners(128, 64),
    nn.BatchNorm1d(128),
    SparseWeights(nn.Linear(128, 64), weightSparsity=.4),
    KWinners(64, 32),
    nn.BatchNorm1d(64),
    nn.Linear(64, 10),
    nn.LogSoftmax(dim=1)
])

In [None]:
from tqdm import tqdm

n_epoch = 10
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = model.to(device)

epoch_iterator = tqdm(
        range(n_epoch),
        leave=True,
        unit="epoch",
        postfix={"tls": "%.4f" % 1},
    )

for epoch in epoch_iterator:
    for input, target in loader:
        input = input.to(device)
        target = target.to(device)

        out = model(input)
        loss = criterion(out, target)

        epoch_iterator.set_postfix(tls="%.4f" % np.mean(loss.detach().item()))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [None]:
from matplotlib import pyplot as plt

bins = np.linspace(0.0, 0.03, 25)
plt.hist((model[1].dutyCycle / 32).cpu().numpy(), bins=25)
plt.title("Histogram of duty cycles, entropy=" + str(float(model[1].entropy())))
plt.xlabel("Duty cycle")
plt.ylabel("Number of units")
print('Max entropy: {}'.format(model[1].maxEntropy()))