# Alpha-CLIP in BLIP-Diffusion

## Prepare Environment
You need to prepare [LAVIS](https://github.com/salesforce/LAVIS) environment first to prepare for [BLIP-Diffusion](https://github.com/salesforce/LAVIS/tree/main/projects/blip-diffusion) model, than run this notebook under LAVIS environment.

In [1]:
import torch
import collections
from PIL import Image
from lavis.models import load_model_and_preprocess
from torchvision import transforms
import types
import cv2
from PIL import Image
import numpy as np
import os
from tqdm import tqdm
from copy import deepcopy

alpha = None # global alpha var as alpha input
device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
torch.cuda.set_device("cuda:0")
model, vis_preprocess, txt_preprocess = load_model_and_preprocess("blip_diffusion", "base", device=device, is_eval=True)


[2023-11-28 15:26:48,152] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


  deprecate(
'(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /bert-base-uncased/resolve/main/vocab.txt (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7f448fcf5040>, 'Connection to huggingface.co timed out. (connect timeout=10)'))"), '(Request ID: 01b831d7-b6b6-409e-bbb7-e66576b6be36)')' thrown while requesting HEAD https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt


KeyboardInterrupt: 

## Plugin Alpha-CLIP
Alpha-CLIP can replace orginal CLIP used in BLIP-Diffusion. for simplicity, we rewrite forward funcation of its visual encoder. this rewrited_forward use alpha conv layer to add alpha-map into CLIP model input.

In [None]:
def rewrited_forward(self, x: torch.Tensor):
    global alpha
    if alpha is None: # better 
        print(f"[Warning] in {type(self)} forward: no alpha input when use alpha CLIP, alpha is expected!")
        alpha = torch.ones_like((x[:, [0], :, :])) * 1.9231
    x = self.conv1(x)  # shape = [*, width, grid, grid]
    x = x + self.conv1_alpha(alpha)
    x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
    x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
    x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
    x = x + self.positional_embedding.to(x.dtype)
    x = self.ln_pre(x)

    x = x.permute(1, 0, 2)  # NLD -> LND
    x = self.transformer(x)
    x = x.permute(1, 0, 2)  # LND -> NLD

    return x

Then, register rewrited forward function to replace original forward function of CLIP model used in BLIP-Diffusion. change its weight into Alpha-CLIP model weight.

In [None]:
state_dict = torch.load('clip_l14_grit+mim_fultune_6xe.pth')
converted_dict = collections.OrderedDict()
for k, v in state_dict.items():
    # if "visual" in k:
    if 'in_proj.weight' in k:
        converted_dict[k.replace('in_proj.weight', 'in_proj_weight')] = v
    elif 'in_proj.bias' in k:
        converted_dict[k.replace('in_proj.bias', 'in_proj_bias')] = v
    else:
        converted_dict[k] = v

model.blip.visual_encoder.conv1_alpha = torch.nn.Conv2d(in_channels=1,
                                                    out_channels=model.blip.visual_encoder.conv1.out_channels, 
                                                    kernel_size=model.blip.visual_encoder.conv1.kernel_size, 
                                                    stride=model.blip.visual_encoder.conv1.stride, 
                                                    bias=False)
model.blip.visual_encoder.forward = types.MethodType(rewrited_forward, model.blip.visual_encoder)
model.blip.visual_encoder.load_state_dict(converted_dict, strict=False)

After steps above, Alpha-CLIP successfully replaces original CLIP, and can perform region focused image variation.