In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)

val_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)

val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=100, shuffle=False)

class VAE(nn.Module):
    def __init__(self, x_dim: int, h_dim1: int, h_dim2: int, z_dim: int):
        super(VAE, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h) # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return F.sigmoid(self.fc6(h)) 
    
    def forward(self, x, y):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

    def loss(self, batch, outputs):
        x, y = batch
        x_recon, mean, log_var  = outputs

        BCE = recon_loss(x_recon, x.view(-1, 784))
        KLD = kl_loss(mean, log_var)
        
        loss = BCE + KLD

        return { 'loss': loss, 'recon_loss_0': BCE, 'kl_loss': KLD}
    
def kl_loss(z_mean, z_log_var):
        return -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())
    
def recon_loss(inputs, outputs):
    return F.mse_loss(inputs, outputs, reduction='sum')

In [6]:
from lightning_extensions import BaseModule

class VAEModule(BaseModule):
    def __init__(self):
        model = VAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=2)
        super().__init__(model)
        self.save_hyperparameters()

    def forward(self, x, y):
        return self.model(x, y)

    def step(self, batch, batch_idx, mode = 'train'):
        x, y = batch
        x_hat = self(x, y)
        loss = self.model.loss(batch, x_hat)
        self.log_dict({f"{mode}_{key}": val.item() for key, val in loss.items()}, sync_dist=True, prog_bar=True)
        return loss['loss']

from softadapt import SoftAdapt, NormalizedSoftAdapt, LossWeightedSoftAdapt
class VAEModuleSoftAdapt(BaseModule):
    def __init__(self):
        model = VAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=2)
        self.softadapt_object = LossWeightedSoftAdapt(beta=0.001)
        self.reconstruction_losses = []
        self.kl_losses = []
        self.adapt_weights = torch.tensor([1,1])
        super().__init__(model)
        self.save_hyperparameters()

    def forward(self, x, y):
        return self.model(x, y)

    def step(self, batch, batch_idx, mode = 'train'):
        x, y = batch
        x_hat = self(x, y)
        loss = self.model.loss(batch, x_hat)

        self.log_dict({f"{mode}_{key}": val.item() for key, val in loss.items()}, prog_bar=True)
        
        recon = loss['recon_loss_0']
        kl = loss['kl_loss']

        if mode == 'train':
            self.reconstruction_losses.append(recon)
            self.kl_losses.append(kl)

        if len(self.reconstruction_losses) > 100 and mode == 'train':
            first = torch.tensor(self.reconstruction_losses, dtype=torch.float64)
            second = torch.tensor(self.kl_losses, dtype=torch.float64)

            self.adapt_weights = self.softadapt_object.get_component_weights(first, second, verbose=False)

            self.reconstruction_losses = []
            self.kl_losses = []

        return self.adapt_weights[0]  * recon + self.adapt_weights[1] * kl

In [None]:
from lightning_extensions import ExtendedTrainer

model = VAEModule()
model_name = "VAE-simple"
trainer = ExtendedTrainer(project_name="MTVAEs_SoftAdapt", max_epochs=30, model_name=model_name)
trainer.fit(model, train_loader, val_loader)

In [None]:
from lightning_extensions import ExtendedTrainer

model = VAEModuleSoftAdapt()
model_name = "VAE-softadapt"
trainer = ExtendedTrainer(project_name="MTVAEs_SoftAdapt", max_epochs=30, model_name=model_name)
trainer.fit(model, train_loader, val_loader)

In [5]:
model = VAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=2)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for current_epoch in range(1, 30):
    for x, y in mnist_dataloader:
        optimizer.zero_grad()
        x_recon, mean, log_var = model(x, y)
        loss = model.loss((x, y), (x_recon, mean, log_var))
        loss['loss'].backward()
        optimizer.step()

    # Validate
    with torch.no_grad():
        losses = []
        bce_losses = []
        kld_losses = []

        for x, y in val_dataloader:
            x_recon, mean, log_var = model(x, y)
            loss = model.loss((x, y), (x_recon, mean, log_var))
            losses.append(loss['loss'])
            bce_losses.append(loss['BCE_loss'])
            kld_losses.append(loss['KLD_loss'])
        print("Weights: ", torch.tensor([1,1]))
        print("Epoch: {}, Loss: {}, BCE Loss: {}, KLD Loss: {}".format(current_epoch, torch.mean(torch.tensor(losses)), torch.mean(torch.tensor(bce_losses)), torch.mean(torch.tensor(kld_losses))))

Weights:  tensor([1, 1])
Epoch: 1, Loss: 4163.302734375, BCE Loss: 3809.1845703125, KLD Loss: 354.11761474609375
Weights:  tensor([1, 1])
Epoch: 2, Loss: 3935.875, BCE Loss: 3529.89306640625, KLD Loss: 405.9815673828125
Weights:  tensor([1, 1])
Epoch: 3, Loss: 3798.40869140625, BCE Loss: 3353.83349609375, KLD Loss: 444.57537841796875
Weights:  tensor([1, 1])
Epoch: 4, Loss: 3714.271484375, BCE Loss: 3254.0146484375, KLD Loss: 460.25726318359375
Weights:  tensor([1, 1])
Epoch: 5, Loss: 3666.865966796875, BCE Loss: 3198.347900390625, KLD Loss: 468.5180358886719
Weights:  tensor([1, 1])
Epoch: 6, Loss: 3642.686767578125, BCE Loss: 3136.531982421875, KLD Loss: 506.1553955078125
Weights:  tensor([1, 1])
Epoch: 7, Loss: 3600.828125, BCE Loss: 3099.433837890625, KLD Loss: 501.39385986328125
Weights:  tensor([1, 1])
Epoch: 8, Loss: 3579.992431640625, BCE Loss: 3091.715576171875, KLD Loss: 488.2769775390625
Weights:  tensor([1, 1])
Epoch: 9, Loss: 3547.362548828125, BCE Loss: 3048.224609375, KL

