In [None]:
import torch
import random
import numpy as np

from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union

import warnings
from collections.abc import Mapping

from transformers import (
    BertTokenizer, BertTokenizerFast
)

In [None]:
a={'a':1, 'b':2}
c,d=a.values()
c

In [None]:
def tolist(x):
    if isinstance(x, list):
        return x
    elif hasattr(x, "numpy"):  # Checks for TF tensors without needing the import
        x = x.numpy()
    return x.tolist()

class ProcessorForWholeWordMask(torch.nn.Module):
    def __init__(self, tokenizer, mlm_probability):
        super().__init__()
        self.tokenizer = tokenizer
        self.mlm_probability = mlm_probability
    
    def _torch_collate_batch(self, examples):
        """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
        pad_to_multiple_of = None

        # Tensorize if necessary.
        if isinstance(examples[0], (list, tuple, np.ndarray)):
            examples = [torch.tensor(e, dtype=torch.long) for e in examples]

        length_of_first = examples[0].size(0)

        # Check if padding is necessary.

        are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
        if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
            return torch.stack(examples, dim=0)

        # If yes, check if we have a `pad_token`.
        if self.tokenizer._pad_token is None:
            raise ValueError(
                "You are attempting to pad samples but the tokenizer you are using"
                f" ({self.tokenizer.__class__.__name__}) does not have a pad token."
            )

        # Creating the full tensor and filling it with our data.
        max_length = max(x.size(0) for x in examples)
        if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
            max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
        result = examples[0].new_full([len(examples), max_length], self.tokenizer.pad_token_id)
        for i, example in enumerate(examples):
            if self.tokenizer.padding_side == "right":
                result[i, : example.shape[0]] = example
            else:
                result[i, -example.shape[0] :] = example
        return result
    
    def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
        """
        Get 0/1 labels for masked tokens with whole word mask proxy
        """
        if not isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)):
            warnings.warn(
                "DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers. "
                "Please refer to the documentation for more information."
            )

        cand_indexes = []
        for i, token in enumerate(input_tokens):
            if token == "[CLS]" or token == "[SEP]":
                continue

            if len(cand_indexes) >= 1 and token.startswith("##"):
                cand_indexes[-1].append(i)
            else:
                cand_indexes.append([i])

        random.shuffle(cand_indexes)
        num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
        masked_lms = []
        covered_indexes = set()
        for index_set in cand_indexes:
            if len(masked_lms) >= num_to_predict:
                break
            # If adding a whole-word mask would exceed the maximum number of
            # predictions, then just skip this candidate.
            if len(masked_lms) + len(index_set) > num_to_predict:
                continue
            is_any_index_covered = False
            for index in index_set:
                if index in covered_indexes:
                    is_any_index_covered = True
                    break
            if is_any_index_covered:
                continue
            for index in index_set:
                covered_indexes.add(index)
                masked_lms.append(index)

        if len(covered_indexes) != len(masked_lms):
            raise ValueError("Length of covered_indexes is not equal to length of masked_lms.")
        mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
        return mask_labels
    
    def torch_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
        'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
        """
        import torch

        if self.tokenizer.mask_token is None:
            raise ValueError(
                "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
                " --mlm flag if you want to use this tokenizer."
            )
        labels = inputs.clone()
        # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)

        probability_matrix = mask_labels

        special_tokens_mask = [
            self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
        ]
        probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
        if self.tokenizer._pad_token is not None:
            padding_mask = labels.eq(self.tokenizer.pad_token_id)
            probability_matrix.masked_fill_(padding_mask, value=0.0)

        masked_indices = probability_matrix.bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels
    
    def forward(self, examples):
        if isinstance(examples, Mapping):
            input_ids = examples["input_ids"]
            examples = [{"input_ids": e} for e in input_ids]
        elif isinstance(examples[0], Mapping):
            input_ids = [e["input_ids"] for e in examples]
        else:
            input_ids = examples
            examples = [{"input_ids": e} for e in examples]

        batch_input = self._torch_collate_batch(input_ids)
    
        mask_labels = []
        for e in examples:
            ref_tokens = []
            for id in tolist(e["input_ids"]):
                token = self.tokenizer._convert_id_to_token(id)
                ref_tokens.append(token)

            # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜，##欢]
            if "chinese_ref" in e:
                ref_pos = tolist(e["chinese_ref"])
                len_seq = len(e["input_ids"])
                for i in range(len_seq):
                    if i in ref_pos:
                        ref_tokens[i] = "##" + ref_tokens[i]
            mask_labels.append(self._whole_word_mask(ref_tokens))
        
        batch_mask = self._torch_collate_batch(mask_labels)
        inputs, labels = self.torch_mask_tokens(batch_input, batch_mask)

        return {"input_ids": inputs, "labels": labels}
        

        


