In [2]:
import torch
import clip
from PIL import Image
import numpy as np
import cv2
import os
import matplotlib.pyplot as plt

In [3]:
!python3 --version

'python3' is not recognized as an internal or external command,
operable program or batch file.


In [4]:

def interpret_vit(image, text, model, device, index=None):
    logits_per_image, logits_per_text = model(image, text)
    print(logits_per_image)
    probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()
    if index is None:
        index = np.argmax(logits_per_image.cpu().data.numpy(), axis=-1)
    one_hot = np.zeros((1, logits_per_image.size()[-1]), dtype=np.float32)
    one_hot[0, index] = 1
    one_hot = torch.from_numpy(one_hot).requires_grad_(True)
    one_hot = torch.sum(one_hot.cpu() * logits_per_image)
    model.zero_grad()
    one_hot.backward(retain_graph=True)

    image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values())
    num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
    R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device)
    for blk in image_attn_blocks:
        grad = blk.attn_grad
        cam = blk.attn_probs
        cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
        grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
        cam = grad * cam
        cam = cam.clamp(min=0).mean(dim=0)
        R += torch.matmul(cam, R)
    R[0, 0] = 0
    image_relevance = R[0, 1:]

    # create heatmap from mask on image
    def show_cam_on_image(img, mask):
        heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
        heatmap = np.float32(heatmap) / 255
        cam = heatmap + np.float32(img)
        cam = cam / np.max(cam)
        return cam

    image_relevance = image_relevance.reshape(1, 1, 7, 7)
    image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
    image_relevance = image_relevance.reshape(224, 224).cpu().data.numpy()
    image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
    image = image[0].permute(1, 2, 0).data.cpu().numpy()
    image = (image - image.min()) / (image.max() - image.min())
    vis = show_cam_on_image(image, image_relevance)
    vis = np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)

    plt.imshow(vis)
#     plt.show()

In [5]:
import streamlit as st

from torchray.attribution.grad_cam import grad_cam
import torch
import clip
from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt

st.sidebar.header('Options')
alpha = st.sidebar.radio("select alpha", [0.5, 0.7, 0.8], index=1)
layer = st.sidebar.selectbox("select saliency layer", ['layer4.2.relu'], index=0)

device = "cuda" if torch.cuda.is_available() else "cpu"
model_rn, preprocess = clip.load("RN50", device=device, jit=False)

def interpret_rn(image, text, model, device, index=None):   
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    image_features_norm = image_features.norm(dim=-1, keepdim=True)
    image_features_new = image_features / image_features_norm
    text_features_norm = text_features.norm(dim=-1, keepdim=True)
    text_features_new = text_features / text_features_norm
    logit_scale = model.logit_scale.exp()
    logits_per_image = logit_scale * image_features_new @ text_features_new.t()
    probs = logits_per_image.softmax(dim=-1).cpu().detach().numpy().tolist()
    
    text_prediction = (text_features_new * image_features_norm)
    image_relevance = grad_cam(model.visual, image.type(model.dtype), text_prediction, saliency_layer=layer)
        
#     image_relevance = grad_cam(model.visual, image.type(model.dtype), image_features, saliency_layer=layer)

    # create heatmap from mask on image
    def show_cam_on_image(img, mask):
        heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
        heatmap = np.float32(heatmap) / 255
        cam = heatmap + np.float32(img)
        cam = cam / np.max(cam)
        return cam

    image_relevance = image_relevance.reshape(1, 1, 7, 7)
    image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
    image_relevance = image_relevance.reshape(224, 224).cpu().data.numpy()
    image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
    image = image[0].permute(1, 2, 0).data.cpu().numpy()
    image = (image - image.min()) / (image.max() - image.min())
    vis = show_cam_on_image(image, image_relevance)
    vis = np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)

    plt.imshow(vis)

2024-03-22 12:11:39.443 
  command:

    streamlit run c:\Users\born1\Desktop\Github\CLIP-ViL GRADCAM\CLIP-ViL-GradCAM\CLIP_virtual_env\Lib\site-packages\ipykernel_launcher.py [ARGUMENTS]


