In [2]:
from argparse import ArgumentParser
from copy import deepcopy
from typing import Any, Union

import torch
from pytorch_lightning import LightningModule, Trainer, seed_everything
from torch.nn import functional as F
from torch.optim import Adam

from transformers import BertModel, BertConfig

from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate
from pl_bolts.models.self_supervised.byol.models import SiameseArm
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR

In [3]:
class Data2Vec(LightningModule):
    """PyTorch Lightning implementation of Data2Vec by Meta AI.
    
    Codebase extended from BYOL implementaion of `Annika Brundyn <https://github.com/annikabrundyn>` \
    found at https://github.com/PyTorchLightning/lightning-bolts/tree/master/pl_bolts/models/self_supervised/byol

    Model implemented by:
        - `Haris Jabbar <https://github.com/maveriq>`_

    .. warning:: Work in progress. This implementation is still being verified.

    TODOs:
        - Implement data augmentation pipeline
        - Verify Implementation
        - Implement selectable top K layers instead of all (current)

    Example::

        model = Data2Vec()

        dm = ///tobeimplemented...

        trainer = pl.Trainer()
        trainer.fit(model, datamodule=dm)

    Train::

        trainer = Trainer()
        trainer.fit(model)

    CLI command::


    .. _BYOL: https://arxiv.org/pdf/2006.07733.pdf
    """

    def __init__(
        self,
        num_classes,
        learning_rate: float = 0.2,
        weight_decay: float = 1.5e-6,
        input_height: int = 32,
        batch_size: int = 32,
        num_workers: int = 0,
        warmup_epochs: int = 10,
        max_epochs: int = 1000,
        # base_encoder: Union[str, torch.nn.Module] = "resnet50",
        **kwargs
    ):
        """
        Args:
            datamodule: The datamodule
            learning_rate: the learning rate
            weight_decay: optimizer weight decay
            input_height: image input height
            batch_size: the batch size
            num_workers: number of workers
            warmup_epochs: num of epochs for scheduler warm up
            max_epochs: max epochs for scheduler
            base_encoder: the base encoder module or resnet name
            encoder_out_dim: output dimension of base_encoder
            projector_hidden_size: hidden layer size of projector MLP
            projector_out_dim: output size of projector MLP
        """
        super().__init__()
        self.save_hyperparameters(ignore="base_encoder")
        
        config = BertConfig.from_pretrained('bert-base-uncased')
        config.output_hidden_states=True

        self.teacher_network = BertModel(config,)
        self.student_network = deepcopy(self.teacher_network)
        self.weight_callback = BYOLMAWeightUpdate()
        self.loss_fn = torch.nn.MSELoss()

    def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
        # Add callback for user automatically since it's key to BYOL weight update
        self.weight_callback.on_train_batch_end(self.trainer, self, outputs, batch, batch_idx, dataloader_idx)

    def forward(self, x):
        output = self.online_network(x)
        return y

    def shared_step(self, batch, batch_idx):
        raw_input, masked_input = batch

        # Image 1 to image 2 loss
        output_student = self.student_network(img_1)
        with torch.no_grad():
            output_teacher = self.teacher_network(img_2)

        embed_student = torch.cat(output_student.hidden_states,0).mean(0)
        embed_teacher = torch.cat(output_teacher.hidden_states,0).mean(0)
        # Final loss
        loss = self.loss_fn(embed_student, embed_teacher)

        return loss

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, batch_idx)

        # log results
        self.log({"train_loss": loss})

        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch, batch_idx)

        # log results
        self.log({"valid_loss": loss})

        return loss

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
        scheduler = LinearWarmupCosineAnnealingLR(
            optimizer, warmup_epochs=self.hparams.warmup_epochs, max_epochs=self.hparams.max_epochs
        )
        return [optimizer], [scheduler]

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument("--online_ft", action="store_true", help="run online finetuner")
        parser.add_argument("--dataset", type=str, default="cifar10", choices=["cifar10", "imagenet2012", "stl10"])

        (args, _) = parser.parse_known_args()

        # Data
        parser.add_argument("--data_dir", type=str, default=".")
        parser.add_argument("--num_workers", default=8, type=int)

        # optim
        parser.add_argument("--batch_size", type=int, default=256)
        parser.add_argument("--learning_rate", type=float, default=1e-3)
        parser.add_argument("--weight_decay", type=float, default=1.5e-6)
        parser.add_argument("--warmup_epochs", type=float, default=10)

        # Model
        parser.add_argument("--meta_dir", default=".", type=str, help="path to meta.bin for imagenet")

        return parser

In [4]:
import math
from typing import Sequence, Union

from pytorch_lightning import Callback, LightningModule, Trainer
from torch import Tensor
from torch.nn import Module

In [5]:
class BYOLMAWeightUpdate(Callback):
    """Weight update rule from BYOL.
    Your model should have:
        - ``self.student_network``
        - ``self.teacher_network``
    Updates the target_network params using an exponential moving average update rule weighted by tau.
    BYOL claims this keeps the online_network from collapsing.
    .. note:: Automatically increases tau from ``initial_tau`` to 1.0 with every training step
    Example::
        # model must have 2 attributes
        model = Model()
        model.student_network = ...
        model.teacher_network = ...
        trainer = Trainer(callbacks=[BYOLMAWeightUpdate()])
    """

    def __init__(self, initial_tau: float = 0.996):
        """
        Args:
            initial_tau: starting tau. Auto-updates with every training step
        """
        super().__init__()
        self.initial_tau = initial_tau
        self.current_tau = initial_tau

    def on_train_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        # get networks
        student_net = pl_module.student_network
        teacher_net = pl_module.teacher_network

        # update weights
        self.update_weights(student_net, teacher_net)

        # update tau after
        self.current_tau = self.update_tau(pl_module, trainer)

    def update_tau(self, pl_module: LightningModule, trainer: Trainer) -> float:
        max_steps = len(trainer.train_dataloader) * trainer.max_epochs
        tau = 1 - (1 - self.initial_tau) * (math.cos(math.pi * pl_module.global_step / max_steps) + 1) / 2
        return tau

    def update_weights(self, online_net: Union[Module, Tensor], target_net: Union[Module, Tensor]) -> None:
        # apply MA weight update
        for (name, online_p), (_, target_p) in zip(
            online_net.named_parameters(),
            target_net.named_parameters(),
        ):
            target_p.data = self.current_tau * target_p.data + (1 - self.current_tau) * online_p.data