In [None]:
def get_pretrained_tokenizer(from_pretrained):
    if torch.distributed.is_initialized():
        if torch.distributed.get_rank() == 0:
            BertTokenizer.from_pretrained(
                from_pretrained, do_lower_case="uncased" in from_pretrained
            )
        torch.distributed.barrier()
    return BertTokenizer.from_pretrained(
        from_pretrained, do_lower_case="uncased" in from_pretrained
    )
tokenizer = get_pretrained_tokenizer("bert-base-uncased")
mlm_processor = ProcessorForWholeWordMask(tokenizer, mlm_probability=0.5)

In [None]:
caption1 = 'I am so cool!'
encoding1 = tokenizer(
            caption1,
            padding="max_length",
            truncation=True,
            max_length=40,
            return_special_tokens_mask=True,
        )
caption2 = 'That giraffe'
encoding2 = tokenizer(
            caption2,
            padding="max_length",
            truncation=True,
            max_length=40,
            return_special_tokens_mask=True,
        )
examples = [encoding1, encoding2]

In [None]:
caption1 = ['I am so cool!', 'That giraffe']
examples = tokenizer(
            caption1,
            padding="max_length",
            truncation=True,
            max_length=40,
            return_special_tokens_mask=True,
        )

In [None]:
examples

In [None]:
import torch

In [None]:
ckpt_path = 'checkpoints/mae_pretrain_vit_base.pth'

In [None]:
ckpt = torch.load(ckpt_path, map_location='cpu')
state_dict = ckpt['model']

In [1]:
import torch
from models_cook import ContinualModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import argparse
def get_args_parser():
    parser = argparse.ArgumentParser('ContinualTransformer pre-training', add_help=False)
    # architecture
    parser.add_argument('--image_size', default=224, type=int)
    parser.add_argument('--model', default='vlmo_base_patch16', type=str)
    parser.add_argument('--drop_path_rate', default=0.1, type=float)
    # language modeling
    parser.add_argument('--max_text_len', default=196, type=int)
    parser.add_argument('--max_text_len_of_initckpt', default=196, type=int)
    parser.add_argument('--vocab_size', default=30522, type=int)
    parser.add_argument('--mlm_probability', default=0.15, type=float)
    return parser
config = get_args_parser().parse_known_args()[0]


In [3]:
# config={'image_size': 224, 'model': 'vlmo_base_patch16', 'drop_path_rate': 0.1, 'max_text_len': 196, 'vocab_size':30522, 'max_text_len_of_initckpt': 196, 'mlm_probability': 0.15}
m = ContinualModel(config=config)

window_size: (14, 14)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [4]:
from util.misc import convert_init_ckpt

a=convert_init_ckpt('checkpoints/beit_base_patch16_224_pt22k_ft22kto1k.pth', module=m, config=config)

In [5]:
torch.save({"model": a}, 'checkpoints/beit_base_patch16_224_pt22k_ft22kto1k_transfertovlmo.pth')

In [6]:
]]

AttributeError: 'ContinualModel' object has no attribute 'mlm_loss'