In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)

val_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)

val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=100, shuffle=False)

class VAE(nn.Module):
    def __init__(self, x_dim: int, h_dim1: int, h_dim2: int, z_dim: int):
        super(VAE, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h) # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return F.sigmoid(self.fc6(h)) 
    
    def forward(self, x, y):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

    def loss(self, batch, outputs):
        x, y = batch
        x_recon, mean, log_var  = outputs

        BCE = recon_loss(x_recon, x.view(-1, 784))
        KLD = kl_loss(mean, log_var)
        
        loss = BCE + KLD

        return { 'loss': loss, 'recon_loss_0': BCE, 'kl_loss': KLD}
    
def kl_loss(z_mean, z_log_var):
        return -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())
    
def recon_loss(inputs, outputs):
    return F.mse_loss(inputs, outputs, reduction='sum')

In [5]:
from lightning_extensions import BaseModule

class VAEModule(BaseModule):
    def __init__(self):
        model = VAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=2)
        super().__init__(model)
        self.save_hyperparameters()

    def forward(self, x, y):
        return self.model(x, y)

    def step(self, batch, batch_idx, mode = 'train'):
        x, y = batch
        x_hat = self(x, y)
        loss = self.model.loss(batch, x_hat)
        self.log_dict({f"{mode}_{key}": val.item() for key, val in loss.items()}, sync_dist=True, prog_bar=True)
        return loss['loss']

from softadapt import SoftAdapt, NormalizedSoftAdapt, LossWeightedSoftAdapt
class VAEModuleSoftAdapt(BaseModule):
    def __init__(self):
        model = VAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=2)
        self.softadapt_object = LossWeightedSoftAdapt(beta=0.001)
        self.reconstruction_losses = []
        self.kl_losses = []
        self.adapt_weights = torch.tensor([1,1])
        super().__init__(model)
        self.save_hyperparameters()

    def forward(self, x, y):
        return self.model(x, y)

    def step(self, batch, batch_idx, mode = 'train'):
        x, y = batch
        x_hat = self(x, y)
        loss = self.model.loss(batch, x_hat)
        recon = loss['recon_loss_0']
        kl = loss['kl_loss']

        if mode == 'train':
            copy_recon = recon.detach().clone()
            copy_kl = kl.detach().clone()
            self.reconstruction_losses.append(copy_recon)
            self.kl_losses.append(copy_kl)

        if len(self.reconstruction_losses) > 100 and mode == 'train':
            first = torch.tensor(self.reconstruction_losses, dtype=torch.float64)
            second = torch.tensor(self.kl_losses, dtype=torch.float64)

            self.adapt_weights = self.softadapt_object.get_component_weights(first, second, verbose=False)

            self.reconstruction_losses = []
            self.kl_losses = []

        self.log_dict({f"{mode}_{key}": val.item() for key, val in loss.items()}, prog_bar=True)

        return self.adapt_weights[0]  * recon + self.adapt_weights[1] * kl

In [3]:
from lightning_extensions import ExtendedTrainer

model = VAEModule()
model_name = "VAE-simple"
trainer = ExtendedTrainer(project_name="MTVAEs_SoftAdapt", max_epochs=30, model_name=model_name)
trainer.fit(model, train_loader, val_loader)

/Home/siv34/edzak2974/.conda/envs/pytorch2.1/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /Home/siv34/edzak2974/.conda/envs/pytorch2.1/lib/pyt ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable cod

/Home/siv34/edzak2974/.conda/envs/pytorch2.1/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:630: Checkpoint directory /Home/siv34/edzak2974/projects/MastersThesis/src/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name  | Type | Params
-------------------------------
0 | model | VAE  | 1.1 M 
-------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.275     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/Home/siv34/edzak2974/.conda/envs/pytorch2.1/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.


                                                                           

/Home/siv34/edzak2974/.conda/envs/pytorch2.1/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.


Epoch 29: 100%|██████████| 600/600 [00:08<00:00, 73.96it/s, v_num=75st, train_loss=3.53e+3, train_recon_loss_0=3e+3, train_kl_loss=536.0, val_loss=3.43e+3, val_recon_loss_0=2.89e+3, val_kl_loss=537.0]   

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 29: 100%|██████████| 600/600 [00:08<00:00, 73.94it/s, v_num=75st, train_loss=3.53e+3, train_recon_loss_0=3e+3, train_kl_loss=536.0, val_loss=3.43e+3, val_recon_loss_0=2.89e+3, val_kl_loss=537.0]




