In [10]:
from LoadingDefault import LoadData, LoadNoise
from AutoEncoderObjects import EntropyLimitedAutoencoder, MSSSIMLoss, NLPDLoss
from torch.nn import MSELoss
import torch.optim as optim
import torch
import pickle

In [12]:
num_epochs = 5

dataloader = LoadData(limit=10, batch_size=8)
noise = LoadNoise(limit = 10, batch_size=8)

metadata_model = []

loss_functions = (MSELoss, NLPDLoss, MSSSIMLoss)

for data in (dataloader, noise):
    for loss_f in range(len(loss_functions)):

        ae = EntropyLimitedAutoencoder()
        criterion = loss_functions[loss_f]()
        optimizer = optim.AdamW(ae.parameters(), lr=1e-3, weight_decay=1e-4)

        mod = {}
        if loss_f == 0:
            mod["loss_type"] = "MSE"
        elif loss_f == 1:
            mod["loss_type"] = "NLPD"
        elif loss_f == 2:
            mod["loss_type"] = "MSSSIM"
        
        if data == dataloader:
            mod["data"] = "songs"
        else:
            mod["data"] = "noise"

        mod["epochs"] = num_epochs

        loss_epochs = []
        loss_batch = []

        for epoch in range(num_epochs):
            total_loss = 0
            for batch in data:  # dataloader ya tiene los batches de 64x1x256x256
                batch = batch[0] # Extraer tensor

                optimizer.zero_grad()  # Reiniciar gradientes

                outputs = ae(batch)  # Forward pass
                loss = criterion(outputs, batch)  # Comparar con entrada

                loss.backward()  # Backpropagation
                optimizer.step()  # Actualizar pesos

                total_loss += loss.item()
                loss_batch.append(loss.item())

            avg_loss = total_loss / len(data)
            loss_epochs.append(avg_loss)

            print(f"Época [{epoch+1}/{num_epochs}], Pérdida: {avg_loss:.6f}")
        
        mod["loss_epochs"] = loss_epochs
        mod["loss_batch"] = loss_batch
        mod["file_weights"] = mod["loss_type"] + "-" + mod["data"] + ".pth"
        mod["file_full"] = mod["loss_type"] + "-" + mod["data"] + "-completo.pth"

        torch.save(ae.state_dict(), "MODELOS/" + mod["file_weights"])
        torch.save(ae, "MODELOS/" + mod["file_full"])
    
        metadata_model.append(mod)

with open("MODELOS/" + "metadatos_modelos.pkl", "wb") as f:
    pickle.dump(metadata_model, f)

Época [1/5], Pérdida: 0.048857
Época [2/5], Pérdida: 0.027259
Época [3/5], Pérdida: 0.017572
Época [4/5], Pérdida: 0.014349
Época [5/5], Pérdida: 0.012027
Época [1/5], Pérdida: 479.059118
Época [2/5], Pérdida: 368.990626
Época [3/5], Pérdida: 318.824956
Época [4/5], Pérdida: 294.056974
Época [5/5], Pérdida: 275.806709
Época [1/5], Pérdida: 0.732278
Época [2/5], Pérdida: 0.462642
Época [3/5], Pérdida: 0.315845
Época [4/5], Pérdida: 0.257107
Época [5/5], Pérdida: 0.207774
Época [1/5], Pérdida: 0.111886
Época [2/5], Pérdida: 0.095356
Época [3/5], Pérdida: 0.087513
Época [4/5], Pérdida: 0.085226
Época [5/5], Pérdida: 0.083930
Época [1/5], Pérdida: 494.657003
Época [2/5], Pérdida: 426.440150
Época [3/5], Pérdida: 406.833206
Época [4/5], Pérdida: 387.005397
Época [5/5], Pérdida: 362.783045
Época [1/5], Pérdida: 0.826854
Época [2/5], Pérdida: 0.657479
Época [3/5], Pérdida: 0.510781
Época [4/5], Pérdida: 0.419398
Época [5/5], Pérdida: 0.362105
