In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

mnist_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)

mnist_dataloader = torch.utils.data.DataLoader(dataset=mnist_dataset, batch_size=100, shuffle=True)

val_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)

val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=100, shuffle=False)

class VAE(nn.Module):
    def __init__(self, x_dim: int, h_dim1: int, h_dim2: int, z_dim: int):
        super(VAE, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h) # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return F.sigmoid(self.fc6(h)) 
    
    def forward(self, x, y):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

    def loss(self, batch, outputs):
        x, y = batch
        x_recon, mean, log_var  = outputs

        BCE = recon_loss(x_recon, x.view(-1, 784))
        KLD = kl_loss(mean, log_var)
        
        loss = BCE + KLD

        return { 'loss': loss, 'BCE_loss': BCE, 'KLD_loss': KLD}
    
def kl_loss(z_mean, z_log_var):
        return -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())
    
def recon_loss(inputs, outputs):
    return F.mse_loss(inputs, outputs, reduction='sum')

In [5]:
model = VAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=2)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for current_epoch in range(1, 30):
    for x, y in mnist_dataloader:
        optimizer.zero_grad()
        x_recon, mean, log_var = model(x, y)
        loss = model.loss((x, y), (x_recon, mean, log_var))
        loss['loss'].backward()
        optimizer.step()

    # Validate
    with torch.no_grad():
        losses = []
        bce_losses = []
        kld_losses = []

        for x, y in val_dataloader:
            x_recon, mean, log_var = model(x, y)
            loss = model.loss((x, y), (x_recon, mean, log_var))
            losses.append(loss['loss'])
            bce_losses.append(loss['BCE_loss'])
            kld_losses.append(loss['KLD_loss'])
        print("Weights: ", torch.tensor([1,1]))
        print("Epoch: {}, Loss: {}, BCE Loss: {}, KLD Loss: {}".format(current_epoch, torch.mean(torch.tensor(losses)), torch.mean(torch.tensor(bce_losses)), torch.mean(torch.tensor(kld_losses))))

Weights:  tensor([1, 1])
Epoch: 1, Loss: 4163.302734375, BCE Loss: 3809.1845703125, KLD Loss: 354.11761474609375
Weights:  tensor([1, 1])
Epoch: 2, Loss: 3935.875, BCE Loss: 3529.89306640625, KLD Loss: 405.9815673828125
Weights:  tensor([1, 1])
Epoch: 3, Loss: 3798.40869140625, BCE Loss: 3353.83349609375, KLD Loss: 444.57537841796875
Weights:  tensor([1, 1])
Epoch: 4, Loss: 3714.271484375, BCE Loss: 3254.0146484375, KLD Loss: 460.25726318359375
Weights:  tensor([1, 1])
Epoch: 5, Loss: 3666.865966796875, BCE Loss: 3198.347900390625, KLD Loss: 468.5180358886719
Weights:  tensor([1, 1])
Epoch: 6, Loss: 3642.686767578125, BCE Loss: 3136.531982421875, KLD Loss: 506.1553955078125
Weights:  tensor([1, 1])
Epoch: 7, Loss: 3600.828125, BCE Loss: 3099.433837890625, KLD Loss: 501.39385986328125
Weights:  tensor([1, 1])
Epoch: 8, Loss: 3579.992431640625, BCE Loss: 3091.715576171875, KLD Loss: 488.2769775390625
Weights:  tensor([1, 1])
Epoch: 9, Loss: 3547.362548828125, BCE Loss: 3048.224609375, KL

# Train example with softadapt

In [13]:
from softadapt import SoftAdapt, NormalizedSoftAdapt, LossWeightedSoftAdapt

model = VAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=2)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Change 1: Create a SoftAdapt object (with your desired variant)
softadapt_object = LossWeightedSoftAdapt(beta=0.001)

# Change 2: Define how often SoftAdapt calculate weights for the loss components
epochs_to_make_updates = 5

values_of_component_1 = []
values_of_component_2 = []
# Initializing adaptive weights to all ones.
adapt_weights = torch.tensor([1,1])

limit = 101

count = 0

for current_epoch in range(1, 30):
    for x, y in mnist_dataloader:
        optimizer.zero_grad()
        count += 1
        x_recon, mean, log_var = model(x, y)
        loss = model.loss((x, y), (x_recon, mean, log_var))

        bce_loss = loss['BCE_loss']
        kld = loss['KLD_loss']

        values_of_component_1.append(bce_loss)
        values_of_component_2.append(kld)

        if (current_epoch % epochs_to_make_updates == 0 and current_epoch > 1 and count >= limit) or count >= limit:
            # Change 3: Update weights of components
            count = 0
            # print("Adaptive weights: ", adapt_weights)
            # print("epoch")
            # print(current_epoch)
            first = torch.tensor(values_of_component_1, dtype=torch.float64)
            second = torch.tensor(values_of_component_2, dtype=torch.float64)
            # print(first)
            # print(second)
            # print(first.dtype)
            # print(second.dtype)
            # print(first.shape)
            # print(second.shape)
            adapt_weights = softadapt_object.get_component_weights(first, second,verbose=False)
            #print("WORKS")
                                           
        
            # Resetting the lists to start fresh (this part is optional)
            values_of_component_1 = []
            values_of_component_2 = []

        loss = adapt_weights[0] * bce_loss + adapt_weights[1] * kld
        
        loss.backward()
        optimizer.step()

    # Validate
    with torch.no_grad():
        losses = []
        bce_losses = []
        kld_losses = []

        for x, y in val_dataloader:
            x_recon, mean, log_var = model(x, y)
            loss = model.loss((x, y), (x_recon, mean, log_var))
            losses.append(loss['loss'])
            bce_losses.append(loss['BCE_loss'])
            kld_losses.append(loss['KLD_loss'])
        print("Weights: ", adapt_weights)
        print("Epoch: {}, Loss: {}, BCE Loss: {}, KLD Loss: {}".format(current_epoch, torch.mean(torch.tensor(losses)), torch.mean(torch.tensor(bce_losses)), torch.mean(torch.tensor(kld_losses))))
        

Weights:  tensor([0.9193, 0.0807], dtype=torch.float64)
Epoch: 1, Loss: 16606.33203125, BCE Loss: 15405.9677734375, KLD Loss: 1200.3646240234375
Weights:  tensor([0.9351, 0.0649], dtype=torch.float64)
Epoch: 2, Loss: 15907.150390625, BCE Loss: 14603.6572265625, KLD Loss: 1303.4937744140625
Weights:  tensor([0.9129, 0.0871], dtype=torch.float64)
Epoch: 3, Loss: 15421.130859375, BCE Loss: 14231.5478515625, KLD Loss: 1189.5838623046875
Weights:  tensor([0.8923, 0.1077], dtype=torch.float64)
Epoch: 4, Loss: 15127.05859375, BCE Loss: 14032.9853515625, KLD Loss: 1094.0736083984375
Weights:  tensor([0.9140, 0.0860], dtype=torch.float64)
Epoch: 5, Loss: 15001.2802734375, BCE Loss: 13837.037109375, KLD Loss: 1164.2423095703125
Weights:  tensor([0.9590, 0.0410], dtype=torch.float64)
Epoch: 6, Loss: 15024.3623046875, BCE Loss: 13809.22265625, KLD Loss: 1215.1402587890625
Weights:  tensor([0.9121, 0.0879], dtype=torch.float64)
Epoch: 7, Loss: 14791.1787109375, BCE Loss: 13656.8037109375, KLD Loss: