In [1]:
!git clone https://github.com/William-Chittavong/torchscale.git

fatal: destination path 'torchscale' already exists and is not an empty directory.


In [None]:
# LEFT OFF: hook has wrong path for imports. change them

In [2]:
from torchscale.model.BEiT3 import BEiT3
print(BEiT3.__init__.__code__.co_varnames)


ModuleNotFoundError: No module named 'hook'

In [3]:
%cd torchscale

/home/william/Documents/GitHub/torchscale/torchscale


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [3]:
import math
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_ as __call_trunc_normal_

from torchscale.model.BEiT3 import BEiT3
from torchscale.architecture.config import EncoderConfig

# modeling_utils

In [4]:
def trunc_normal_(tensor, mean=0., std=1.):
    __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)


def get_base_config(
        img_size=224, patch_size=16, drop_path_rate=0, 
        checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs
):
    return EncoderConfig(
        img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=True, 
        layernorm_embedding=False, normalize_output=True, no_output_layer=True, 
        drop_path_rate=drop_path_rate, encoder_embed_dim=768, encoder_attention_heads=12, 
        encoder_ffn_embed_dim=int(768 * mlp_ratio), encoder_layers=12, 
        checkpoint_activations=checkpoint_activations, 
    )


def get_large_config(
        img_size=224, patch_size=16, drop_path_rate=0, 
        checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs
):
    return EncoderConfig(
        img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=True, 
        layernorm_embedding=False, normalize_output=True, no_output_layer=True, 
        drop_path_rate=drop_path_rate, encoder_embed_dim=1024, encoder_attention_heads=16, 
        encoder_ffn_embed_dim=int(1024 * mlp_ratio), encoder_layers=24, 
        checkpoint_activations=checkpoint_activations, 
    )


class BEiT3Wrapper(nn.Module):
    def __init__(self, args, **kwargs):
        super().__init__()
        self.args = args
        self.beit3 = BEiT3(args)
        self.apply(self._init_weights)

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def get_num_layers(self):
        return self.beit3.encoder.num_layers

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token', 'beit3.encoder.embed_positions.A.weight', 'beit3.vision_embed.cls_token', 'logit_scale'}

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

In [5]:
import numpy as np
import torch.nn.functional as F

import torch.distributed as dist

def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True

def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


class GatherLayer(torch.autograd.Function):
    """
    Gather tensors from all workers with support for backward propagation:
    This implementation does not cut the gradients as torch.distributed.all_gather does.
    """
    @staticmethod
    def forward(ctx, x):
        output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
        dist.all_gather(output, x)
        return tuple(output)
    @staticmethod
    def backward(ctx, *grads):
        all_gradients = torch.stack(grads)
        dist.all_reduce(all_gradients)
        return all_gradients[dist.get_rank()]


def gather_features(
        image_features,
        text_features,
):
    gathered_image_features = GatherLayer.apply(image_features)
    gathered_text_features = GatherLayer.apply(text_features)
    all_image_features = torch.cat(gathered_image_features)
    all_text_features = torch.cat(gathered_text_features)

    return all_image_features, all_text_features


# The implementation code is modified from open_clip (https://github.com/mlfoundations/open_clip.git)
class ClipLoss(nn.Module):

    def __init__(
            self,
            cache_labels=False,
            rank=0,
            world_size=1,
    ):
        super().__init__()
        self.cache_labels = cache_labels
        self.rank = rank
        self.world_size = world_size

        # cache state
        self.prev_num_logits = 0
        self.labels = {}

    def forward(self, image_features, text_features, logit_scale):
        device = image_features.device
        if self.world_size > 1:
            all_image_features, all_text_features = gather_features(
                image_features, text_features
            )

            logits_per_image = logit_scale * image_features @ all_text_features.T
            logits_per_text = logit_scale * text_features @ all_image_features.T
        else:
            logits_per_image = logit_scale * image_features @ text_features.T
            logits_per_text = logit_scale * text_features @ image_features.T

        # calculated ground-truth and cache if enabled
        num_logits = logits_per_image.shape[0]
        if self.prev_num_logits != num_logits or device not in self.labels:
            labels = torch.arange(num_logits, device=device, dtype=torch.long)
            if self.world_size > 1:
                labels = labels + num_logits * self.rank
            if self.cache_labels:
                self.labels[device] = labels
                self.prev_num_logits = num_logits
        else:
            labels = self.labels[device]

        total_loss = (
            F.cross_entropy(logits_per_image, labels) +
            F.cross_entropy(logits_per_text, labels)
            ) / 2
        return total_loss, logits_per_image, logits_per_text




