In [1]:
# Load model directly
from transformers import AutoTokenizer, AutoModel
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
import torch
import nltk
import evaluate
from PIL import Image
tokenizer = AutoTokenizer.from_pretrained("tuman/vit-rugpt2-image-captioning")


Collecting evaluate
  Downloading evaluate-0.4.1-py3-none-any.whl.metadata (9.4 kB)
Collecting responses<0.19 (from evaluate)
  Downloading responses-0.18.0-py3-none-any.whl.metadata (29 kB)
Downloading evaluate-0.4.1-py3-none-any.whl (84 kB)
   ---------------------------------------- 0.0/84.1 kB ? eta -:--:--
   -------------- ------------------------- 30.7/84.1 kB 1.4 MB/s eta 0:00:01
   ---------------------------------------- 84.1/84.1 kB 1.2 MB/s eta 0:00:00
Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Installing collected packages: responses, evaluate
Successfully installed evaluate-0.4.1 responses-0.18.0


tokenizer_config.json:   0%|          | 0.00/780 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.81M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.27M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/217 [00:00<?, ?B/s]

In [None]:
model = VisionEncoderDecoderModel.from_pretrained("tuman/vit-rugpt2-image-captioning", cache_dir='E:\\ggames')

In [3]:
feature_extractor = ViTFeatureExtractor.from_pretrained("tuman/vit-rugpt2-image-captioning", cache_dir='E:\\ggames')



In [4]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
dt = pd.read_csv('train.csv', sep=';', encoding='utf-8')
path_train = '/home/jupyter/datasphere/project/train'

In [5]:
dt_non_nan = dt[dt['description'].notna()].reset_index()
dt_non_nan['path'] = path_train + '/' + (dt_non_nan['object_id']).astype(str) + '/' + dt_non_nan['img_name']
dt_non_nan = dt_non_nan.drop(columns=['index', 'name', 'group', 'img_name', 'object_id'])

In [6]:
train_x, test_x, train_y, test_y = train_test_split(dt_non_nan['path'].values, 
                                                    dt_non_nan['description'].values, 
                                                    test_size=0.1)
test_x, valid_x, test_y, valid_y = train_test_split(test_x, 
                                                    test_y, 
                                                    test_size=0.05)

In [7]:
def get_pixels(image_paths):
  images = []
  for image_path in image_paths:
    i_image = Image.open(image_path)
    if i_image.mode != "RGB":
      i_image = i_image.convert(mode="RGB")

    images.append(i_image)

  pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
  return pixel_values


In [8]:
def get_label(texts, max_target_length):
    return tokenizer(texts,
                     return_tensors='pt',
                     padding='max_length',
                     max_length=max_target_length,
                     truncation=True).input_ids

In [9]:
from torch.utils.data import Dataset
class Custom(Dataset):
    def __init__(self, X, y, max_target_length=128):
        self.max_target_length = max_target_length
        self.all_files = X
        self.all_texts = y

    def __len__(self):
        return len(self.all_files)

    def __getitem__(self, idx):
        pixel_values = get_pixels([self.all_files[idx]])
        # add labels (input_ids) by encoding the text
        labels = get_label(self.all_texts[idx], self.max_target_length)[0]
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

In [10]:
train_dataset = Custom(train_x, train_y)
test_dataset = Custom(test_x, test_y)
valid_dataset = Custom(valid_x, valid_y)

In [11]:
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    preds = ["\n".join(nltk.sent_tokenize(pred, language='russian')) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label, language='russian')) for label in labels]

    return preds, labels

In [12]:
import evaluate
metric = evaluate.load("rouge")

Downloading builder script: 100%|██████████| 6.27k/6.27k [00:00<00:00, 2.68MB/s]


In [13]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds,
                                                     decoded_labels)

    result = metric.compute(predictions=decoded_preds,
                            references=decoded_labels,
                            use_stemmer=True)
    return result

In [14]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    eval_steps=1,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    output_dir="./image-captioning-output",
    save_steps=3000,
    report_to='clearml'
)

from transformers import default_data_collator

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=default_data_collator,
)

In [17]:
trainer.train(resume_from_checkpoint=True)