# Train example with softadapt

In [3]:
from softadapt import SoftAdapt, NormalizedSoftAdapt, LossWeightedSoftAdapt
import wandb

model = VAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=2)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Change 1: Create a SoftAdapt object (with your desired variant)
softadapt_object = LossWeightedSoftAdapt(beta=0.001)

# Change 2: Define how often SoftAdapt calculate weights for the loss components
epochs_to_make_updates = 5

values_of_component_1 = []
values_of_component_2 = []
# Initializing adaptive weights to all ones.
adapt_weights = torch.tensor([1,1])

limit = 101

count = 0

wandb.init(project="MTVAEs_SoftAdapt", name="VAE-softadapt-custom")
for current_epoch in range(1, 30):
    for x, y in train_loader:
        optimizer.zero_grad()
        count += 1
        x_recon, mean, log_var = model(x, y)
        loss = model.loss((x, y), (x_recon, mean, log_var))

        bce_loss = loss['recon_loss_0']
        kld = loss['kl_loss']

        values_of_component_1.append(bce_loss)
        values_of_component_2.append(kld)

        if (current_epoch % epochs_to_make_updates == 0 and current_epoch > 1 and count >= limit) or count >= limit:
            # Change 3: Update weights of components
            count = 0
            # print("Adaptive weights: ", adapt_weights)
            # print("epoch")
            # print(current_epoch)
            first = torch.tensor(values_of_component_1, dtype=torch.float64)
            second = torch.tensor(values_of_component_2, dtype=torch.float64)
            # print(first)
            # print(second)
            # print(first.dtype)
            # print(second.dtype)
            # print(first.shape)
            # print(second.shape)
            adapt_weights = softadapt_object.get_component_weights(first, second,verbose=False)
            #print("WORKS")
                                           
        
            # Resetting the lists to start fresh (this part is optional)
            values_of_component_1 = []
            values_of_component_2 = []

        loss = adapt_weights[0] * bce_loss + adapt_weights[1] * kld
        
        loss.backward()
        optimizer.step()

        #wandb.log({"train_loss": loss, "train_recon_loss_0": bce_loss, "train_kl_loss": kld, "epoch": current_epoch})

    # Validate
    with torch.no_grad():
        losses = []
        bce_losses = []
        kld_losses = []

        for x, y in val_loader:
            x_recon, mean, log_var = model(x, y)
            loss = model.loss((x, y), (x_recon, mean, log_var))
            losses.append(loss['loss'])
            bce_losses.append(loss['recon_loss_0'])
            kld_losses.append(loss['kl_loss'])
        print("Weights: ", adapt_weights)
        print("Epoch: {}, Loss: {}, BCE Loss: {}, KLD Loss: {}".format(current_epoch, torch.mean(torch.tensor(losses)), torch.mean(torch.tensor(bce_losses)), torch.mean(torch.tensor(kld_losses))))
        wandb.log({"val_loss": torch.mean(torch.tensor(losses)), "val_recon_loss_0": torch.mean(torch.tensor(bce_losses)), "val_kl_loss": torch.mean(torch.tensor(kld_losses))})
        

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33medvardsz[0m. Use [1m`wandb login --relogin`[0m to force relogin


Weights:  tensor([0.8530, 0.1470], dtype=torch.float64)
Epoch: 1, Loss: 4242.599609375, BCE Loss: 3583.261962890625, KLD Loss: 659.33740234375
Weights:  tensor([0.8725, 0.1275], dtype=torch.float64)
Epoch: 2, Loss: 4024.90625, BCE Loss: 3352.579345703125, KLD Loss: 672.3267211914062
Weights:  tensor([0.8157, 0.1843], dtype=torch.float64)
Epoch: 3, Loss: 3856.71240234375, BCE Loss: 3211.237548828125, KLD Loss: 645.4752197265625
Weights:  tensor([0.8474, 0.1526], dtype=torch.float64)
Epoch: 4, Loss: 3793.824951171875, BCE Loss: 3117.480224609375, KLD Loss: 676.3447875976562
Weights:  tensor([0.8203, 0.1797], dtype=torch.float64)
Epoch: 5, Loss: 3738.108154296875, BCE Loss: 3070.41943359375, KLD Loss: 667.6888427734375
Weights:  tensor([0.8201, 0.1799], dtype=torch.float64)
Epoch: 6, Loss: 3698.8583984375, BCE Loss: 3006.57373046875, KLD Loss: 692.28466796875
Weights:  tensor([0.8181, 0.1819], dtype=torch.float64)
Epoch: 7, Loss: 3665.492919921875, BCE Loss: 2979.7001953125, KLD Loss: 685