In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

In [7]:
class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'

In [38]:
img_id = 'COCO_val2014_000000393267'
MSCOCO_IMG_ROOT = "/rscratch/data/coco_2014/images"

# COCO_val2014_000000393267 What color is the woman's shirt on the left? {'black': 1, 'blonde': 0.3}
import os
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
ori_preprocess = Compose([
        Resize((224), interpolation=Image.BICUBIC),
    CenterCrop(size=(224, 224)),
        ToTensor()])
# img_path = os.path.join(MSCOCO_IMG_ROOT, "val2014", img_id + ".jpg")
# img_path = 'id.png'
# image = ori_preprocess(Image.open(img_path))
# print(preprocess)

from matplotlib import rc

#----------------------------------------------------------------------------------------------------------------------------------------
import cv2
from transformers import CLIPProcessor, CLIPModel
# Load CLIP model and processor
model_name = "openai/clip-vit-base-patch32"  # Replace with your chosen CLIP model
processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name).to(device)

# Function for danger classification in a frame
def classify_frame(frame, text_prompts):
  """
  Classifies a video frame based on provided text prompts using CLIP.

  Args:
      frame: The video frame as a NumPy array.
      text_prompts: List of text prompts representing danger and safety categories.

  Returns:
      A string indicating the predicted class ("danger" or "safe").
  """
  
  # Pre-process the frame (resize, normalize) as required by your CLIP model
  # ... (refer to your CLIP model's documentation for pre-processing steps)
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

  # Resize the frame (check the model's documentation for required dimensions)
  if model_name == "openai/clip-vit-base-patch32":
      # Assuming CLIP-Vit-Base-Patch32 expects 224x224
      frame = cv2.resize(frame, (224, 224))
  else:
      # Modify based on your CLIP model's requirements
      raise ValueError(f"Unsupported CLIP model: {model_name}")

  # Normalize pixel values (assuming model expects values between 0 and 1)
  frame = frame.astype(np.float32) / 255.0

  # Convert frame to a tensor and transpose dimensions (B, G, R) -> (C, H, W)
  frame_tensor = torch.from_numpy(frame).permute(2, 0, 1).unsqueeze(0)

  
  # Convert frame to a tensor and move it to device
  #frame_tensor = ...  # Conversion based on your pre-processing steps (e.g., to torch.tensor)
  frame_tensor = frame_tensor.to(device)

  vocabulary={"danger":1,"safe":2}

  # Convert text prompts to integer indices (assuming a vocabulary exists)
  text_indices = [vocabulary[prompt] for prompt in text_prompts]
  text = torch.tensor(text_indices).to(device)

  # Assuming text prompts are already pre-processed with integer indices
  text = torch.tensor(text_prompts).to(device)

  # Forward pass through CLIP model (assuming it takes both image and text)
  with torch.no_grad():
      logits = model(frame_tensor, text)

  # Get the class with the highest score (assuming higher logits indicate danger)
  predicted_class = torch.argmax(logits).item()
  predicted_class_text = text_prompts[predicted_class]

  return predicted_class_text

# Video processing loop
video_path = "dangerous accident Very lucky Survival cctv camera.mp4"
cap = cv2.VideoCapture(video_path)

font = cv2.FONT_HERSHEY_SIMPLEX  # Font for danger label

danger_frame_count = 0
safe_frame_count = 0

while True:
  ret, frame = cap.read()
  if not ret:
    break

  # Danger classification using CLIP on the current frame
  predicted_class = classify_frame(frame, ["danger", "safe"])

  # Update danger/safe frame counts based on prediction
  if predicted_class == "danger":
      danger_frame_count += 1
      # Add danger label to the frame
      cv2.putText(frame, "Danger!", (10, 30), font, 1, (0, 0, 255), 2, cv2.LINE_AA)
  else:
      safe_frame_count += 1

  # Display the frame with potential danger label
  cv2.imshow("Video", frame)

  # Exit on 'q' key press
  if cv2.waitKey(1) & 0xFF == ord('q'):
      break