In [7]:

class BEiT3ForRetrieval(BEiT3Wrapper):
    def __init__(
            self, 
            args,
            hook,
            **kwargs
    ):
        super(BEiT3ForRetrieval, self).__init__(args=args)
        self.hook_manager = hook
        embed_dim = args.encoder_embed_dim
        self.language_head = nn.Linear(embed_dim, embed_dim, bias=False)
        self.vision_head = nn.Linear(embed_dim, embed_dim, bias=False)
        self.language_head.apply(self._init_weights)
        self.vision_head.apply(self._init_weights)
        self.criterion = ClipLoss(
            rank=get_rank(), 
            world_size=get_world_size(), 
        )
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def forward(self, image=None, text_description=None, padding_mask=None, only_infer=False, **kwargs):
        if image is not None:
            outputs = self.beit3(
                textual_tokens=None, 
                visual_tokens=image, 
                text_padding_position=None,
                hook = self.hook_manager 
            )
            x = outputs["encoder_out"]
            vision_cls = self.vision_head(x[:, 0, :])
            vision_cls = F.normalize(vision_cls, dim=-1)
        else:
            vision_cls = None

        if text_description is not None:
            outputs = self.beit3(
                textual_tokens=text_description, 
                visual_tokens=None, 
                text_padding_position=padding_mask, 
            )
            x = outputs["encoder_out"]
            language_cls = self.language_head(x[:, 0, :])
            language_cls = F.normalize(language_cls, dim=-1)
        else:
            language_cls = None
        
        if only_infer:
            return vision_cls, language_cls
        else:
            loss, logits_per_image, logits_per_text = self.criterion(
                vision_cls, language_cls, self.logit_scale.exp())
            return loss, vision_cls, language_cls




def create_beit3_retrieval_model(model_size='base',hook_manager= None, img_size=224, **kwargs):
    """
    Create a BEiT3 model for retrieval tasks.
    
    Args:
    model_size (str): 'base' or 'large'
    img_size (int): Image size (assuming square images)
    **kwargs: Additional arguments to pass to the model
    
    Returns:
    BEiT3ForRetrieval: The created model
    """
    if model_size not in ['base', 'large']:
        raise ValueError("model_size must be either 'base' or 'large'")
    
    if model_size == 'base':
        args = get_base_config(img_size=img_size, **kwargs)
    else:  # large
        args = get_large_config(img_size=img_size, **kwargs)
   
    
    model = BEiT3ForRetrieval(args, hook=hook_manager, **kwargs)
    
    return model


In [None]:
#why is kwargs wrong

In [8]:
hook = HookManager()
model_size = "base"
img_size = 224

retrieve_model = create_beit3_retrieval_model(model_size='base',hook_manager= hook, img_size=224)

In [9]:
retrieve_model

