In [1]:
from transformers import GitProcessor, GitForCausalLM, AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
from PIL import Image
import torch
import json
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

In [3]:
git_processor = GitProcessor.from_pretrained("microsoft/git-large")
git_model = GitForCausalLM.from_pretrained("microsoft/git-large").to(device)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [4]:
llama_tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
llama_model = AutoModelForCausalLM.from_pretrained(
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    torch_dtype=torch.float16,
    device_map="auto"
).eval()

In [5]:
sbert_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

In [6]:
def clean_caption(text):
    return text.replace("[ unused0 ]", "").strip()

In [7]:
def merge_captions(base_caption, desc1, desc2):
    prompt = f"""
Given the base caption that is true and factual:
\"{base_caption}\"

And two descriptive captions:
1) {desc1}
2) {desc2}

Write a short, coherent description that is faithful to the base caption but incorporates descriptive elements from captions 1 and 2 without contradicting the original meaning.
"""
    inputs = llama_tokenizer(prompt, return_tensors="pt").to(llama_model.device)
    with torch.no_grad():
        ids = llama_model.generate(**inputs, max_new_tokens=100, do_sample=False)
        text = llama_tokenizer.decode(ids[0], skip_special_tokens=True)
        result = text[len(prompt):].strip()
        for prefix in ["Example:", "example:"]:
            if result.startswith(prefix):
                result = result[len(prefix):].strip()
        return result

In [8]:
BATCH_SIZE = 16
SAVE_EVERY = 500
MAX_IMAGES = 5000
OUTPUT_FILE ="wikiart_captions_embeddings_10000_14999.jsonl"

In [9]:
def save_batch_to_jsonl(batch_data, file_path):
    with open(file_path, "a") as f:
        for item in batch_data:
            f.write(json.dumps(item) + "\n")

In [10]:
def process_batch(batch_examples):
    results = []
    images = []
    image_ids = []

    for ex in batch_examples:
        img = ex["image"].convert("RGB")
        images.append(img)
        image_ids.append(ex["image"].filename if hasattr(ex["image"], "filename") else None)


    if not images:
        return results

    pixel_values = git_processor(images=images, return_tensors="pt")["pixel_values"].to(device)

    # do_sample=False
    with torch.no_grad():
        base_ids = git_model.generate(pixel_values=pixel_values, max_new_tokens=30, do_sample=False)
    base_captions = [clean_caption(git_processor.tokenizer.decode(ids, skip_special_tokens=True)) for ids in base_ids]

    # do_sample=True
    with torch.no_grad():
        sampled_ids = git_model.generate(
            pixel_values=pixel_values,
            max_new_tokens=30,
            do_sample=True,
            top_k=100,
            temperature=0.8,
            num_return_sequences=2
        )
        
    sampled_ids = sampled_ids.view(len(images), 2, -1)

    merged_captions = []
    for i in range(len(images)):
        desc1 = clean_caption(git_processor.tokenizer.decode(sampled_ids[i][0], skip_special_tokens=True))
        desc2 = clean_caption(git_processor.tokenizer.decode(sampled_ids[i][1], skip_special_tokens=True))
        merged = merge_captions(base_captions[i], desc1, desc2)
        merged_captions.append(merged)

    # Эмбеддинги с sentence-transformers
    embeddings = sbert_model.encode(merged_captions, convert_to_numpy=True).tolist()

    for i in range(len(images)):
        results.append({
            "image_id": image_ids[i],
            "caption": merged_captions[i],
            "embedding": embeddings[i]
        })

    return results


In [21]:
# dataset = load_dataset("huggan/wikiart", split="train")

# processed_count = 0
# buffer = []

# pbar = tqdm(total=MAX_IMAGES)
# batch = []

# for example in dataset:
#     if processed_count >= MAX_IMAGES:
#         break
#     batch.append(example)
#     if len(batch) == BATCH_SIZE:
#         result_batch = process_batch(batch)
#         buffer.extend(result_batch)
#         processed_count += len(result_batch)
#         pbar.update(len(result_batch))
#         batch = []

#         if processed_count % SAVE_EVERY == 0:
#             save_batch_to_jsonl(buffer, OUTPUT_FILE)
#             buffer = []

# # Сохраняем остатки
# if buffer:
#     save_batch_to_jsonl(buffer, OUTPUT_FILE)

# pbar.close()
# print(f"Готово! Сохранено {processed_count} примеров в {OUTPUT_FILE}")


  0%|                                                 | 0/10000 [00:38<?, ?it/s]
100%|██████████████████████████████████| 10000/10000 [17:26:00<00:00,  6.28s/it]

Готово! Сохранено 10000 примеров в wikiart_captions_embeddings.jsonl





In [None]:
BATCH_SIZE = 16
SAVE_EVERY = 500
MAX_IMAGES = 5000  
START_INDEX = 10000  
OUTPUT_FILE = "wikiart_captions_embeddings_10000_14999.jsonl" 

dataset = load_dataset("huggan/wikiart", split="train")

processed_count = 0
buffer = []

pbar = tqdm(total=MAX_IMAGES)
batch = []

for i, example in enumerate(dataset):
    if i < START_INDEX:
        continue
        
    if processed_count >= MAX_IMAGES:
        break
        
    batch.append(example)
    if len(batch) == BATCH_SIZE:
        result_batch = process_batch(batch)
        buffer.extend(result_batch)
        processed_count += len(result_batch)
        pbar.update(len(result_batch))
        batch = []

        if processed_count % SAVE_EVERY == 0:
            save_batch_to_jsonl(buffer, OUTPUT_FILE)
            buffer = []

if buffer:
    save_batch_to_jsonl(buffer, OUTPUT_FILE)

pbar.close()
print(f"Готово! Сохранено {processed_count} примеров в {OUTPUT_FILE}")

  0%|                                                  | 0/5000 [00:00<?, ?it/s]