In [None]:
import os, sys
project_dir = os.path.join(os.getcwd(),'..')
if project_dir not in sys.path:
    sys.path.append(project_dir)

import numpy as np
import torch
from torch import nn

from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor

In [None]:
transform = Compose([ToTensor(), nn.Flatten(start_dim=0)])
dataset = MNIST('../dataset', transform=transform, download=True)
loader = DataLoader(dataset, batch_size=128)

In [None]:
from Sparse import ReLUWithSparsity

model = nn.Sequential(*[
    nn.Linear(28*28, 28*28),
    ReLUWithSparsity(beta=1e-6, rho=0.05),
    nn.BatchNorm1d(28*28),
    nn.Linear(28*28, 28*28),
    nn.ReLU(inplace=True)
])

In [None]:
from tqdm import tqdm

n_epoch = 5
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-6)
criterion = nn.MSELoss()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = model.to(device)

epoch_iterator = tqdm(
        range(n_epoch),
        leave=True,
        unit="epoch",
        postfix={"tls": "%.4f" % 1},
    )

for epoch in epoch_iterator:
    for input, _ in loader:
        input = input.to(device)

        out = model(input)
        loss = criterion(out, input)

        epoch_iterator.set_postfix(tls="%.4f" % np.mean(loss.detach().item()))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [None]:
from torchvision.transforms import ToPILImage

to_img = ToPILImage()
img_in = to_img(input[0].reshape(1,28,28))
img_out = to_img(out[0].reshape(1,28,28))

In [None]:
from matplotlib import pyplot as plt
plt.imshow(img_in)
plt.show()
plt.imshow(img_out)
plt.show()

In [None]:
test = model[:2](input)
(test[0]>0).sum()

In [None]:
plt.subplot(1,4, 1)
plt.imshow(test[0].detach().cpu().reshape(28,28))
plt.subplot(1,4, 2)
plt.imshow(test[1].detach().cpu().reshape(28,28))
plt.subplot(1,4, 3)
plt.imshow(test[3].detach().cpu().reshape(28,28))
plt.subplot(1,4, 4)
plt.imshow(test[4].detach().cpu().reshape(28,28))
plt.show()