In [1]:
%matplotlib inline

import glob
import clip

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

import cv2
from math import *

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint

from PIL import Image
from torchvision import transforms

from IPython import display
import time

device = "cuda"

In [2]:
def vision_attn_forward(self, x):
    x = self.conv1(x)  # shape = [*, width, grid, grid]
    x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
    x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
    x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
    x = x + self.positional_embedding.to(x.dtype)
    x = self.ln_pre(x)

    x = x.permute(1, 0, 2)  # NLD -> LND
    x,attn = self.transformer.transformer_attn_forward(x)
    x = x.permute(1, 0, 2)  # LND -> NLD

    x = self.ln_post(x[:, 0, :])

    if self.proj is not None:
        x = x @ self.proj

    return x, attn

def transformer_attn_forward(self, x):
    z = x
    attns = []
    
    for layer in self.resblocks:
        z,a = layer.resblock_attn_forward(z)
        attns.append(a)
        
    return z, attns

def resblock_attn_forward(self, x):
    attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
    
    z = self.ln_1(x)
    attn_pat, attn_weights = self.attn(z, z, z, need_weights=True, average_attn_weights = True, attn_mask=attn_mask)
    attn_pat = attn_pat[0]

    x = x + attn_pat
    x = x + self.mlp(self.ln_2(x))
    
    return x, attn_weights

In [3]:
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [4]:
model, preprocess = clip.load("ViT-L/14@336px", device=device)

In [5]:
model.visual.vision_attn_forward = vision_attn_forward.__get__(model.visual)
model.visual.transformer.transformer_attn_forward = transformer_attn_forward.__get__(model.visual.transformer)

for layer in model.visual.transformer.resblocks:
    layer.resblock_attn_forward = resblock_attn_forward.__get__(layer)


In [6]:
def attn_mask(frame, lastw):
    with torch.no_grad():
        with torch.autocast("cuda"):
            frame = Image.fromarray(frame[:480,80:80+480,:])
            frame = frame.resize((336, 336), Image.NEAREST)
            
            z = preprocess(frame).unsqueeze(0).cuda()

            z,attn = model.visual.vision_attn_forward(z)

            #attn = torch.cat(attn,0).mean(0)[0,1:]
            attn = torch.cat(attn,0)[-1,0,1:]
            GRID = int(sqrt(attn.shape[0]))
            attn = attn.view(1,1,GRID,GRID)
            attn = F.upsample_bilinear(attn.view((1,1,GRID,GRID)), scale_factor=frame.width//GRID)
            attn = attn.cpu().detach().numpy()[0,0]    
            
            im2 = np.array(frame).astype(np.float32)

            w = np.clip( (attn/0.0025)**6,0,1)[:,:,np.newaxis]
            
            w = 0.25 * w + 0.75 * lastw
            
            im2 = w*im2 + np.mean(im2,axis=(0,1),keepdims=True)*(1-w)
            im2 = np.clip(im2,0,255).astype(np.uint8)
            
            return im2, w

In [7]:
vid = cv2.VideoCapture(0)

lastw = np.zeros((336, 336, 1))

while(True):
    ret, frame = vid.read()
    
    frame, lastw = attn_mask(frame, lastw)
    cv2.imshow('frame', frame)
    # Display the resulting frame
    #cv2.imshow('frame', frame)
      
    # the 'q' button is set as the
    # quitting button you may use any
    # desired button of your choice
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break
  
vid.release()
cv2.destroyAllWindows()



KeyboardInterrupt: 

In [None]:
vid.release()
cv2.destroyAllWindows()