In [17]:
from LoadingDefault import LoadData, Autoencoder

from torch import nn
import torch.optim as optim
from torchinfo import summary

import torch
import torch.nn.functional as F

In [2]:
dataloader = LoadData()



In [None]:
class EntropyLimitedModel(nn.Module):
    """Autoencoder that quantises to centers and can be
    optimised for only distortion term, with theoretical limit
    on the entropy."""
    def __init__(self, sigma=1, N=128, M=64, sigmoid=False, centers=None):
        super().__init__()
        self.entropy_bottleneck = None
        self.sigma = sigma
        if centers == 1:
            cent = torch.Tensor([1])
        elif centers == 2:
            cent = torch.Tensor([-1, 1])
        elif centers == 5:
            cent = torch.Tensor([-2, -1, 0, 1, 2])
        else:
            cent = None
        self.register_buffer('centers', cent)
        #self.centers = centers
        padding = 3 // 2
        self.encoder = nn.Sequential(
            nn.Conv2d(3, N, kernel_size=3, stride=2, padding=padding),
            GDN(N),
            nn.Conv2d(N, N, kernel_size=3, stride=2, padding=padding),
            GDN(N),
            nn.Conv2d(N, N, kernel_size=3, stride=2, padding=padding),
            GDN(N),
            nn.Conv2d(N, N, kernel_size=3, stride=2, padding=padding),
            GDN(N),
            nn.Conv2d(N, M, kernel_size=3, stride=1, padding=padding))

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(M, N, kernel_size=3, stride=1, padding=padding),
            GDN(N, inverse=True),
            nn.ConvTranspose2d(N, N, kernel_size=3, stride=2, padding=padding,
                              output_padding=1),
            GDN(N, inverse=True),
            nn.ConvTranspose2d(N, N, kernel_size=3, stride=2, padding=padding,
                              output_padding=1),
            GDN(N, inverse=True),
            nn.ConvTranspose2d(N, N, kernel_size=3, stride=2, padding=padding,
                              output_padding=1),
            GDN(N, inverse=True),
            nn.ConvTranspose2d(N, 3, kernel_size=3, stride=2, padding=padding,
                              output_padding=1))

        if sigmoid:
            self.decoder.add_module('sig', nn.Sigmoid())

    def encode(self, x):
        y = self.encoder(x)
        return y

    def quantise(self, y):
        if self.centers is None:
            return y
        y_flat = y.reshape(y.size(0), y.size(1), y.size(2)*y.size(3), 1)
        dist = torch.abs((y_flat - self.centers))**2
        if self.train:
            phi = F.softmax(-self.sigma * dist, dim=-1)
        else:
            phi = F.softmax(-1e7 * dist, dim=-1)
            symbols_hard = torch.argmax(phi, axis=-1)
            phi = F.one_hot(symbols_hard, num_classes=self.centers.size(0))
        inner_product = phi * self.centers
        y_hat = torch.sum(inner_product, axis=-1)
        y_hat = y_hat.reshape(y.shape)
        return y_hat

    def decode(self, y):
        x = self.decoder(y)
        return x

    def forward(self, x):
        y = self.encode(x)
        y_hat = self.quantise(y)
        x_hat = self.decode(y_hat)

        return {
            'x_hat': x_hat,
            'likelihoods': {
                'y': torch.zeros_like(y),
            }
        }

In [19]:
class EntropyLimitedAutoencoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.centers = torch.Tensor([-1, 1])
        self.sigma = 10
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=128, kernel_size=3, stride=2, padding=1),  # Reduce tamaño a 128x128
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1),  # Reduce tamaño a 128x128
            nn.BatchNorm2d(128),
            nn.Tanh()
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=128, out_channels=1, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )


    def encode(self, x):
        y = self.encoder(x)
        return y

    
    def quantise(self, y):
        if self.centers is None:
            return y
        y_flat = y.reshape(y.size(0), y.size(1), y.size(2)*y.size(3), 1)
        dist = torch.abs((y_flat - self.centers))**2
        if self.train:
            phi = F.softmax(-self.sigma * dist, dim=-1)
        else:
            phi = F.softmax(-1e7 * dist, dim=-1)
            symbols_hard = torch.argmax(phi, axis=-1)
            phi = F.one_hot(symbols_hard, num_classes=self.centers.size(0))
        inner_product = phi * self.centers
        y_hat = torch.sum(inner_product, axis=-1)
        y_hat = y_hat.reshape(y.shape)
        return y_hat
    

    def decode(self, y):
        x = self.decoder(y)
        return x
    

    def forward(self, x):
        encoded = self.encode(x)
        limit_entropy = self.quantise(encoded)
        decoded = self.decode(limit_entropy)
        return decoded

In [20]:
ae = EntropyLimitedAutoencoder()
criterion = nn.MSELoss()
optimizer = optim.AdamW(ae.parameters(), lr=1e-3, weight_decay=1e-4)
summary(ae)

Layer (type:depth-idx)                   Param #
EntropyLimitedAutoencoder                --
├─Sequential: 1-1                        --
│    └─Conv2d: 2-1                       1,280
│    └─BatchNorm2d: 2-2                  256
│    └─LeakyReLU: 2-3                    --
│    └─Conv2d: 2-4                       147,584
│    └─BatchNorm2d: 2-5                  256
│    └─LeakyReLU: 2-6                    --
│    └─Conv2d: 2-7                       147,584
│    └─BatchNorm2d: 2-8                  256
│    └─LeakyReLU: 2-9                    --
│    └─Conv2d: 2-10                      147,584
│    └─BatchNorm2d: 2-11                 256
│    └─Tanh: 2-12                        --
├─Sequential: 1-2                        --
│    └─ConvTranspose2d: 2-13             147,584
│    └─BatchNorm2d: 2-14                 256
│    └─LeakyReLU: 2-15                   --
│    └─ConvTranspose2d: 2-16             147,584
│    └─BatchNorm2d: 2-17                 256
│    └─LeakyReLU: 2-18               

In [None]:
num_epochs = 3

for epoch in range(num_epochs):
    total_loss = 0
    for batch in dataloader:  # dataloader ya tiene los batches de 64x1x256x256
        batch = batch[0] # Extraer tensor y mover a GPU si es necesario

        optimizer.zero_grad()  # Reiniciar gradientes

        outputs = ae(batch)  # Forward pass
        loss = criterion(outputs, batch)  # Comparar con entrada

        loss.backward()  # Backpropagation
        optimizer.step()  # Actualizar pesos

        total_loss += loss.item()
        print(f"Pérdida: {loss.item():.6f}")
    avg_loss = total_loss / len(dataloader)
    print(f"Época [{epoch+1}/{num_epochs}], Pérdida: {avg_loss:.6f}")

Pérdida: 0.042193
Pérdida: 0.076499
Pérdida: 0.039777
Pérdida: 0.026595
Pérdida: 0.028007
Pérdida: 0.026625
Pérdida: 0.023293
Pérdida: 0.019201
Pérdida: 0.019451
Pérdida: 0.016311
Pérdida: 0.017346
