In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

mnist_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)

mnist_dataloader = torch.utils.data.DataLoader(dataset=mnist_dataset, batch_size=100, shuffle=True)

In [6]:
class VAE(nn.Module):
    def __init__(self, x_dim: int, h_dim1: int, h_dim2: int, z_dim: int):
        super(VAE, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h) # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return F.sigmoid(self.fc6(h)) 
    
    def forward(self, x, y):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

    def loss(self, batch, outputs):
        x, y = batch
        x_recon, mean, log_var  = outputs

        BCE = F.binary_cross_entropy(x_recon, x.view(-1, 784), reduction='sum')
        KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
        
        loss = BCE + KLD

        return { 'loss': loss, 'BCE_loss': BCE, 'KLD_loss': KLD}
    

In [11]:
from softadapt import SoftAdapt, NormalizedSoftAdapt, LossWeightedSoftAdapt

model = VAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=2)

# Change 1: Create a SoftAdapt object (with your desired variant)
softadapt_object = LossWeightedSoftAdapt(beta=0.1)

# Change 2: Define how often SoftAdapt calculate weights for the loss components
epochs_to_make_updates = 5

values_of_component_1 = []
values_of_component_2 = []
# Initializing adaptive weights to all ones.
adapt_weights = torch.tensor([1,1])

for current_epoch in range(1, 10):
    for x, y in mnist_dataloader:
        x_recon, mean, log_var = model(x, y)
        loss = model.loss((x, y), (x_recon, mean, log_var))

        bce_loss = loss['BCE_loss']
        kld = loss['KLD_loss']

        values_of_component_1.append(bce_loss)
        values_of_component_2.append(kld)

        if current_epoch % epochs_to_make_updates == 0 and current_epoch > 0:
            # Change 3: Update weights of components
            adapt_weights = softadapt_object.get_component_weights(torch.Tensor(values_of_component_1), 
                                                                 torch.Tensor(values_of_component_2),
                                                                 verbose=False,
                                                                 )
                                           
        
            # Resetting the lists to start fresh (this part is optional)
            values_of_component_1 = []
            values_of_component_2 = []



OverflowError: int too large to convert to float