In [1]:
using PyCall

torch = pyimport("torch")
clip = pyimport("clip")
Image = pyimport("PIL").Image

PyObject <module 'PIL.Image' from 'C:\\Users\\user\\.pyenv\\pyenv-win\\versions\\3.8.10\\lib\\site-packages\\PIL\\Image.py'>

In [2]:
# モデル読み込み
device = torch.cuda.is_available() ? "cuda" : "cpu"
model, preprocess = clip.load("ViT-B/32"; device = device)

(PyObject CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
          )
          (ln_1): La

In [4]:
using JSON

# 画像とテキストの準備
image = preprocess(Image.open("./img/shortpants2013051823.jpg")).unsqueeze(0).to(device)
tag_list = JSON.parsefile("./tag_list.json") # Danbooru, Safebooru のタグリスト from https://github.com/rezoo/illustration2vec
text = clip.tokenize(tag_list).to(device)

PyObject tensor([[49406,   272,  1611,  ...,     0,     0,     0],
        [49406,  5797, 49407,  ...,     0,     0,     0],
        [49406,  1538,  2225,  ...,     0,     0,     0],
        ...,
        [49406,  2996, 49407,  ...,     0,     0,     0],
        [49406, 30733, 49407,  ...,     0,     0,     0],
        [49406, 33228, 49407,  ...,     0,     0,     0]], device='cuda:0')

In [5]:
probs = []

# PyTorch モデルを推論モードで実行
@pywith torch.no_grad() begin
    # 画像とテキストのエンコード
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)

    # 推論
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()
end

probs

1×1539 Matrix{Float16}:
 4.3e-5  1.3e-6  1.53e-5  1.395e-5  6.5e-6  …  3.93e-6  7.8e-6  8.46e-6

In [6]:
# Matrix{Float64} => Vector{Float64}
probs = vec(probs)

1539-element Vector{Float16}:
 4.3e-5
 1.3e-6
 1.53e-5
 1.395e-5
 6.5e-6
 0.0003712
 0.00011146
 9.7e-6
 2.0e-7
 4.0e-7
 3.04e-6
 3.76e-6
 5.54e-6
 ⋮
 3.16e-6
 6.25e-5
 1.67e-6
 6.0e-8
 3.93e-6
 8.0e-7
 0.00010973
 7.194e-5
 0.00010306
 3.93e-6
 7.8e-6
 8.46e-6

In [7]:
# 予測値（確率）が大きい順に10個の index を取得
k = 10
max_k_indices = partialsortperm(probs, 1:k; rev = true)

10-element view(::Vector{Int64}, 1:10) with eltype Int64:
 1304
 1097
 1291
 1306
 1526
 1300
 1523
 1125
 1105
 1329

In [8]:
# 値が大きい順にソートし tag_list[index] => value 形式で取得
for index in max_k_indices
    println(tag_list[index], "=>", probs[index])
end

queen's blade rebellion=>0.3071
queen's blade=>0.1521
koihime musou=>0.04358
summon night=>0.03967
senjou no valkyria 1=>0.0345
senjou no valkyria=>0.02483
miniskirt pirates=>0.02444
sword girls=>0.0168
ragnarok online=>0.01102
fire emblem: souen no kiseki=>0.01068


In [9]:
"""
    predictImageTag(imagefile::AbstractString, taglist::Vector{AbstractString}; modelname::AbstractString = "ViT-B/32", k::Integer = 10)

画像のタグを指定タグリストから推論する

# Params

- `imagefile::AbstractString`: 推論対象画像ファイルパス
- `taglist::Vector{AbstractString}`: タグリスト
- `modelname::AbstractString`: 使用するモデル名
- `k::Integer`: 推論結果を上位いくつ取得するか

# Returns

- `probs::Vector{Pair{AbstractString, AbstractFloat}}`: 推論結果
"""
function predictImageTag(imagefile::AbstractString, taglist::Vector; modelname::AbstractString = "ViT-B/32", k::Integer = 10)
    device = torch.cuda.is_available() ? "cuda" : "cpu"
    model, preprocess = clip.load(modelname; device = device)

    image = preprocess(Image.open(imagefile)).unsqueeze(0).to(device)
    text = clip.tokenize(taglist).to(device)
    
    probs = []
    @pywith torch.no_grad() begin
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)
        
        logits_per_image, logits_per_text = model(image, text)
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()
    end

    map(partialsortperm(vec(probs), 1:k; rev = true)) do index taglist[index] => probs[index] end
end

predictImageTag("./img/shortpants2013051823.jpg", tag_list)

10-element Vector{Pair{String, Float16}}:
      "queen's blade rebellion" => 0.3071
                "queen's blade" => 0.1521
                "koihime musou" => 0.04358
                 "summon night" => 0.03967
         "senjou no valkyria 1" => 0.0345
           "senjou no valkyria" => 0.02483
            "miniskirt pirates" => 0.02444
                  "sword girls" => 0.0168
              "ragnarok online" => 0.01102
 "fire emblem: souen no kiseki" => 0.01068