In [2]:
pip install pytorch_msssim

Collecting pytorch_msssim
  Downloading pytorch_msssim-1.0.0-py3-none-any.whl.metadata (8.0 kB)
Downloading pytorch_msssim-1.0.0-py3-none-any.whl (7.7 kB)
Installing collected packages: pytorch_msssim
Successfully installed pytorch_msssim-1.0.0
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.3.1 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [17]:
from LoadingDefault import LoadData

from torch import nn
import torch.optim as optim
import torch
import torch.nn.functional as F

from pytorch_msssim import ms_ssim

dataloader = LoadData(limit=100, batch_size=8)



In [18]:
class EntropyLimitedAutoencoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.centers = torch.Tensor([-1, 1])
        self.sigma = 10
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=128, kernel_size=3, stride=2, padding=1),  # Reduce tamaño a 128x128
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1),  # Reduce tamaño a 128x128
            nn.BatchNorm2d(128),
            nn.Tanh()
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=128, out_channels=1, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )


    def encode(self, x):
        y = self.encoder(x)
        return y

    
    def quantise(self, y):
        if self.centers is None:
            return y
        y_flat = y.reshape(y.size(0), y.size(1), y.size(2)*y.size(3), 1)
        dist = torch.abs((y_flat - self.centers))**2
        if self.train:
            phi = F.softmax(-self.sigma * dist, dim=-1)
        else:
            phi = F.softmax(-1e7 * dist, dim=-1)
            symbols_hard = torch.argmax(phi, axis=-1)
            phi = F.one_hot(symbols_hard, num_classes=self.centers.size(0))
        inner_product = phi * self.centers
        y_hat = torch.sum(inner_product, axis=-1)
        y_hat = y_hat.reshape(y.shape)
        return y_hat
    

    def decode(self, y):
        x = self.decoder(y)
        return x
    

    def forward(self, x):
        encoded = self.encode(x)
        limit_entropy = self.quantise(encoded)
        decoded = self.decode(limit_entropy)
        return decoded

In [22]:
def ms_ssim_loss(x_pred, x_true):
    return 1 - ms_ssim(x_pred, x_true, data_range=1.0, size_average=True)

class MSSSIMLoss(nn.Module):
    def __init__(self):
        super(MSSSIMLoss, self).__init__()

    def forward(self, reconstructed, original):
        return 1 - ms_ssim(reconstructed, original, data_range=1.0, size_average=True)

In [23]:
ae = EntropyLimitedAutoencoder()
criterion = MSSSIMLoss()
optimizer = optim.AdamW(ae.parameters(), lr=1e-3, weight_decay=1e-4)

In [24]:
num_epochs = 3

for epoch in range(num_epochs):
    total_loss = 0
    for batch in dataloader:  # dataloader ya tiene los batches de 64x1x256x256
        batch = batch[0] # Extraer tensor

        optimizer.zero_grad()  # Reiniciar gradientes

        outputs = ae(batch)  # Forward pass
        loss = criterion(outputs, batch)  # Comparar con entrada

        loss.backward()  # Backpropagation
        optimizer.step()  # Actualizar pesos

        total_loss += loss.item()
        print(f"Pérdida: {loss.item():.6f}")
    avg_loss = total_loss / len(dataloader)
    print(f"Época [{epoch+1}/{num_epochs}], Pérdida: {avg_loss:.6f}")

Pérdida: 0.795522
Pérdida: 0.799145
Pérdida: 0.560102
Pérdida: 0.524463
Pérdida: 0.531237
Pérdida: 0.406490
Pérdida: 0.385096
Pérdida: 0.336608
Pérdida: 0.322165
Pérdida: 0.316142
Pérdida: 0.260274
Pérdida: 0.247094
Pérdida: 0.260976
Pérdida: 0.236262
Pérdida: 0.236617
Pérdida: 0.227871
Pérdida: 0.237589
Pérdida: 0.213715
Pérdida: 0.203694
Pérdida: 0.218071
Pérdida: 0.228664
Pérdida: 0.199183
Pérdida: 0.225289
Pérdida: 0.196277
Pérdida: 0.228374
Época [1/3], Pérdida: 0.335877
Pérdida: 0.183653
Pérdida: 0.181277
Pérdida: 0.196269
Pérdida: 0.192351
Pérdida: 0.187076
Pérdida: 0.175193
Pérdida: 0.170951
Pérdida: 0.173031
Pérdida: 0.175625
Pérdida: 0.176766
Pérdida: 0.183787
Pérdida: 0.165321
Pérdida: 0.172000
Pérdida: 0.169707
Pérdida: 0.161331
Pérdida: 0.165498
Pérdida: 0.156969
Pérdida: 0.157550
Pérdida: 0.145697
Pérdida: 0.151166
Pérdida: 0.152214
Pérdida: 0.164424
Pérdida: 0.143970
Pérdida: 0.158068
Pérdida: 0.146022
Época [2/3], Pérdida: 0.168237
Pérdida: 0.152827
Pérdida: 0.145155
Pé