#------------------------------------------------------------------------------------------------------------------------------------------

# font = {
#     'size': 32,
# }
# import matplotlib
# matplotlib.rcParams['mathtext.fontset'] = 'custom'
# matplotlib.rcParams['mathtext.rm'] = 'Bitstream Vera Sans'
# matplotlib.rcParams['mathtext.it'] = 'Bitstream Vera Sans:italic'
# matplotlib.rcParams['mathtext.bf'] = 'Bitstream Vera Sans:bold'
# # matplotlib.rcParams['mathtext.size'] = 16

# # {'cursive', 'fantasy', 'monospace', 'sans', 'sans serif', 'sans-serif', 'serif'}
# plt.figure(figsize=(16, 16))
# plt.tight_layout()
# plt.subplot(131)
# plt.imshow(image.permute(1, 2, 0))
# plt.axis('off')
# plt.title("(a) Original", **font, y=-0.15)

# # plt.savefig('/rscratch/sheng.s/clip_boi/clip_vqa_starting/visual/sample_1_ori.pdf', bbox_inches='tight')
# # plt.show()
# texts = ["face"]


# image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
# text = clip.tokenize(texts).to(device)
# print(color.BOLD + color.PURPLE + color.UNDERLINE + 'text: ' + texts[0] + color.END)
# plt.subplot(132)
# plt.axis('off')
# plt.title("(b) ViT-B/32", **font,y=-0.15)
# interpret_vit(model=model, image=image, text=text, device=device, index=0)
# plt.subplot(133)
# plt.axis('off')
# plt.title("(c) RN50", **font,y=-0.15)
# interpret_rn(model=model_rn, image=image, text=text, device=device, index=0)



"""
image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
text = clip.tokenize(texts).to(device)
plt.subplot(133)
print(color.BOLD + color.PURPLE + color.UNDERLINE + 'text: ' + texts[0] + color.END)
interpret_rn(model=model_rn, image=image, text=text, device=device, index=0)
plt.axis('off')
plt.title("(c) RN50", **font,y=-0.15)
plt.tight_layout()

plt.savefig('sample_all.pdf', bbox_inches='tight')"""

ValueError: too many dimensions 'str'

In [9]:
images = []
from glob import glob

image_files = glob('./bottle/*.jpg')

# Load and store images in a list
images = [preprocess(Image.open(file)).unsqueeze(0).to(device) for file in image_files]
images = np.stack(images)
images = torch.Tensor(images).squeeze(1)
texts= ["bottle", "mom", "lion"]
text = clip.tokenize(texts).to(device)
logits_per_image, logits_per_text = model(images, text)
probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()

In [30]:
probs

array([[9.80874002e-01, 1.74526330e-02, 1.67341577e-03],
       [9.95607436e-01, 3.55899474e-03, 8.33532307e-04],
       [9.92620230e-01, 6.58456888e-03, 7.95182423e-04],
       [9.92818892e-01, 5.86325396e-03, 1.31783565e-03],
       [9.93724763e-01, 4.45985748e-03, 1.81542360e-03],
       [9.91071224e-01, 6.50541391e-03, 2.42342823e-03],
       [9.91384864e-01, 5.98575547e-03, 2.62938719e-03],
       [9.93631899e-01, 5.23412926e-03, 1.13408861e-03],
       [9.76863801e-01, 2.09404062e-02, 2.19580322e-03],
       [9.89538431e-01, 8.15730635e-03, 2.30429950e-03],
       [9.91728187e-01, 7.03415601e-03, 1.23772700e-03],
       [9.90314782e-01, 8.53700563e-03, 1.14823144e-03],
       [9.93409157e-01, 5.49768517e-03, 1.09319668e-03],
       [9.91212189e-01, 6.31303200e-03, 2.47475691e-03],
       [9.95156586e-01, 3.37180356e-03, 1.47161086e-03],
       [9.93618011e-01, 4.93897311e-03, 1.44309725e-03],
       [9.95293438e-01, 3.74574005e-03, 9.60841309e-04],
       [9.94298518e-01, 4.70601