In [None]:
import torch
import torchvision
from torchvision import transforms
from torch.nn import functional as F
import torch.nn as nn
from matplotlib import pyplot as plt
import numpy as np
import collections
from pathlib import Path
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from moco_model import MoCo
import time
import itertools
import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"


In [None]:
class ModifiedMOCO(MoCo):
    def __init__(self, base_encoder, negative_samples, dim=512, K=1024, m=0.999, T=0.07, mlp=True):
        super().__init__(base_encoder=base_encoder, dim=dim, K=K, m=m, T=T, mlp=mlp, three_channel=True, pretrained=True)
        self.negative_samples = negative_samples
        self.register_buffer('ns_ptr', torch.zeros(1, dtype=torch.long))
        self.accompany_index = 25
        self.valid_ns = (negative_samples.shape[0] // 384) * 384

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        """
        Here implement one of the choice to extract all negative samples sequentially. In this case, the queue is updated so frequently.
        And it may be better to just extract the size of several times of batch size of negative samples to make the queue more stable.
        :param keys:
        :return:
        """
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)
        ns_ptr = int(self.ns_ptr)
        with torch.no_grad():
            negative_features = self.encoder_k(self.negative_samples[ns_ptr: ns_ptr + 384])
            negative_features = F.normalize(negative_features, dim=1)
        ns_ptr = (ns_ptr + 384) % self.valid_ns
        self.ns_ptr[0] = ns_ptr
        enque_array = torch.cat([keys, negative_features], dim=0)
        enque_array = enque_array[torch.randperm(len(enque_array))]
        self.queue[:, ptr: ptr + len(enque_array)] = enque_array.T
        ptr = (ptr + batch_size) % self.K
        self.queue_ptr[0] = ptr
