In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

mnist_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)

mnist_dataloader = torch.utils.data.DataLoader(dataset=mnist_dataset, batch_size=100, shuffle=True)

In [2]:
class VAE(nn.Module):
    def __init__(self, x_dim: int, h_dim1: int, h_dim2: int, z_dim: int):
        super(VAE, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h) # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return F.sigmoid(self.fc6(h)) 
    
    def forward(self, x, y):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

    def loss(self, batch, outputs):
        x, y = batch
        x_recon, mean, log_var  = outputs

        BCE = F.binary_cross_entropy(x_recon, x.view(-1, 784), reduction='sum')
        KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
        
        loss = BCE + KLD

        return { 'loss': loss, 'BCE_loss': BCE, 'KLD_loss': KLD}
    

# Train example with softadapt

In [4]:
from softadapt import SoftAdapt, NormalizedSoftAdapt, LossWeightedSoftAdapt

model = VAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=2)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Change 1: Create a SoftAdapt object (with your desired variant)
softadapt_object = LossWeightedSoftAdapt(beta=0.001)

# Change 2: Define how often SoftAdapt calculate weights for the loss components
epochs_to_make_updates = 5

values_of_component_1 = []
values_of_component_2 = []
# Initializing adaptive weights to all ones.
adapt_weights = torch.tensor([1,1])

limit = 101

count = 0

for current_epoch in range(1, 30):
    for x, y in mnist_dataloader:
        optimizer.zero_grad()
        count += 1
        x_recon, mean, log_var = model(x, y)
        loss = model.loss((x, y), (x_recon, mean, log_var))

        bce_loss = loss['BCE_loss']
        kld = loss['KLD_loss']

        values_of_component_1.append(bce_loss)
        values_of_component_2.append(kld)

        if (current_epoch % epochs_to_make_updates == 0 and current_epoch > 1 and count >= limit) or count >= limit:
            # Change 3: Update weights of components
            count = 0
            # print("Adaptive weights: ", adapt_weights)
            # print("epoch")
            # print(current_epoch)
            first = torch.tensor(values_of_component_1, dtype=torch.float64)
            second = torch.tensor(values_of_component_2, dtype=torch.float64)
            # print(first)
            # print(second)
            # print(first.dtype)
            # print(second.dtype)
            # print(first.shape)
            # print(second.shape)
            adapt_weights = softadapt_object.get_component_weights(first, second,verbose=True)
            #print("WORKS")
                                           
        
            # Resetting the lists to start fresh (this part is optional)
            values_of_component_1 = []
            values_of_component_2 = []

        loss = adapt_weights[0] * bce_loss + adapt_weights[1] * kld
        
        loss.backward()
        optimizer.step()

    print(f'Epoch [{current_epoch}/{30}], Loss: {loss.item():.4f}, BCE Loss: {bce_loss.item():.4f}, KLD Loss: {kld.item():.4f}')



Epoch [1/30], Loss: 18579.5293, BCE Loss: 18315.3633, KLD Loss: 264.1652
==> Interpreting finite difference order as 100 sinceno explicit order was specified.
==> Interpreting finite difference order as 100 sinceno explicit order was specified.
Epoch [2/30], Loss: 15295.5254, BCE Loss: 16298.4072, KLD Loss: 988.6555
==> Interpreting finite difference order as 100 sinceno explicit order was specified.
==> Interpreting finite difference order as 100 sinceno explicit order was specified.
Epoch [3/30], Loss: 15481.9363, BCE Loss: 16080.1582, KLD Loss: 1181.3688
==> Interpreting finite difference order as 100 sinceno explicit order was specified.
==> Interpreting finite difference order as 100 sinceno explicit order was specified.
Epoch [4/30], Loss: 14472.2100, BCE Loss: 15376.2148, KLD Loss: 1160.6482
==> Interpreting finite difference order as 100 sinceno explicit order was specified.
==> Interpreting finite difference order as 100 sinceno explicit order was specified.
Epoch [5/30], Loss