In [2]:
!nvidia-smi

In [3]:
!pip install git+https://github.com/StarxSky/CLIP.git
!pip install --upgrade torch==1.12.1
!pip install --upgrade torchvision

<div style='color: #216969;
           background-color: #EAF6F6;
           font-size: 200%;
           border-radius:15px;
           text-align:center;
           font-weight:600;
           border-style: solid;
           border-color: dark green;
           font-family: "Verdana";'>
Notebook Imports


In [46]:
import os 
import clip
import torch 
import skimage
import numpy as np 
import matplotlib.pyplot as plt

from PIL import Image
from torch.backends import mps
from torchvision.datasets import CIFAR100

# device
if mps.is_available() :
    device = torch.device('mps')

elif torch.cuda.is_available() :
    device = torch.device('cuda')

else :
    device = torch.device('cpu')

print(f'CLIP Version(PyTorch) :{clip.version}')
print(f'Device :{device}')

# Load models
model, preprocess = clip.load(name='ViT-B/32', device=device, download_root='./Pre_Models/')
model = model.to(device)

<div style='color: #216969;
           background-color: #EAF6F6;
           font-size: 200%;
           border-radius:15px;
           text-align:center;
           font-weight:600;
           border-style: solid;
           border-color: dark green;
           font-family: "Verdana";'>
Importing Data


In [5]:
# images in skimage to use and their textual descriptions
descriptions = {
    "page": "a page of text about segmentation",
    "chelsea": "a facial photo of a tabby cat",
    "astronaut": "a portrait of an astronaut with the American flag",
    "rocket": "a rocket standing on a launchpad",
    "motorcycle_right": "a red motorcycle standing in a garage",
    "camera": "a person looking at a camera on a tripod",
    "horse": "a black-and-white silhouette of a horse", 
    "coffee": "a cup of coffee on a saucer"
}

In [34]:
original_images = []
images = []
texts = []
plt.figure(figsize=(16, 5))

for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
    name = os.path.splitext(filename)[0]
    if name not in descriptions:
        continue

    image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
  
    plt.subplot(2, 4, len(images) + 1)
    plt.imshow(image)
    plt.title(f"{filename}\n{descriptions[name]}")
    plt.xticks([])
    plt.yticks([])

    original_images.append(image)
    images.append(preprocess(image))
    texts.append(descriptions[name])

plt.tight_layout()


In [35]:
len(images)

In [36]:
image_input = torch.tensor(np.stack(images)).to(device)
text_tokens = clip.tokenize(["This is " + desc for desc in texts]).to(device)
image_input.shape

<div style='color: #216969;
           background-color: #EAF6F6;
           font-size: 200%;
           border-radius:15px;
           text-align:center;
           font-weight:600;
           border-style: solid;
           border-color: dark green;
           font-family: "Verdana";'>
Give Model with data


In [37]:
# 将数据喂入模型CLIP
with torch.no_grad():
    image_features = model.encode_image(image_input).float()
    text_features = model.encode_text(text_tokens).float()
    
print(f'Image Features Shape :{image_features.shape}')
print(f'Text Features Shape :{text_features.shape}')

In [42]:
# 计算余弦相似度
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = image_features.cpu().numpy() @ text_features.cpu().numpy().T


print(f'logits_per_image shape:  {similarity.shape}') # 图像
print(f'logits_per_text shape:  {similarity.T.shape}') # 文本

* The above code is equivalent to the following code(上面的代码等效于下面的代码)

```python
with torch.no_grad():
    Image_P, logits_per_text = model(image_input, text_inputs)
print(Image_P.shape)
```
#### In this
* `similarity` = `Image_P`
* `similarity。T`  = `logits_per_text`
#### Model Forward Code:
```python
def forward(self, image, text):
        image_features = self.encode_image(image)
        text_features = self.encode_text(text)

        # normalized features
        image_features = image_features / image_features.norm(dim=1, keepdim=True)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)# <<<<====这里的shape[batch, text_nums]

        # .......
        
        similarity = logit_scale * image_features @ text_features.t()# <<<<====将图像特征和文本特征(已经转置)相乘[Multiplying image features and text features (already transposed)]
        
        logits_per_text = similarity.t() # <<<<====这里将图像的预测结果进行转置得出文本的预测信息[Here the prediction result of the image is transposed to derive the prediction information of the text]
        
        return similarity, logits_per_text
    
 ```

<div style='color: #216969;
           background-color: #EAF6F6;
           font-size: 200%;
           border-radius:15px;
           text-align:center;
           font-weight:600;
           border-style: solid;
           border-color: dark green;
           font-family: "Verdana";'>
The Easy Way

In [49]:
with torch.no_grad():
    Image_P, logits_per_text = model(image_input, text_tokens)
    Image_P = Image_P.softmax(dim=-1)# -1 is Adaptive
print(Image_P.shape)

<div style='color: #216969;
           background-color: #EAF6F6;
           font-size: 200%;
           border-radius:15px;
           text-align:center;
           font-weight:600;
           border-style: solid;
           border-color: dark green;
           font-family: "Verdana";'>
Visualizing the images


In [47]:
count = len(descriptions)

plt.figure(figsize=(20, 14))
plt.imshow(similarity, vmin=0.1, vmax=0.3)
plt.colorbar()
plt.yticks(range(count), texts, fontsize=18)
plt.xticks([])
for i, image in enumerate(original_images):
    plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
for x in range(similarity.shape[1]):
    for y in range(similarity.shape[0]):
        plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)

for side in ["left", "top", "right", "bottom"]:
  plt.gca().spines[side].set_visible(False)

plt.xlim([-0.5, count - 0.5])
plt.ylim([count + 0.5, -2])

plt.title("Cosine similarity between text and image features", size=20)