In [None]:
# default_exp model.mocoae

In [None]:
# hide
%load_ext autoreload
%autoreload 2

# MoCoAE

> Momentum Contrast for Autoencoder based Representation Learning
> Based and modified from the Github repository of momentum contrast: https://github.com/facebookresearch/moco

In [None]:
# export
import torch
from torch import nn, optim
import torch.nn.functional as F
from deeptool.architecture import Encoder, Decoder, DownUpConv
from deeptool.utils import Tracker

In [None]:
# load some test dataset to confirm architecture:
from deeptool.parameters import get_all_args
from deeptool.dataloader import load_test_batch

args = get_all_args()
args.model_type = "rnnvae"
args.batch_size = 5
args.track = False
batch = load_test_batch(args)
batch["img"].shape

torch.Size([5, 3, 16, 256, 256])

In [None]:
# export


class MoCoAE(nn.Module):
    """
    The MoCoAE contains the Autoencoder based Architecture and the modified Pretext task
    """

    def __init__(self, device, args):
        """init the network"""
        super(MoCoAE, self).__init__()
        self.device = device  # GPU
        self.dim = args.dim  # 2/3 Dimensional input
        self.n_z = args.n_z  # Compression

        ### MoCo specific args
        self.K = args.moco_K  # limit of the queue
        self.tau = args.moco_tau  # temperature
        self.m = args.moco_m  # momentum

        # Encoder
        self.enc_q = Encoder(args, vae_mode=False).to(self.device)  # query encoder
        self.enc_k = Encoder(args, vae_mode=False).to(self.device)  # key encoder

        # Decoder
        self.dec_q = Decoder(args).to(self.device)  # query decoder
        self.dec_k = Decoder(args).to(self.device)  # key decoder

        # set the params of the knetwork to be equal q network:
        copy_q2k_params(self.enc_q, self.enc_k)
        copy_q2k_params(self.dec_q, self.dec_k)

        # Initialise the randomised Queues for Momentum Contrastive Learning
        self.register_queue("enc_queue")
        self.register_queue("dec_queue")

        # Save the pointer position as well
        self.register_buffer("ptr_enc", torch.zeros(1, dtype=torch.long).to(self.device))
        self.register_buffer("ptr_dec", torch.zeros(1, dtype=torch.long).to(self.device))

        # optimizers
        self.optimizerEnc = optim.Adam(self.enc_q.parameters(), lr=args.lr)
        self.optimizerDec = optim.Adam(self.dec_q.parameters(), lr=args.lr)

        # Setup the tracker to visualize the progress
        if args.track:
            self.tracker = Tracker(args)

    @torch.no_grad()
    def register_queue(self, name: str):
        """
        Register the queue as a buffer with no parameters in the state dict
        """
        # create the queue
        self.register_buffer(name, torch.randn(self.n_z, self.K).to(self.device))
        setattr(self, name, nn.functional.normalize(getattr(self, name), dim=0))

    @torch.no_grad()
    def watch_progress(self, test_data, iteration):
        """
        Outsourced to Tracker
        """
        self.tracker.track_progress(self, test_data, iteration)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys, mode="enc"):
        """
        Update the Queue and Pointer ->
        available in mode 'enc' and 'dec'
        """
        # gather keys before updating queue
        batch_size = keys.shape[0]

        ptr = int(getattr(self, f"ptr_{mode}"))

        # replace the keys at ptr (dequeue and enqueue)
        if mode == "enc":
            self.enc_queue[:, ptr : ptr + batch_size] = keys.T
            ptr = (ptr + batch_size) % self.K
            self.ptr_enc[0] = ptr

        # mode is 'dec'
        else:
            self.dec_queue[:, ptr : ptr + batch_size] = keys.T
            ptr = (ptr + batch_size) % self.K
            self.ptr_dec[0] = ptr

    def forward(self, data, update=True):
        """
        Perform forward computaion and update
        """
        # Reset Gradients
        self.optimizerEnc.zero_grad()
        self.optimizerDec.zero_grad()

        # 1. Send data to device
        x = data["img"]

        # 2. further we will apply additional augmentation to the picture!
        x_q = aug(x).to(self.device)
        x_k = aug(x).to(self.device)

        # 3. Encode
        q = self.enc_q(x_q)
        q = nn.functional.normalize(q, dim=1)

        with torch.no_grad():
            k = self.enc_k(x_k)
            k = nn.functional.normalize(k, dim=1)
        
        # Get the InfoNCE loss:
        loss_enc = MomentumContrastiveLoss(k, q, self.enc_queue, self.tau, device=self.device)

        # Perform encoder update
        if update: 
            loss_enc.backward()

            # update the Query Encoder
            self.optimizerEnc.step()

            # update the Key Encoder with Momentum update
            momentum_update(self.enc_q, self.enc_k, self.m)

        # append keys to the queue
        self._dequeue_and_enqueue(k, mode="enc")

        # 4. Decode
        x_qq = self.dec_q(q.detach())

        with torch.no_grad():
            x_kk = self.dec_k(k)

        # 5. Encode again using the k-network to focus on decoder only!:
        qq = self.enc_k(x_qq)
        qq = nn.functional.normalize(qq, dim=1)

        with torch.no_grad():
            kk = self.enc_k(x_kk).detach()
            kk = nn.functional.normalize(kk, dim=1)
        
        # Get the InfoNCE loss:
        loss_dec = MomentumContrastiveLoss(kk, qq, self.dec_queue, self.tau, device=self.device)

        # perform decoder update
        if update:
            loss_dec.backward()

            # update the Query Decoder
            self.optimizerDec.step()

            # update the Key Decoder with Momentum update
            momentum_update(self.dec_q, self.dec_k, self.m)

        # append keys to the queue
        self._dequeue_and_enqueue(kk, mode="dec")

        if update:
            return x_kk

        else:
            tr_data = {
                "loss_enc": loss_enc.item(),
                "loss_dec": loss_dec.item(),
            }
            return x_kk, tr_data

