In [1]:
import open_clip
import torch
from torch import nn
import json
import torchvision.transforms as T
from PIL import Image
import torch.nn.functional as F

In [2]:
open_clip.list_pretrained()

[('RN50', 'openai'),
 ('RN50', 'yfcc15m'),
 ('RN50', 'cc12m'),
 ('RN50-quickgelu', 'openai'),
 ('RN50-quickgelu', 'yfcc15m'),
 ('RN50-quickgelu', 'cc12m'),
 ('RN101', 'openai'),
 ('RN101', 'yfcc15m'),
 ('RN101-quickgelu', 'openai'),
 ('RN101-quickgelu', 'yfcc15m'),
 ('RN50x4', 'openai'),
 ('RN50x16', 'openai'),
 ('RN50x64', 'openai'),
 ('ViT-B-32', 'openai'),
 ('ViT-B-32', 'laion400m_e31'),
 ('ViT-B-32', 'laion400m_e32'),
 ('ViT-B-32', 'laion2b_e16'),
 ('ViT-B-32', 'laion2b_s34b_b79k'),
 ('ViT-B-32', 'datacomp_m_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_clip_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_laion_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_image_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_text_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_basic_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_s128m_b4k'),
 ('ViT-B-32', 'datacomp_s_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_clip_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_laion_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_image_s13m_b4k'),
 ('ViT-B-32', 'commo

In [3]:
clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', 'laion2b_s32b_b82k')
tokenizer = open_clip.get_tokenizer('ViT-L-14')
pre_process = T.Compose([
                T.Resize(
                    size=(224, 224), 
                    interpolation=T.InterpolationMode.BICUBIC,
                    antialias=True),
                T.ToTensor(), 
                T.Normalize(
                    mean=(0.48145466, 0.4578275, 0.40821073), 
                    std=(0.26862954, 0.26130258, 0.27577711)
                )
            ])
clip.cuda()
clip.half()
_ = clip.eval()

In [4]:
def remove_transparency(im, bg_colour=(255, 255, 255)):

    if im.mode in ('RGBA', 'LA') or (im.mode == 'P' and 'transparency' in im.info):
        alpha = im.convert('RGBA').split()[-1]
        bg = Image.new("RGB", im.size, (255, 255, 255))
        bg.paste(im, mask=alpha)
        return bg
    elif im.mode == 'P':
        bg = Image.new("RGB", im.size, (255, 255, 255))
        bg.paste(im)
        return bg
    else:
        return im

In [5]:
nsfw_sfw_keywords = json.load(open('./keywords.json'))
keywords = []
nsfw_keywords = nsfw_sfw_keywords["nsfw_keywords"]
for nsfw_keyword in nsfw_keywords:
    keywords.extend(nsfw_keyword["keywords"])
sfw_keywords = nsfw_sfw_keywords["non_nsfw_keywords"]
for sfw_keyword in sfw_keywords:
    keywords.extend(sfw_keyword["keywords"])
keywords = list(set(keywords))

In [10]:
len(keywords)
keywords = ["+13", "+18", "+17", "general audiences", "rated r", "family friendly"]

In [11]:
test_images = "./test_imgs/sfw/open_mounth.png"
image = Image.open(test_images)
text = tokenizer(keywords)

In [12]:
with torch.no_grad(), torch.cuda.amp.autocast():
    image = remove_transparency(image)
    x = pre_process(image)
    x = x.reshape(1, 3, 224, 224)
    image_features = clip.encode_image(x.cuda())
    text_features = clip.encode_text(text.cuda())
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    #text_probs = (image_features @ text_features.T)
    #text_probs /= text_probs.norm(dim=-1, keepdim=True)

In [13]:
prod_list = text_probs[0].tolist()
results = [(keyword, prob) for keyword, prob in zip(keywords, prod_list)]
results = sorted(results, key=lambda x: x[1], reverse=True)

print(results)

[('+13', 0.6041147112846375), ('+18', 0.2557982802391052), ('+17', 0.13479621708393097), ('rated r', 0.004102418664842844), ('family friendly', 0.0011215388076379895), ('general audiences', 6.682948878733441e-05)]
