In [3]:

import hashlib
import logging
import os
import urllib
import warnings
from typing import Union, List

import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm

from model import build_model
from simple_tokenizer import SimpleTokenizer as _Tokenizer

def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
    os.makedirs(root, exist_ok=True)
    filename = os.path.basename(url)

    expected_sha256 = url.split("/")[-2]
    download_target = os.path.join(root, filename)

    if os.path.exists(download_target) and not os.path.isfile(download_target):
        raise RuntimeError(f"{download_target} exists and is not a regular file")

    if os.path.isfile(download_target):
        if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
            return download_target
        else:
            warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")

    with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
        with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))

    if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
        raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")

    return download_target


model_path=_download("https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt")
model = torch.jit.load(model_path, map_location="cpu" if True else "cpu").eval()



RecursiveScriptModule(
  original_name=Multimodal
  (visual): RecursiveScriptModule(
    original_name=VisualTransformer
    (conv1): RecursiveScriptModule(original_name=Conv2d)
    (ln_pre): RecursiveScriptModule(original_name=LayerNorm)
    (transformer): RecursiveScriptModule(
      original_name=Transformer
      (resblocks): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=ResidualAttentionBlock
          (attn): RecursiveScriptModule(
            original_name=MultiheadAttention
            (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
          )
          (ln_1): RecursiveScriptModule(original_name=LayerNorm)
          (mlp): RecursiveScriptModule(
            original_name=Sequential
            (c_fc): RecursiveScriptModule(original_name=Linear)
            (gelu): RecursiveScriptModule(original_name=QuickGELU)
            (c_proj): RecursiveScriptModule(original_name=Linear)
          )
          (ln_2): RecursiveScriptModule(original_name=LayerNorm)
        )
        (1): RecursiveScriptModule(
          original_name=ResidualAttentionBlock
          (attn): RecursiveScriptModule(
            original_name=MultiheadAttention
            (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
          )
          (ln_1): RecursiveScriptModule(original_name=LayerNorm)
          (mlp): RecursiveScriptModule(
            original_name=Sequential
            (c_fc): RecursiveScriptModule(original_name=Linear)
            (gelu): RecursiveScriptModule(original_name=QuickGELU)
            (c_proj): RecursiveScriptModule(original_name=Linear)
          )
          (ln_2): RecursiveScriptModule(original_name=LayerNorm)
        )
        (2): RecursiveScriptModule(
          original_name=ResidualAttentionBlock
          (attn): RecursiveScriptModule(
            original_name=MultiheadAttention
            (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
          )
          (ln_1): RecursiveScriptModule(original_name=LayerNorm)
          (mlp): RecursiveScriptModule(
            original_name=Sequential
            (c_fc): RecursiveScriptModule(original_name=Linear)
            (gelu): RecursiveScriptModule(original_name=QuickGELU)
            (c_proj): RecursiveScriptModule(original_name=Linear)
          )
          (ln_2): RecursiveScriptModule(original_name=LayerNorm)
        )
        (3): RecursiveScriptModule(
          original_name=ResidualAttentionBlock
          (attn): RecursiveScriptModule(
            original_name=MultiheadAttention
            (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
          )
          (ln_1): RecursiveScriptModule(original_name=LayerNorm)
          (mlp): RecursiveScriptModule(
            original_name=Sequential
            (c_fc): RecursiveScriptModule(original_name=Linear)
            (gelu): RecursiveScriptModule(original_name=QuickGELU)
            (c_proj): RecursiveScriptModule(original_name=Linear)
          )
          (ln_2): RecursiveScriptModule(original_name=LayerNorm)
        )
        (4): RecursiveScriptModule(
          original_name=ResidualAttentionBlock
          (attn): RecursiveScriptModule(
            original_name=MultiheadAttention
            (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
          )
          (ln_1): RecursiveScriptModule(original_name=LayerNorm)
          (mlp): RecursiveScriptModule(
            original_name=Sequential
            (c_fc): RecursiveScriptModule(original_name=Linear)
            (gelu): RecursiveScriptModule(original_name=QuickGELU)
            (c_proj): RecursiveScriptModule(original_name=Linear)
          )
          (ln_2): RecursiveScriptModule(original_name=LayerNorm)
        )
        (5): RecursiveScriptModule(
          original_name=ResidualAttentionBlock
          (attn): RecursiveScriptModule(
            original_name=MultiheadAttention
            (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
          )
          (ln_1): RecursiveScriptModule(original_name=LayerNorm)
          (mlp): RecursiveScriptModule(
            original_name=Sequential
            (c_fc): RecursiveScriptModule(original_name=Linear)
            (gelu): RecursiveScriptModule(original_name=QuickGELU)
            (c_proj): RecursiveScriptModule(original_name=Linear)
          )
          (ln_2): RecursiveScriptModule(original_name=LayerNorm)
        )
        (6): RecursiveScriptModule(
          original_name=ResidualAttentionBlock
          (attn): RecursiveScriptModule(
            original_name=MultiheadAttention
            (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
          )
          (ln_1): RecursiveScriptModule(original_name=LayerNorm)
          (mlp): RecursiveScriptModule(
            original_name=Sequential
            (c_fc): RecursiveScriptModule(original_name=Linear)
            (gelu): RecursiveScriptModule(original_name=QuickGELU)
            (c_proj): RecursiveScriptModule(original_name=Linear)
          )
          (ln_2): RecursiveScriptModule(original_name=LayerNorm)
        )
        (7): RecursiveScriptModule(
          original_name=ResidualAttentionBlock
          (attn): RecursiveScriptModule(
            original_name=MultiheadAttention
            (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
          )
          (ln_1): RecursiveScriptModule(original_name=LayerNorm)
          (mlp): RecursiveScriptModule(
            original_name=Sequential
            (c_fc): RecursiveScriptModule(original_name=Linear)
            (gelu): RecursiveScriptModule(original_name=QuickGELU)
            (c_proj): RecursiveScriptModule(original_name=Linear)
          )
          (ln_2): RecursiveScriptModule(original_name=LayerNorm)
        )
        (8): RecursiveScriptModule(
          original_name=ResidualAttentionBlock
          (attn): RecursiveScriptModule(
            original_name=MultiheadAttention
            (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
          )
          (ln_1): RecursiveScriptModule(original_name=LayerNorm)
          (mlp): RecursiveScriptModule(
            original_name=Sequential
            (c_fc): RecursiveScriptModule(original_name=Linear)
            (gelu): RecursiveScriptModule(original_name=QuickGELU)
            (c_proj): RecursiveScriptModule(original_name=Linear)
          )
          (ln_2): RecursiveScriptModule(original_name=LayerNorm)
        )
        (9): RecursiveScriptModule(
          original_name=ResidualAttentionBlock
          (attn): RecursiveScriptModule(
            original_name=MultiheadAttention
            (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
          )
          (ln_1): RecursiveScriptModule(original_name=LayerNorm)
          (mlp): RecursiveScriptModule(
            original_name=Sequential
            (c_fc): RecursiveScriptModule(original_name=Linear)
            (gelu): RecursiveScriptModule(original_name=QuickGELU)
            (c_proj): RecursiveScriptModule(original_name=Linear)
          )
          (ln_2): RecursiveScriptModule(original_name=LayerNorm)
        )
        (10): RecursiveScriptModule(
          original_name=ResidualAttentionBlock
          (attn): RecursiveScriptModule(
            original_name=MultiheadAttention
            (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
          )
          (ln_1): RecursiveScriptModule(original_name=LayerNorm)
          (mlp): RecursiveScriptModule(
            original_name=Sequential
            (c_fc): RecursiveScriptModule(original_name=Linear)
            (gelu): RecursiveScriptModule(original_name=QuickGELU)
            (c_proj): RecursiveScriptModule(original_name=Linear)
          )
          (ln_2): RecursiveScriptModule(original_name=LayerNorm)
        )
        (11): RecursiveScriptModule(
          original_name=ResidualAttentionBlock
          (attn): RecursiveScriptModule(
            original_name=MultiheadAttention
            (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
          )
          (ln_1): RecursiveScriptModule(original_name=LayerNorm)
          (mlp): RecursiveScriptModule(
            original_name=Sequential
            (c_fc): RecursiveScriptModule(original_name=Linear)
            (gelu): RecursiveScriptModule(original_name=QuickGELU)
            (c_proj): RecursiveScriptModule(original_name=Linear)
          )
          (ln_2): RecursiveScriptModule(original_name=LayerNorm)
        )
      )
    )
    (ln_post): RecursiveScriptModule(original_name=LayerNorm)
  )
  (transformer): RecursiveScriptModule(
    original_name=Transformer
    (resblocks): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(
        original_name=ResidualAttentionBlock
        (attn): RecursiveScriptModule(
          original_name=MultiheadAttention
          (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
        )
        (ln_1): RecursiveScriptModule(original_name=LayerNorm)
        (mlp): RecursiveScriptModule(
          original_name=Sequential
          (c_fc): RecursiveScriptModule(original_name=Linear)
          (gelu): RecursiveScriptModule(original_name=QuickGELU)
          (c_proj): RecursiveScriptModule(original_name=Linear)
        )
        (ln_2): RecursiveScriptModule(original_name=LayerNorm)
      )
      (1): RecursiveScriptModule(
        original_name=ResidualAttentionBlock
        (attn): RecursiveScriptModule(
          original_name=MultiheadAttention
          (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
        )
        (ln_1): RecursiveScriptModule(original_name=LayerNorm)
        (mlp): RecursiveScriptModule(
          original_name=Sequential
          (c_fc): RecursiveScriptModule(original_name=Linear)
          (gelu): RecursiveScriptModule(original_name=QuickGELU)
          (c_proj): RecursiveScriptModule(original_name=Linear)
        )
        (ln_2): RecursiveScriptModule(original_name=LayerNorm)
      )
      (2): RecursiveScriptModule(
        original_name=ResidualAttentionBlock
        (attn): RecursiveScriptModule(
          original_name=MultiheadAttention
          (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
        )
        (ln_1): RecursiveScriptModule(original_name=LayerNorm)
        (mlp): RecursiveScriptModule(
          original_name=Sequential
          (c_fc): RecursiveScriptModule(original_name=Linear)
          (gelu): RecursiveScriptModule(original_name=QuickGELU)
          (c_proj): RecursiveScriptModule(original_name=Linear)
        )
        (ln_2): RecursiveScriptModule(original_name=LayerNorm)
      )
      (3): RecursiveScriptModule(
        original_name=ResidualAttentionBlock
        (attn): RecursiveScriptModule(
          original_name=MultiheadAttention
          (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
        )
        (ln_1): RecursiveScriptModule(original_name=LayerNorm)
        (mlp): RecursiveScriptModule(
          original_name=Sequential
          (c_fc): RecursiveScriptModule(original_name=Linear)
          (gelu): RecursiveScriptModule(original_name=QuickGELU)
          (c_proj): RecursiveScriptModule(original_name=Linear)
        )
        (ln_2): RecursiveScriptModule(original_name=LayerNorm)
      )
      (4): RecursiveScriptModule(
        original_name=ResidualAttentionBlock
        (attn): RecursiveScriptModule(
          original_name=MultiheadAttention
          (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
        )
        (ln_1): RecursiveScriptModule(original_name=LayerNorm)
        (mlp): RecursiveScriptModule(
          original_name=Sequential
          (c_fc): RecursiveScriptModule(original_name=Linear)
          (gelu): RecursiveScriptModule(original_name=QuickGELU)
          (c_proj): RecursiveScriptModule(original_name=Linear)
        )
        (ln_2): RecursiveScriptModule(original_name=LayerNorm)
      )
      (5): RecursiveScriptModule(
        original_name=ResidualAttentionBlock
        (attn): RecursiveScriptModule(
          original_name=MultiheadAttention
          (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
        )
        (ln_1): RecursiveScriptModule(original_name=LayerNorm)
        (mlp): RecursiveScriptModule(
          original_name=Sequential
          (c_fc): RecursiveScriptModule(original_name=Linear)
          (gelu): RecursiveScriptModule(original_name=QuickGELU)
          (c_proj): RecursiveScriptModule(original_name=Linear)
        )
        (ln_2): RecursiveScriptModule(original_name=LayerNorm)
      )
      (6): RecursiveScriptModule(
        original_name=ResidualAttentionBlock
        (attn): RecursiveScriptModule(
          original_name=MultiheadAttention
          (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
        )
        (ln_1): RecursiveScriptModule(original_name=LayerNorm)
        (mlp): RecursiveScriptModule(
          original_name=Sequential
          (c_fc): RecursiveScriptModule(original_name=Linear)
          (gelu): RecursiveScriptModule(original_name=QuickGELU)
          (c_proj): RecursiveScriptModule(original_name=Linear)
        )
        (ln_2): RecursiveScriptModule(original_name=LayerNorm)
      )
      (7): RecursiveScriptModule(
        original_name=ResidualAttentionBlock
        (attn): RecursiveScriptModule(
          original_name=MultiheadAttention
          (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
        )
        (ln_1): RecursiveScriptModule(original_name=LayerNorm)
        (mlp): RecursiveScriptModule(
          original_name=Sequential
          (c_fc): RecursiveScriptModule(original_name=Linear)
          (gelu): RecursiveScriptModule(original_name=QuickGELU)
          (c_proj): RecursiveScriptModule(original_name=Linear)
        )
        (ln_2): RecursiveScriptModule(original_name=LayerNorm)
      )
      (8): RecursiveScriptModule(
        original_name=ResidualAttentionBlock
        (attn): RecursiveScriptModule(
          original_name=MultiheadAttention
          (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
        )
        (ln_1): RecursiveScriptModule(original_name=LayerNorm)
        (mlp): RecursiveScriptModule(
          original_name=Sequential
          (c_fc): RecursiveScriptModule(original_name=Linear)
          (gelu): RecursiveScriptModule(original_name=QuickGELU)
          (c_proj): RecursiveScriptModule(original_name=Linear)
        )
        (ln_2): RecursiveScriptModule(original_name=LayerNorm)
      )
      (9): RecursiveScriptModule(
        original_name=ResidualAttentionBlock
        (attn): RecursiveScriptModule(
          original_name=MultiheadAttention
          (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
        )
        (ln_1): RecursiveScriptModule(original_name=LayerNorm)
        (mlp): RecursiveScriptModule(
          original_name=Sequential
          (c_fc): RecursiveScriptModule(original_name=Linear)
          (gelu): RecursiveScriptModule(original_name=QuickGELU)
          (c_proj): RecursiveScriptModule(original_name=Linear)
        )
        (ln_2): RecursiveScriptModule(original_name=LayerNorm)
      )
      (10): RecursiveScriptModule(
        original_name=ResidualAttentionBlock
        (attn): RecursiveScriptModule(
          original_name=MultiheadAttention
          (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
        )
        (ln_1): RecursiveScriptModule(original_name=LayerNorm)
        (mlp): RecursiveScriptModule(
          original_name=Sequential
          (c_fc): RecursiveScriptModule(original_name=Linear)
          (gelu): RecursiveScriptModule(original_name=QuickGELU)
          (c_proj): RecursiveScriptModule(original_name=Linear)
        )
        (ln_2): RecursiveScriptModule(original_name=LayerNorm)
      )
      (11): RecursiveScriptModule(
        original_name=ResidualAttentionBlock
        (attn): RecursiveScriptModule(
          original_name=MultiheadAttention
          (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
        )
        (ln_1): RecursiveScriptModule(original_name=LayerNorm)
        (mlp): RecursiveScriptModule(
          original_name=Sequential
          (c_fc): RecursiveScriptModule(original_name=Linear)
          (gelu): RecursiveScriptModule(original_name=QuickGELU)
          (c_proj): RecursiveScriptModule(original_name=Linear)
        )
        (ln_2): RecursiveScriptModule(original_name=LayerNorm)
      )
    )
  )
  (token_embedding): RecursiveScriptModule(original_name=Embedding)
  (ln_final): RecursiveScriptModule(original_name=LayerNorm)
)

In [4]:
model

RecursiveScriptModule(
  original_name=Multimodal
  (visual): RecursiveScriptModule(
    original_name=VisualTransformer
    (conv1): RecursiveScriptModule(original_name=Conv2d)
    (ln_pre): RecursiveScriptModule(original_name=LayerNorm)
    (transformer): RecursiveScriptModule(
      original_name=Transformer
      (resblocks): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=ResidualAttentionBlock
          (attn): RecursiveScriptModule(
            original_name=MultiheadAttention
            (out_proj): RecursiveScriptModule(original_name=_LinearWithBias)
          )
          (ln_1): RecursiveScriptModule(original_name=LayerNorm)
          (mlp): RecursiveScriptModule(
            original_name=Sequential
            (c_fc): RecursiveScriptModule(original_name=Linear)
            (gelu): RecursiveScriptModule(original_name=QuickGELU)
            (c_proj): RecursiveScriptModule(original_name=Linear)
          )
   

In [6]:
# 打印模型的参数和名称
for name, param in model.named_parameters():
    print(f"参数名称: {name}")
    print(f"参数形状: {param.shape}")
    print(f"参数类型: {param.dtype}")
    print("-" * 50)

参数名称: positional_embedding
参数形状: torch.Size([77, 512])
参数类型: torch.float32
--------------------------------------------------
参数名称: text_projection
参数形状: torch.Size([512, 512])
参数类型: torch.float16
--------------------------------------------------
参数名称: logit_scale
参数形状: torch.Size([])
参数类型: torch.float32
--------------------------------------------------
参数名称: visual.class_embedding
参数形状: torch.Size([768])
参数类型: torch.float32
--------------------------------------------------
参数名称: visual.positional_embedding
参数形状: torch.Size([50, 768])
参数类型: torch.float32
--------------------------------------------------
参数名称: visual.proj
参数形状: torch.Size([768, 512])
参数类型: torch.float16
--------------------------------------------------
参数名称: visual.conv1.weight
参数形状: torch.Size([768, 3, 32, 32])
参数类型: torch.float16
--------------------------------------------------
参数名称: visual.ln_pre.weight
参数形状: torch.Size([768])
参数类型: torch.float32
--------------------------------------------------
参数名称: visual.