BEiT3ForRetrieval(
  (beit3): BEiT3(
    (text_embed): TextEmbedding(64010, 768)
    (vision_embed): VisionEmbedding(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (encoder): Encoder(
      (dropout_module): Dropout(p=0.0, inplace=False)
      (embed_positions): MutliwayEmbedding(
        (A): PositionalEmbedding(199, 768)
        (B): PositionalEmbedding(1024, 768)
      )
      (layers): ModuleList(
        (0-11): 12 x EncoderLayer(
          (self_attn): MultiheadAttention(
            (k_proj): MultiwayNetwork(
              (A): Linear(in_features=768, out_features=768, bias=True)
              (B): Linear(in_features=768, out_features=768, bias=True)
            )
            (v_proj): MultiwayNetwork(
              (A): Linear(in_features=768, out_features=768, bias=True)
              (B): Linear(in_features=768, out_features=768, bias=True)
            )
            (q_proj): MultiwayNetwork(
              (A): Linear(in_features=768, out_featu

In [22]:
class PRSLogger(object):
    def __init__(self, model, embed_dim,device):
        self.current_layer = 0
        self.device = device
        self.attentions = []
        self.mlps = []
        self.post_ln_std = None
        self.post_ln_mean = None
        self.model = model
        self.vision_head = torch.nn.Linear(embed_dim, embed_dim, bias=False)

    @torch.no_grad()
    def compute_attentions(self, ret):
        bias_term = self.model.encoder.layers[self.current_layer].self_attn.out_proj.bias

        self.current_layer += 1
        return_value = ret[:, 0].detach().cpu()
        self.attentions.append(
            return_value
            + bias_term[np.newaxis, np.newaxis, np.newaxis].cpu()
            / (return_value.shape[1] * return_value.shape[2])
        )  # [b, n, h, d]
        return ret

    @torch.no_grad()
    def compute_mlps(self, ret):
        self.mlps.append(ret[:, 0].detach().cpu())  # [b, d]
        return ret

 
    @torch.no_grad()
    def log_post_ln_mean(self, ret):
        self.post_ln_mean = ret.detach().cpu()  # [b, 1]
        return ret

    @torch.no_grad()
    def log_post_ln_std(self, ret):
        self.post_ln_std = ret.detach().cpu()  # [b, 1]
        return ret


    def register_hooks(self):
        self.model.hook_manager.register(
            "encoder.layer.*.self_attn.out_proj_post*",
            self.compute_attentions
        )
       
        self.model.hook_manager.register(
            "encoder.layer.not_moe.ffn.fc2_post",
            self.compute_mlps
        )
        
        #MOE FFNs
        self.model.hook_manager.register(
            "encoder.layer.moe.expert.*.ffn.fc2_post",
            self.compute_mlps
        )
        
        # IS THE THING BELOW needed? why is the layer norm before
        # the transformer resblocks included in the mlps?
        
        # LN before the other encoder layers but self attn already happened
        # what about layernorm in the forward embedding? ah nvm, its before self attn
        self.model.hook_manager.register(
            "encoder.layer.0.self_attn_layer_norm.*.ln_post",self.compute_mlps
        )

        #after final layer's layer norm. 
        self.model.hook_manager.register(
            f"encoder.layer_norm_post.mean",
            self.log_post_ln_mean
        )
        
        self.model.hook_manager.register(
            f"encoder.layer_norm_post.sqrt_var",
            self.log_post_ln_std
        )


    def _normalize_mlps(self):
        len_intermediates = self.attentions.shape[1] + self.mlps.shape[1]
        # This is just the normalization layer:
        mean_centered = (
            self.mlps
            - self.post_ln_mean[:, :, np.newaxis].to(self.device) / len_intermediates
        )
        weighted_mean_centered = (
            self.model.beit3.encoder.layernorm.B.weight.detach().to(self.device) * mean_centered

        )
        weighted_mean_by_std = weighted_mean_centered / self.post_ln_std[
            :, :, np.newaxis, np.newaxis, np.newaxis
        ].to(self.device)
        bias_term = self.model.beit3.encoder.layernorm.B.bias.detach().to(self.device) / (
            len_intermediates 
        )
        post_ln = weighted_mean_by_std + bias_term
        return post_ln @ self.model.beit3.encoder.output_projection.detach().to(self.device)

    def _normalize_attentions(self):
        len_intermediates = self.attentions.shape[1] + self.mlps.shape[1]  # 2*l + 1
        normalization_term = (
            self.attentions.shape[2] * self.attentions.shape[3]
        )  # n * h
        # This is just the normalization layer:
        mean_centered = self.attentions - self.post_ln_mean[
            :, :, np.newaxis, np.newaxis, np.newaxis
        ].to(self.device) / (len_intermediates * normalization_term)
        
        weighted_mean_centered = (
            self.model.beit3.encoder.layernorm.B.weight.detach().to(self.device) * mean_centered

        )
        weighted_mean_by_std = weighted_mean_centered / self.post_ln_std[
            :, :, np.newaxis, np.newaxis, np.newaxis
        ].to(self.device)
        
        
        
        bias_term = self.model.beit3.encoder.layernorm.B.bias.detach().to(self.device) / (
            len_intermediates * normalization_term
        )
        
        post_ln = weighted_mean_by_std + bias_term
        return post_ln @ self.model.beit3.encoder.output_projection.detach().to(self.device)

    @torch.no_grad()
    def finalize(self):
        """We calculate the post-ln scaling, project it and normalize by the last norm."""
        self.attentions = torch.stack(self.attentions, axis=1).to(
            self.device
        )  # [b, l, n, h, d]
        self.mlps = torch.stack(self.mlps, axis=1).to(self.device)  # [b, l + 1, d]
        projected_attentions = self._normalize_attentions()
        projected_mlps = self._normalize_mlps()
        
        vision_cls_proj_attn = self.vision_head(projected_attentions)
        vision_cls_proj_mlps = self.vision_head(projected_mlps)
        
        attentions = F.normalize(vision_cls_proj_attn)
        mlps = F.normalize(vision_cls_proj_mlps)
        
        return(
            attentions,mlps
        )

    def reinit(self):
        self.current_layer = 0
        self.attentions = []
        self.mlps = []
        self.post_ln_mean = None
        self.post_ln_std = None
        torch.cuda.empty_cache()

def hook_prs_logger(model, embed_dim, device):
    """Hooks a projected residual stream logger to the model."""
    prs = PRSLogger(model, embed_dim, device)
    prs.register_hooks()
    return prs

In [23]:
from torchvision import transforms

from unilm.beit3.utils import load_model_and_may_interpolate


# Load and transform the image
#image_path = "/content/tiny-imagenet-200/val/images/val_0.JPEG"

from transformers import XLMRobertaTokenizer
from PIL import Image

tokenizer = XLMRobertaTokenizer("beit3.spm")

image_path = "/home/william/project/images/catdog.png"
image = Image.open(image_path)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to the size expected by the model
    transforms.ToTensor(),          # Convert to tensor
])

img_tensor = transform(image).unsqueeze(0).requires_grad_(True)
img_tensor = img_tensor.to(device)

checkpoint_path = "https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_224.pth"

image_text_contrastive_checkpoint = "https://github.com/addf400/files/releases/download/beit3/beit3_base_itc_patch16_224.pth"


# Load the checkpoint into vqa_fixed
#load_model_and_may_interpolate(checkpoint_path, vqa_model, model_key='model', model_prefix='')

load_model_and_may_interpolate(image_text_contrastive_checkpoint, retrieve_model, model_key='model', model_prefix='')




Load ckpt from https://github.com/addf400/files/releases/download/beit3/beit3_base_itc_patch16_224.pth
Load state_dict by model_key = model


In [24]:

device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'

In [25]:
retrieve_model.to(device)
retrieve_model.eval()

encoder_embed_dim = 768

prs = hook_prs_logger(retrieve_model,encoder_embed_dim, device)


In [27]:
with torch.no_grad():
    retrieve_model(image=img_tensor)
    attentions,mlps = prs.finalize()


AttributeError: 'BEiT3ForRetrieval' object has no attribute 'hook'