In [3]:
from lightning_extensions import ExtendedTrainer

model = VAEModuleSoftAdapt()
model_name = "VAE-softadapt"
trainer = ExtendedTrainer(project_name="MTVAEs_SoftAdapt", max_epochs=30, model_name=model_name)
trainer.fit(model, train_loader, val_loader)

/Home/siv34/edzak2974/.conda/envs/pytorch2.1/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /Home/siv34/edzak2974/.conda/envs/pytorch2.1/lib/pyt ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable cod

/Home/siv34/edzak2974/.conda/envs/pytorch2.1/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:630: Checkpoint directory /Home/siv34/edzak2974/projects/MastersThesis/src/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name  | Type | Params
-------------------------------
0 | model | VAE  | 1.1 M 
-------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.275     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/Home/siv34/edzak2974/.conda/envs/pytorch2.1/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.


                                                                           

/Home/siv34/edzak2974/.conda/envs/pytorch2.1/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.


Epoch 29: 100%|██████████| 600/600 [00:08<00:00, 71.30it/s, v_num=lk5g, train_loss=3.48e+3, train_recon_loss_0=2.75e+3, train_kl_loss=731.0, val_loss=3.5e+3, val_recon_loss_0=2.76e+3, val_kl_loss=740.0] 

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 29: 100%|██████████| 600/600 [00:08<00:00, 71.28it/s, v_num=lk5g, train_loss=3.48e+3, train_recon_loss_0=2.75e+3, train_kl_loss=731.0, val_loss=3.5e+3, val_recon_loss_0=2.76e+3, val_kl_loss=740.0]




In [9]:
model = VAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=2)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for current_epoch in range(1, 30):
    for x, y in train_loader:
        optimizer.zero_grad()
        x_recon, mean, log_var = model(x, y)
        loss = model.loss((x, y), (x_recon, mean, log_var))
        loss['loss'].backward()
        optimizer.step()

    # Validate
    with torch.no_grad():
        losses = []
        bce_losses = []
        kld_losses = []

        for x, y in val_loader:
            x_recon, mean, log_var = model(x, y)
            loss = model.loss((x, y), (x_recon, mean, log_var))
            losses.append(loss['loss'])
            bce_losses.append(loss['recon_loss_0'])
            kld_losses.append(loss['kl_loss'])
        print("Weights: ", torch.tensor([1,1]))
        print("Epoch: {}, Loss: {}, Recon Loss: {}, KL Loss: {}".format(current_epoch, torch.mean(torch.tensor(losses)), torch.mean(torch.tensor(bce_losses)), torch.mean(torch.tensor(kld_losses))))

Weights:  tensor([1, 1])
Epoch: 1, Loss: 4177.669921875, Recon Loss: 3823.641357421875, KL Loss: 354.0292053222656
Weights:  tensor([1, 1])
Epoch: 2, Loss: 3928.5625, Recon Loss: 3505.71875, KL Loss: 422.8433837890625
Weights:  tensor([1, 1])
Epoch: 3, Loss: 3830.398193359375, Recon Loss: 3402.633056640625, KL Loss: 427.7652282714844
Weights:  tensor([1, 1])
Epoch: 4, Loss: 3739.007568359375, Recon Loss: 3274.516357421875, KL Loss: 464.4911804199219
Weights:  tensor([1, 1])
Epoch: 5, Loss: 3709.9306640625, Recon Loss: 3242.3603515625, KL Loss: 467.5703125
Weights:  tensor([1, 1])
Epoch: 6, Loss: 3667.258056640625, Recon Loss: 3202.8046875, KL Loss: 464.4531555175781
Weights:  tensor([1, 1])
Epoch: 7, Loss: 3654.521240234375, Recon Loss: 3198.266357421875, KL Loss: 456.2549133300781
Weights:  tensor([1, 1])
Epoch: 8, Loss: 3625.47412109375, Recon Loss: 3150.82958984375, KL Loss: 474.64422607421875
Weights:  tensor([1, 1])
Epoch: 9, Loss: 3601.628173828125, Recon Loss: 3099.199951171875,

