In [1]:
from LoadingDefault import LoadData

from torch import nn
import torch.optim as optim
import torch
import torch.nn.functional as F

from NLP.nlp import LaplacianPyramid

dataloader = LoadData(limit=100, batch_size=8)



In [2]:
class EntropyLimitedAutoencoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.centers = torch.Tensor([-1, 1])
        self.sigma = 10
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=128, kernel_size=3, stride=2, padding=1),  # Reduce tamaño a 128x128
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1),  # Reduce tamaño a 128x128
            nn.BatchNorm2d(128),
            nn.Tanh()
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=128, out_channels=1, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )


    def encode(self, x):
        y = self.encoder(x)
        return y

    
    def quantise(self, y):
        if self.centers is None:
            return y
        y_flat = y.reshape(y.size(0), y.size(1), y.size(2)*y.size(3), 1)
        dist = torch.abs((y_flat - self.centers))**2
        if self.train:
            phi = F.softmax(-self.sigma * dist, dim=-1)
        else:
            phi = F.softmax(-1e7 * dist, dim=-1)
            symbols_hard = torch.argmax(phi, axis=-1)
            phi = F.one_hot(symbols_hard, num_classes=self.centers.size(0))
        inner_product = phi * self.centers
        y_hat = torch.sum(inner_product, axis=-1)
        y_hat = y_hat.reshape(y.shape)
        return y_hat
    

    def decode(self, y):
        x = self.decoder(y)
        return x
    

    def forward(self, x):
        encoded = self.encode(x)
        limit_entropy = self.quantise(encoded)
        decoded = self.decode(limit_entropy)
        return decoded

In [3]:
class NLPDLoss(nn.Module):
    def __init__(self):
        super(NLPDLoss, self).__init__()
        self.lp = LaplacianPyramid(5, dims=1)
    def forward(self, reconstructed, original):
        return self.lp.compare(reconstructed, original)

In [4]:
ae = EntropyLimitedAutoencoder()
criterion = NLPDLoss()
optimizer = optim.AdamW(ae.parameters(), lr=1e-3, weight_decay=1e-4)

In [5]:
num_epochs = 3

for epoch in range(num_epochs):
    total_loss = 0
    for batch in dataloader:  # dataloader ya tiene los batches de 64x1x256x256
        batch = batch[0] # Extraer tensor

        optimizer.zero_grad()  # Reiniciar gradientes

        outputs = ae(batch)  # Forward pass
        loss = criterion(outputs, batch)  # Comparar con entrada

        loss.backward()  # Backpropagation
        optimizer.step()  # Actualizar pesos

        total_loss += loss.item()
        print(f"Pérdida: {loss.item():.6f}")
    avg_loss = total_loss / len(dataloader)
    print(f"Época [{epoch+1}/{num_epochs}], Pérdida: {avg_loss:.6f}")

Pérdida: 567.971924
Pérdida: 673.571228
Pérdida: 506.436188
Pérdida: 468.200806
Pérdida: 466.642090
Pérdida: 430.664551
Pérdida: 411.428070
Pérdida: 387.182587
Pérdida: 384.587006
Pérdida: 396.997742
Pérdida: 390.436096
Pérdida: 366.394287
Pérdida: 367.102020
Pérdida: 350.306885
Pérdida: 336.681915
Pérdida: 374.526917
Pérdida: 352.703369
Pérdida: 330.961914
Pérdida: 343.229065
Pérdida: 321.024292
Pérdida: 311.774811
Pérdida: 330.822906
Pérdida: 320.142426
Pérdida: 320.691040
Pérdida: 300.427246
Época [1/3], Pérdida: 392.436295
Pérdida: 313.868622
Pérdida: 294.127197
Pérdida: 307.082184
Pérdida: 295.177521
Pérdida: 302.103271
Pérdida: 287.711884
Pérdida: 288.455505
Pérdida: 286.674072
Pérdida: 284.692871
Pérdida: 294.267792
Pérdida: 299.993042
Pérdida: 286.634644
Pérdida: 282.825012
Pérdida: 272.759674
Pérdida: 307.043579
Pérdida: 280.452606
Pérdida: 270.740662
Pérdida: 300.393951
Pérdida: 280.150299
Pérdida: 298.725800
Pérdida: 270.566528
Pérdida: 282.339874
Pérdida: 275.558624
Pérdida