In [None]:
import numpy as np
import torch
from torch import nn

from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor

In [None]:
transform = Compose([ToTensor(), nn.Flatten(start_dim=0)])
dataset = MNIST('dataset', transform=transform, download=True)
loader = DataLoader(dataset, batch_size=128)
x, y = next(iter(loader))

In [None]:
from modules import SparseWeights, KWinners

model = nn.Sequential(*[
    SparseWeights(nn.Linear(28*28, 128), weightSparsity=.4),
    KWinners(128, 100),
    SparseWeights(nn.Linear(128, 64), weightSparsity=.4),
    KWinners(64, 32),
    nn.Linear(64, 10),
    nn.LogSoftmax(dim=1)
])

# model = nn.Sequential(*[
#     nn.Linear(28*28, 512),
#     nn.ReLU(inplace=True),
#     nn.Linear(512, 128),
#     nn.ReLU(inplace=True),
#     nn.Linear(128, 64),
#     nn.ReLU(inplace=True),
#     nn.Linear(64, 10),
#     nn.LogSoftmax(dim=1)
# ])


In [None]:
from tqdm import tqdm

n_epoch = 10
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = model.to(device)

epoch_iterator = tqdm(
        range(n_epoch),
        leave=True,
        unit="epoch",
        postfix={"tls": "%.4f" % 1},
    )

for epoch in epoch_iterator:
    for input, target in loader:
        input = input.to(device)
        target = target.to(device)

        out = model(input)
        loss = criterion(out, target)

        epoch_iterator.set_postfix(tls="%.4f" % np.mean(loss.detach().item()))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


# AutoEncoder

In [None]:
# model = nn.Sequential(*[
#     SparseWeights(nn.Linear(28*28, 128), weightSparsity=.4),
#     KWinners(128, 100),
#     SparseWeights(nn.Linear(128, 64), weightSparsity=.4),
#     KWinners(64, 32),
#     nn.Linear(64, 128),
#     nn.ReLU(inplace=True),
#     nn.Linear(128, 28*28),
#     nn.ReLU(inplace=True)
# ])

model = nn.Sequential(*[
    nn.Linear(28*28, 128),
    KWinners(128, 100),
    nn.Linear(128, 64),
    KWinners(64, 32),
    nn.Linear(64, 128),
    nn.ReLU(inplace=True),
    nn.Linear(128, 28*28),
    nn.ReLU(inplace=True)
])


In [None]:
from tqdm import tqdm

n_epoch = 10
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = model.to(device)

epoch_iterator = tqdm(
        range(n_epoch),
        leave=True,
        unit="epoch",
        postfix={"tls": "%.4f" % 1},
    )

for epoch in epoch_iterator:
    for input, _ in loader:
        input = input.to(device)

        out = model(input)
        loss = criterion(out, input)

        epoch_iterator.set_postfix(tls="%.4f" % np.mean(loss.detach().item()))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [None]:
from torchvision.transforms import ToPILImage

to_img = ToPILImage()
img_in = to_img(input[0].reshape(1,28,28))
img_out = to_img(out[0].reshape(1,28,28))


In [None]:
from matplotlib import pyplot as plt
plt.imshow(img_in)
plt.show()
plt.imshow(img_out)
plt.show()

In [None]:
model[1].entropy()