Sparse Autoenconder - Versão Linear

In [25]:
import sys
import os
sys.path.append(os.path.join(os.getcwd(), '..'))
from utils.load_mnist import load_mnist

In [26]:
import torch
import torch.nn as nn 
import torch.optim as optim
import matplotlib.pyplot as plt
from utils.load_mnist import load_mnist

In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [28]:
class SparseLinearAE(nn.Module):
    def __init__(self, input_dim, latent_dim, sparsity_coeff=0.01):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim),
            nn.Sigmoid()
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, input_dim),
            nn.Sigmoid()
        )
        self.sparsity_coeff = sparsity_coeff

    def forward(self, x):
        latent = self.encoder(x)
        reconstructed = self.decoder(latent)
        return reconstructed, latent

    def loss(self, x, reconstructed, latent):
        mse_loss = nn.MSELoss()(reconstructed, x)
        l1_loss = torch.norm(latent, p=1)
        return mse_loss + self.sparsity_coeff * l1_loss

In [29]:
def train(model, train_loader, epochs=50, lr=1e-3):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    for epoch in range(epochs):
        for batch in train_loader:
            x, _ = batch
            x = x.view(x.size(0), -1).to(device)
            reconstructed, latent = model(x)
            loss = model.loss(x, reconstructed, latent)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

In [30]:
if __name__ == "__main__":
    train_loader = load_mnist()
    model = SparseLinearAE(28*28, 32)
    train(model, train_loader)
    torch.save(model.state_dict(), 'tests/sparse_linear.pth')

Epoch [1/50], Loss: 0.0669
Epoch [2/50], Loss: 0.0685
Epoch [3/50], Loss: 0.0682
Epoch [4/50], Loss: 0.0670
Epoch [5/50], Loss: 0.0631
Epoch [6/50], Loss: 0.0652
Epoch [7/50], Loss: 0.0659
Epoch [8/50], Loss: 0.0709
Epoch [9/50], Loss: 0.0616
Epoch [10/50], Loss: 0.0678
Epoch [11/50], Loss: 0.0653
Epoch [12/50], Loss: 0.0624
Epoch [13/50], Loss: 0.0660
Epoch [14/50], Loss: 0.0705
Epoch [15/50], Loss: 0.0658
Epoch [16/50], Loss: 0.0708
Epoch [17/50], Loss: 0.0653
Epoch [18/50], Loss: 0.0672
Epoch [19/50], Loss: 0.0608
Epoch [20/50], Loss: 0.0640
Epoch [21/50], Loss: 0.0675
Epoch [22/50], Loss: 0.0772
Epoch [23/50], Loss: 0.0663
Epoch [24/50], Loss: 0.0659
Epoch [25/50], Loss: 0.0656
Epoch [26/50], Loss: 0.0663
Epoch [27/50], Loss: 0.0709
Epoch [28/50], Loss: 0.0620
Epoch [29/50], Loss: 0.0728
Epoch [30/50], Loss: 0.0672
Epoch [31/50], Loss: 0.0692
Epoch [32/50], Loss: 0.0711
Epoch [33/50], Loss: 0.0686
Epoch [34/50], Loss: 0.0725
Epoch [35/50], Loss: 0.0639
Epoch [36/50], Loss: 0.0666
E