In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)

val_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)

val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=100, shuffle=False)

In [2]:
from lightning_extensions import BaseModule
from models import VAE
from loss import SoftAdaptModule

def kl_loss(z_mean, z_log_var):
        return -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())
    
def recon_loss(inputs, outputs):
    return F.mse_loss(inputs, outputs, reduction='sum')

class VAEModule(BaseModule):
    def __init__(self):
        model = VAE()
        super().__init__(model)
        self.save_hyperparameters()

    def forward(self, x, y):
        return self.model(x, None, y)

    def step(self, batch, batch_idx, mode = 'train'):
        x, y = batch
        outputs, outputs_masked, z, z_mean, z_log_var = self(x, y)
        loss = {}
        loss['recon_loss_0'] = recon_loss(x, outputs[0])
        loss['kl_loss'] = kl_loss(z_mean, z_log_var)
        loss['loss'] = loss['recon_loss_0'] + loss['kl_loss']
        
        self.log_dict({f"{mode}_{key}": val.item() for key, val in loss.items()}, sync_dist=True, prog_bar=True)
        return loss['loss']

from softadapt import SoftAdapt, NormalizedSoftAdapt, LossWeightedSoftAdapt
class VAEModuleSoftAdapt(BaseModule):
    def __init__(self):
        model = VAE()
        self.softadapt_object = LossWeightedSoftAdapt(beta=0.001)
        super().__init__(model)
        self.save_hyperparameters()
        self.soft_adapt = SoftAdaptModule()

    def forward(self, x, y):
        return self.model(x, None, y)

    def step(self, batch, batch_idx, mode = 'train'):
        x, y = batch
        outputs, outputs_masked, z, z_mean, z_log_var = self(x, y)
        loss = {}
        loss['recon_loss_0'] = recon_loss(x, outputs[0])
        loss['kl_loss'] = kl_loss(z_mean, z_log_var)
        loss['loss'] = loss['recon_loss_0'] + loss['kl_loss']

        self.log_dict({f"{mode}_{key}": val.item() for key, val in loss.items()}, sync_dist=True, prog_bar=True)
        return self.soft_adapt([loss['recon_loss_0'], loss['kl_loss']], mode == 'train')

In [None]:
from lightning_extensions import ExtendedTrainer

model = VAEModule()
model_name = "VAE-convolutional"
trainer = ExtendedTrainer(project_name="MTVAEs_SoftAdapt", max_epochs=30, model_name=model_name)
trainer.fit(model, train_loader, val_loader)

In [3]:
from lightning_extensions import ExtendedTrainer

model = VAEModuleSoftAdapt()
model_name = "VAE-convolutional-softadapt"
trainer = ExtendedTrainer(project_name="MTVAEs_SoftAdapt", max_epochs=30, model_name=model_name)
trainer.fit(model, train_loader, val_loader)



AttributeError: cannot assign module before Module.__init__() call