ClearML Task: overwriting (reusing) task id=da48a4eeeed04a57a55124ed1ea8ad7b
2024-04-13 19:33:11,763 - clearml.Task - INFO - No repository found, storing script code instead
ClearML results page: https://app.clear.ml/projects/bed312fc3f7f41c9a552f47d265832db/experiments/da48a4eeeed04a57a55124ed1ea8ad7b/output/log


Unsupported key of type '<class 'int'>' found when connecting dictionary. It will be converted to str
 43%|████▎     | 3001/6942 [00:01<00:02, 1645.36it/s]
  0%|          | 0/13 [00:00<?, ?it/s][A
 15%|█▌        | 2/13 [00:01<00:07,  1.46it/s][A
 23%|██▎       | 3/13 [00:02<00:09,  1.06it/s][A
 31%|███       | 4/13 [00:04<00:10,  1.14s/it][A
 38%|███▊      | 5/13 [00:05<00:10,  1.33s/it][A
 46%|████▌     | 6/13 [00:07<00:09,  1.35s/it][A
 54%|█████▍    | 7/13 [00:08<00:08,  1.34s/it][A
 62%|██████▏   | 8/13 [00:09<00:06,  1.35s/it][A
 43%|████▎     | 3001/6942 [00:17<00:02, 1645.36it/s]
 77%|███████▋  | 10/13 [00:12<00:04,  1.41s/it][A
 85%|████████▍ | 11/13 [00:14<00:02,  1.37s/it][A
 92%|█████████▏| 12/13 [00:15<00:01,  1.36s/it][A
                                                     
 43%|████▎     | 3001/6942 [00:22<00:02, 1645.36it/s]
100%|██████████| 13/13 [00:17<00:00,  1.35s/it][A
                                               [A

{'eval_loss': 0.816318929195404, 'eval_rouge1': 0.0373931623931624, 'eval_rouge2': 0.0, 'eval_rougeL': 0.038461538461538464, 'eval_rougeLsum': 0.038461538461538464, 'eval_runtime': 20.291, 'eval_samples_per_second': 2.563, 'eval_steps_per_second': 0.641, 'epoch': 1.3}


 43%|████▎     | 3002/6942 [00:22<00:42, 93.73it/s]  
  0%|          | 0/13 [00:00<?, ?it/s][A
 15%|█▌        | 2/13 [00:01<00:07,  1.56it/s][A
 23%|██▎       | 3/13 [00:02<00:09,  1.08it/s][A
 31%|███       | 4/13 [00:04<00:10,  1.12s/it][A
 38%|███▊      | 5/13 [00:05<00:10,  1.29s/it][A
 46%|████▌     | 6/13 [00:06<00:09,  1.30s/it][A
 54%|█████▍    | 7/13 [00:08<00:07,  1.32s/it][A
 62%|██████▏   | 8/13 [00:09<00:06,  1.32s/it][A
 69%|██████▉   | 9/13 [00:11<00:05,  1.34s/it][A
 77%|███████▋  | 10/13 [00:12<00:04,  1.39s/it][A
 85%|████████▍ | 11/13 [00:14<00:03,  1.59s/it][A
 92%|█████████▏| 12/13 [00:15<00:01,  1.50s/it][A
                                                   
 43%|████▎     | 3002/6942 [00:41<00:42, 93.73it/s]
100%|██████████| 13/13 [00:17<00:00,  1.45s/it][A
                                               [A

{'eval_loss': 0.8156265616416931, 'eval_rouge1': 0.0373931623931624, 'eval_rouge2': 0.0, 'eval_rougeL': 0.038461538461538464, 'eval_rougeLsum': 0.038461538461538464, 'eval_runtime': 18.9542, 'eval_samples_per_second': 2.743, 'eval_steps_per_second': 0.686, 'epoch': 1.3}


 43%|████▎     | 3003/6942 [00:43<01:36, 40.92it/s]
  0%|          | 0/13 [00:00<?, ?it/s][A
 15%|█▌        | 2/13 [00:01<00:06,  1.59it/s][A
 23%|██▎       | 3/13 [00:02<00:08,  1.11it/s][A
 31%|███       | 4/13 [00:03<00:09,  1.10s/it][A
 38%|███▊      | 5/13 [00:05<00:10,  1.27s/it][A
 46%|████▌     | 6/13 [00:06<00:09,  1.29s/it][A
 54%|█████▍    | 7/13 [00:08<00:07,  1.30s/it][A
 62%|██████▏   | 8/13 [00:09<00:06,  1.32s/it][A
 69%|██████▉   | 9/13 [00:10<00:05,  1.34s/it][A
 77%|███████▋  | 10/13 [00:12<00:04,  1.39s/it][A
 85%|████████▍ | 11/13 [00:13<00:02,  1.37s/it][A
 92%|█████████▏| 12/13 [00:15<00:01,  1.35s/it][A
                                                   
 43%|████▎     | 3003/6942 [01:01<01:36, 40.92it/s]
100%|██████████| 13/13 [00:16<00:00,  1.33s/it][A
                                               [A

{'eval_loss': 0.8149785399436951, 'eval_rouge1': 0.0373931623931624, 'eval_rouge2': 0.0, 'eval_rougeL': 0.038461538461538464, 'eval_rougeLsum': 0.038461538461538464, 'eval_runtime': 18.1427, 'eval_samples_per_second': 2.866, 'eval_steps_per_second': 0.717, 'epoch': 1.3}


 43%|████▎     | 3004/6942 [01:02<02:48, 23.44it/s]
  0%|          | 0/13 [00:00<?, ?it/s][A
 15%|█▌        | 2/13 [00:01<00:07,  1.51it/s][A
 23%|██▎       | 3/13 [00:02<00:09,  1.08it/s][A
 31%|███       | 4/13 [00:04<00:10,  1.12s/it][A
 38%|███▊      | 5/13 [00:05<00:10,  1.30s/it][A
 46%|████▌     | 6/13 [00:07<00:09,  1.31s/it][A
 54%|█████▍    | 7/13 [00:08<00:07,  1.32s/it][A
 62%|██████▏   | 8/13 [00:09<00:06,  1.32s/it][A
 69%|██████▉   | 9/13 [00:11<00:05,  1.34s/it][A
 77%|███████▋  | 10/13 [00:12<00:04,  1.37s/it][A
 85%|████████▍ | 11/13 [00:13<00:02,  1.35s/it][A
 92%|█████████▏| 12/13 [00:15<00:01,  1.33s/it][A
                                                   
 43%|████▎     | 3004/6942 [01:20<02:48, 23.44it/s]
100%|██████████| 13/13 [00:16<00:00,  1.34s/it][A
                                               [A

{'eval_loss': 0.8142771124839783, 'eval_rouge1': 0.0373931623931624, 'eval_rouge2': 0.0, 'eval_rougeL': 0.038461538461538464, 'eval_rougeLsum': 0.038461538461538464, 'eval_runtime': 18.2099, 'eval_samples_per_second': 2.856, 'eval_steps_per_second': 0.714, 'epoch': 1.3}


 43%|████▎     | 3005/6942 [01:20<04:30, 14.55it/s]
  0%|          | 0/13 [00:00<?, ?it/s][A
 15%|█▌        | 2/13 [00:01<00:07,  1.49it/s][A
 23%|██▎       | 3/13 [00:02<00:09,  1.04it/s][A
 31%|███       | 4/13 [00:04<00:10,  1.12s/it][A
 38%|███▊      | 5/13 [00:05<00:10,  1.28s/it][A
 46%|████▌     | 6/13 [00:06<00:08,  1.29s/it][A
 54%|█████▍    | 7/13 [00:08<00:07,  1.31s/it][A
 62%|██████▏   | 8/13 [00:09<00:06,  1.31s/it][A
 69%|██████▉   | 9/13 [00:10<00:05,  1.32s/it][A
 77%|███████▋  | 10/13 [00:12<00:04,  1.36s/it][A
 85%|████████▍ | 11/13 [00:13<00:02,  1.36s/it][A
 92%|█████████▏| 12/13 [00:15<00:01,  1.34s/it][A
                                                   
 43%|████▎     | 3005/6942 [01:39<04:30, 14.55it/s]
100%|██████████| 13/13 [00:16<00:00,  1.34s/it][A
                                               [A

{'eval_loss': 0.8137597441673279, 'eval_rouge1': 0.0373931623931624, 'eval_rouge2': 0.0, 'eval_rougeL': 0.038461538461538464, 'eval_rougeLsum': 0.038461538461538464, 'eval_runtime': 18.2538, 'eval_samples_per_second': 2.849, 'eval_steps_per_second': 0.712, 'epoch': 1.3}


ValueError: Unsupported number of image dimensions: 0