## Functions handling the queue

In [None]:
# export
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [
        torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
    ]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output

## Functions Handling the K and Q Network updates:

In [None]:
# export
@torch.no_grad()
def copy_q2k_params(Q_network: nn.Module, K_network: nn.Module):
    """
    Helper function to Copy parameters from Network Q to network K.
    Further deactive gradient computation on k
    """
    for param_q, param_k in zip(Q_network.parameters(), K_network.parameters()):
        param_k.data.copy_(param_q.data)  # initialize
        param_k.requires_grad = False  # not updated by gradient

In [None]:
# export
@torch.no_grad()
def momentum_update(Q_network: nn.Module, K_network: nn.Module, m: float):
    """
    Momentum update of the key network based on the query network
    """
    for param_q, param_k in zip(Q_network.parameters(), K_network.parameters()):
        param_k.data = param_k.data * m + param_q.data * (1.0 - m)

## Data Augmentation:

In [None]:
# export
def aug(x):
    """perform random data augmentation on an image batch"""
    # ToDo
    return x

## Momentum Contrastive Loss:

$ Loss = -\log \left( \frac{ \exp{ \frac{q \cdot k_+}{\tau} } }{\sum_{i=0}^{n}{\exp{\frac{q \cdot k_i}{\tau}  }} } \right) $

In [None]:
# export
ce_loss = nn.CrossEntropyLoss()


def MomentumContrastiveLoss(k, q, queue, tau, device):
    """
    Calculate the loss of the network depending on the current key(k), the query(q)
    and the overall queue(queue)
    We follow the suggestion of the paper, Algorithm 1:
    https://arxiv.org/pdf/1911.05722.pdf
    """
    N, C = q.shape
    K = k.shape[1]

    # positive logits: Nx1
    l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)

    # negative logits: NxK
    l_neg = torch.einsum("nc,ck->nk", [q, queue.clone().detach()])

    # logits: Nx(1+K) with temperature
    logits = torch.cat([l_pos, l_neg], dim=1) / tau

    # positive key indicators
    labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

    # calculate the crossentropyloss
    loss = ce_loss(logits, labels)

    return loss

In [None]:
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
print(input.shape, target.shape)

torch.Size([3, 5]) torch.Size([3])


In [None]:
# hide
from nbdev.export import *

notebook2script()

Converted 00_dataloader.ipynb.
Converted 01_architecture.ipynb.
Converted 02_utils.ipynb.
Converted 03_parameters.ipynb.
Converted 04_train_loop.ipynb.
Converted 10_diagnosis.ipynb.
Converted 20_dcgan.ipynb.
Converted 21_introvae.ipynb.
Converted 22_vqvae.ipynb.
Converted 23_bigan.ipynb.
Converted 24_mocoae.ipynb.
Converted 33_rnn_vae.ipynb.
Converted 99_index.ipynb.
