In [1]:
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import Callback, LearningRateMonitor
from collections import namedtuple
import pdb
from pytorch_lightning.loggers import CSVLogger

In [2]:
class BoringModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

In [18]:
class BoringModelSWA(pl.LightningModule):

    def __init__(self, args,model):
        super().__init__()

        self.args = args
#         self.model = BoringModel()
        self.swa_model = AveragedModel(model)
        self.swa_start_step = self.args.swa_start_step

    def forward(self, x):
        return self.swa_model(x)

    def configure_optimizers(self):

        optim = torch.optim.AdamW(self.swa_model.module.parameters(), lr=self.args.lr,weight_decay=self.args.weight_decay)
        print(f"Total number of training steps : {self.num_training_steps}")
        sched = torch.optim.lr_scheduler.OneCycleLR(optim,max_lr=self.args.lr,total_steps=self.num_training_steps,anneal_strategy='linear')
        self.swa_scheduler = SWALR(optim, swa_lr=self.args.lr)

        lr_scheduler_config = {
            "scheduler": sched,
            "interval": "step",
            "frequency": 1,
#             "name": 'lr_scheduler',
        }
        return {'optimizer':optim,
                'scheduler':lr_scheduler_config
                }

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        
        if self.trainer.global_step > self.swa_start_step:
#             pdb.set_trace()
            self.swa_model.update_parameters(self.swa_model.module)
            self.swa_scheduler.step()
            
        return {"loss": loss}

    @property
    def num_training_steps(self) -> int:
        """Total training steps inferred from datamodule and devices."""
#         pdb.set_trace()
        if self.trainer.max_steps:
            return self.trainer.max_steps

        limit_batches = self.trainer.limit_train_batches
        batches = len(self.train_dataloader())
        batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches)     

        num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
        if self.trainer.tpu_cores:
            num_devices = max(num_devices, self.trainer.tpu_cores)

        effective_accum = self.trainer.accumulate_grad_batches * num_devices
        return (batches // effective_accum) * self.trainer.max_epochs

In [None]:
class SWAResnet(LitResnet):
    def __init__(self, trained_model, lr=0.01):
        super().__init__()

        self.save_hyperparameters("lr")
        self.model = trained_model
        self.swa_model = AveragedModel(self.model)
        self.swa_start_step = 100

    def forward(self, x):
        out = self.swa_model(x)
        return F.log_softmax(out, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss)

        if self.trainer.global_step > self.swa_start_step:
#             pdb.set_trace()
            self.swa_model.update_parameters(self.model)
            self.swa_scheduler.step()
        
        return loss
    
#     def training_epoch_end(self, training_step_outputs):
#         self.swa_model.update_parameters(self.model)

    def validation_step(self, batch, batch_idx, stage=None):
        x, y = batch
        logits = F.log_softmax(self.model(x), dim=1)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)
        self.swa_scheduler = SWALR(optimizer, swa_lr=self.args.lr)
        return optimizer

    def on_train_end(self):
        update_bn(self.datamodule.train_dataloader(), self.swa_model, device=self.device)

In [19]:
class RandomDataset(torch.utils.data.Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

In [20]:
class UpdateBNCallback(Callback):
        
    def on_train_end(self, trainer, pl_module):
        print("Updating BatchNorm weights")
        torch.optim.swa_utils.update_bn(pl_module.train_dataloader(), pl_module.swa_model)

In [21]:
lr_monitor = LearningRateMonitor(logging_interval='step',log_momentum=True)

In [22]:
logger = CSVLogger("logs", name="logging_lr")

In [23]:
Args = namedtuple('Args',['lr','weight_decay','swa_start_step'])

In [24]:
args = Args(lr=0.01,weight_decay=1e-6,swa_start_step=500)

In [25]:
model = BoringModel()

In [26]:
modelswa = BoringModelSWA(args,model)

In [27]:
train_data = torch.utils.data.DataLoader(RandomDataset(32, 64000), batch_size=4)

In [28]:
trainer = pl.Trainer(
        accelerator='dp',
        gpus=2,
        limit_train_batches=700,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        weights_summary=None,
        callbacks=[UpdateBNCallback(), lr_monitor],
        logger=logger
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [29]:
trainer.fit(modelswa, train_dataloaders=train_data)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Total number of training steps : 350


  rank_zero_warn(
  rank_zero_warn(


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Updating BatchNorm weights



In [15]:
lr_monitor.lrs

{}

In [16]:
lr_monitor.last_momentum_values.values()

dict_values([])

In [17]:
trainer.logged_metrics

{'train_loss': -7.996602535247803, 'epoch': 0}