In [None]:
import pandas as pd
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
import torch
from PIL import Image

import os


In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
dtype = torch.float16
nsfw_pipe = pipeline("image-classification",
                     model=AutoModelForImageClassification.from_pretrained(
                         "carbon225/vit-base-patch16-224-hentai"),
                     feature_extractor=AutoFeatureExtractor.from_pretrained(
                         "carbon225/vit-base-patch16-224-hentai"),
                     device=device,
                     torch_dtype=dtype)


style_pipe = pipeline("image-classification",
                      model=AutoModelForImageClassification.from_pretrained(
                          "cafeai/cafe_style"),
                      feature_extractor=AutoFeatureExtractor.from_pretrained(
                          "cafeai/cafe_style"),
                      device=device,
                      torch_dtype=dtype)

aesthetic_pipe = pipeline("image-classification",
                          model=AutoModelForImageClassification.from_pretrained(
                              "cafeai/cafe_aesthetic"),
                          feature_extractor=AutoFeatureExtractor.from_pretrained(
                              "cafeai/cafe_aesthetic"),
                          device=device,
                          torch_dtype=dtype)



In [None]:
data_path = os.path.join("..", "..", "datasets")
full_set = pd.read_csv(os.path.join(data_path, "full_set.csv"))
full_set["img_path"] = full_set["img_path"].transform(lambda x: os.path.join(data_path, x))
full_set

In [None]:
img = Image.open(full_set["img_path"][0])

In [None]:
img

In [None]:
pil_images = [img]

In [None]:
aesthetic = aesthetic_pipe(pil_images)
aesthetic[0][0]["score"]

In [None]:
nsfw = nsfw_pipe(pil_images)
nsfw[0][0]["score"]

In [None]:
aesthetic_predicts = []
for path in full_set["img_path"]:
    aesthetic_predicts.append(
        aesthetic_pipe([Image.open(path)])
    )
aesthetic_scores = pd.Series(aesthetic_predicts)
aesthetic_scores

In [None]:
aesthetic_scores.transform(lambda x: x[0][0]["score"])

In [None]:
# comparision df
cp_df = pd.DataFrame({
    "id": full_set["id"],
    "rating": full_set["rating"],
    "predict": aesthetic_scores.transform(lambda x: x[0][0]["score"])
})

In [None]:
cp_df

In [None]:
import matplotlib.pyplot as plt

plt.scatter(cp_df["rating"], cp_df["predict"])

In [None]:
plt.scatter(cp_df['predict'], [0] * 2256, c=cp_df['rating']>1, cmap='coolwarm')
plt.xlabel('X')
plt.ylabel('predict')
plt.colorbar(label='rating')

# Отображение графика
plt.show()