# Train example with softadapt

In [10]:
from softadapt import SoftAdapt, NormalizedSoftAdapt, LossWeightedSoftAdapt
import wandb

model = VAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=2)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Change 1: Create a SoftAdapt object (with your desired variant)
softadapt_object = LossWeightedSoftAdapt(beta=0.001)

# Change 2: Define how often SoftAdapt calculate weights for the loss components
epochs_to_make_updates = 5

values_of_component_1 = []
values_of_component_2 = []
# Initializing adaptive weights to all ones.
adapt_weights = torch.tensor([1,1])

limit = 101

count = 0

for current_epoch in range(1, 30):
    for x, y in train_loader:
        optimizer.zero_grad()
        count += 1
        x_recon, mean, log_var = model(x, y)
        loss = model.loss((x, y), (x_recon, mean, log_var))

        bce_loss = loss['recon_loss_0']
        kld = loss['kl_loss']

        values_of_component_1.append(bce_loss)
        values_of_component_2.append(kld)

        if (current_epoch % epochs_to_make_updates == 0 and current_epoch > 1 and count >= limit) or count >= limit:
            # Change 3: Update weights of components
            count = 0
            # print("Adaptive weights: ", adapt_weights)
            # print("epoch")
            # print(current_epoch)
            first = torch.tensor(values_of_component_1, dtype=torch.float64)
            second = torch.tensor(values_of_component_2, dtype=torch.float64)
            # print(first)
            # print(second)
            # print(first.dtype)
            # print(second.dtype)
            # print(first.shape)
            # print(second.shape)
            adapt_weights = softadapt_object.get_component_weights(first, second,verbose=False)
            #print("WORKS")
                                           
        
            # Resetting the lists to start fresh (this part is optional)
            values_of_component_1 = []
            values_of_component_2 = []

        loss = adapt_weights[0] * bce_loss + adapt_weights[1] * kld
        
        loss.backward()
        optimizer.step()

    # Validate
    with torch.no_grad():
        losses = []
        bce_losses = []
        kld_losses = []

        for x, y in val_loader:
            x_recon, mean, log_var = model(x, y)
            loss = model.loss((x, y), (x_recon, mean, log_var))
            losses.append(loss['loss'])
            bce_losses.append(loss['recon_loss_0'])
            kld_losses.append(loss['kl_loss'])
        print("Weights: ", adapt_weights)
        print("Epoch: {}, Loss: {}, Recon Loss: {}, KL Loss: {}".format(current_epoch, torch.mean(torch.tensor(losses)), torch.mean(torch.tensor(bce_losses)), torch.mean(torch.tensor(kld_losses))))
        

Weights:  tensor([0.8314, 0.1686], dtype=torch.float64)
Epoch: 1, Loss: 4298.5791015625, Recon Loss: 3638.10595703125, KL Loss: 660.4732666015625
Weights:  tensor([0.8357, 0.1643], dtype=torch.float64)
Epoch: 2, Loss: 4009.738525390625, Recon Loss: 3350.79150390625, KL Loss: 658.9473266601562
Weights:  tensor([0.8220, 0.1780], dtype=torch.float64)
Epoch: 3, Loss: 3932.1123046875, Recon Loss: 3269.55126953125, KL Loss: 662.5613403320312
Weights:  tensor([0.8353, 0.1647], dtype=torch.float64)
Epoch: 4, Loss: 3866.88134765625, Recon Loss: 3165.371826171875, KL Loss: 701.5097045898438
Weights:  tensor([0.7787, 0.2213], dtype=torch.float64)
Epoch: 5, Loss: 3794.3525390625, Recon Loss: 3155.98779296875, KL Loss: 638.365234375
Weights:  tensor([0.8174, 0.1826], dtype=torch.float64)
Epoch: 6, Loss: 3740.163818359375, Recon Loss: 3067.777587890625, KL Loss: 672.3861694335938
Weights:  tensor([0.8183, 0.1817], dtype=torch.float64)
Epoch: 7, Loss: 3714.518798828125, Recon Loss: 3017.746826171875,