In [1]:
from LoadingDefault import LoadData

from torch import nn
import torch.optim as optim
from torchinfo import summary

import torch
import torch.nn.functional as F



In [2]:
dataloader = LoadData(limit=100, batch_size=8)



In [3]:
class EntropyLimitedAutoencoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.centers = torch.Tensor([-1, 1])
        self.sigma = 10
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=128, kernel_size=3, stride=2, padding=1),  # Reduce tamaño a 128x128
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1),  # Reduce tamaño a 128x128
            nn.BatchNorm2d(128),
            nn.Tanh()
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=128, out_channels=1, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )


    def encode(self, x):
        y = self.encoder(x)
        return y

    
    def quantise(self, y):
        if self.centers is None:
            return y
        y_flat = y.reshape(y.size(0), y.size(1), y.size(2)*y.size(3), 1)
        dist = torch.abs((y_flat - self.centers))**2
        if self.train:
            phi = F.softmax(-self.sigma * dist, dim=-1)
        else:
            phi = F.softmax(-1e7 * dist, dim=-1)
            symbols_hard = torch.argmax(phi, axis=-1)
            phi = F.one_hot(symbols_hard, num_classes=self.centers.size(0))
        inner_product = phi * self.centers
        y_hat = torch.sum(inner_product, axis=-1)
        y_hat = y_hat.reshape(y.shape)
        return y_hat
    

    def decode(self, y):
        x = self.decoder(y)
        return x
    

    def forward(self, x):
        encoded = self.encode(x)
        limit_entropy = self.quantise(encoded)
        decoded = self.decode(limit_entropy)
        return decoded

In [4]:
ae = EntropyLimitedAutoencoder()
criterion = nn.MSELoss()
optimizer = optim.AdamW(ae.parameters(), lr=1e-3, weight_decay=1e-4)
summary(ae)

Layer (type:depth-idx)                   Param #
EntropyLimitedAutoencoder                --
├─Sequential: 1-1                        --
│    └─Conv2d: 2-1                       1,280
│    └─BatchNorm2d: 2-2                  256
│    └─LeakyReLU: 2-3                    --
│    └─Conv2d: 2-4                       147,584
│    └─BatchNorm2d: 2-5                  256
│    └─LeakyReLU: 2-6                    --
│    └─Conv2d: 2-7                       147,584
│    └─BatchNorm2d: 2-8                  256
│    └─LeakyReLU: 2-9                    --
│    └─Conv2d: 2-10                      147,584
│    └─BatchNorm2d: 2-11                 256
│    └─Tanh: 2-12                        --
├─Sequential: 1-2                        --
│    └─ConvTranspose2d: 2-13             147,584
│    └─BatchNorm2d: 2-14                 256
│    └─LeakyReLU: 2-15                   --
│    └─ConvTranspose2d: 2-16             147,584
│    └─BatchNorm2d: 2-17                 256
│    └─LeakyReLU: 2-18               

In [5]:
num_epochs = 3

for epoch in range(num_epochs):
    total_loss = 0
    for batch in dataloader:  # dataloader ya tiene los batches de 64x1x256x256
        batch = batch[0] # Extraer tensor y mover a GPU si es necesario

        optimizer.zero_grad()  # Reiniciar gradientes

        outputs = ae(batch)  # Forward pass
        loss = criterion(outputs, batch)  # Comparar con entrada

        loss.backward()  # Backpropagation
        optimizer.step()  # Actualizar pesos

        total_loss += loss.item()
        print(f"Pérdida: {loss.item():.6f}")
    avg_loss = total_loss / len(dataloader)
    print(f"Época [{epoch+1}/{num_epochs}], Pérdida: {avg_loss:.6f}")

Pérdida: 0.061320
Pérdida: 0.065787
Pérdida: 0.033257
Pérdida: 0.018635
Pérdida: 0.013751
Pérdida: 0.021841
Pérdida: 0.021061
Pérdida: 0.024489
Pérdida: 0.026121
Pérdida: 0.014240
Pérdida: 0.032181
Pérdida: 0.016722
Pérdida: 0.012907
Pérdida: 0.011670
Pérdida: 0.013230
Pérdida: 0.018964
Pérdida: 0.013425
Pérdida: 0.015355


KeyboardInterrupt: 