<a href="https://colab.research.google.com/github/Arseniy-Polyakov/russian_sign_language_recognition/blob/main/Sign_Language_Recognition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch --upgrade --quiet

In [26]:
import re
import wandb
import pandas as pd
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    BartForConditionalGeneration,
    BartTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForSeq2Seq
)
from datasets import load_dataset

In [None]:
wandb.init(
    project="marian-mt-translation",
    name="marian-mt-ru-rsl"
)

In [16]:
dataset = load_dataset(
    "csv",
    data_files={
        "train": "train.csv",
        "test": "test.csv"
    }
)
dataset

DatasetDict({
    train: Dataset({
        features: ['russian', 'rsl'],
        num_rows: 80
    })
    test: Dataset({
        features: ['russian', 'rsl'],
        num_rows: 20
    })
})

In [17]:
def tokenizing(tokenizer):
    def preprocess(batch):
        model_inputs = tokenizer(
            batch["russian"],
            truncation=True,
            padding="max_length",
            max_length=128
        )

        with tokenizer.as_target_tokenizer():
            labels = tokenizer(
                batch["rsl"],
                truncation=True,
                padding="max_length",
                max_length=128
            )
        labels_ids = labels["input_ids"]
        labels_ids = [
            [(tok if tok != tokenizer.pad_token_id else -100) for tok in sent]
            for sent in labels_ids
        ]

        model_inputs["labels"] = labels_ids
        return model_inputs
    return preprocess

In [None]:
# model_path = "facebook/bart-base"
# model_path = "google/mt5-small"
model_path = "hf-internal-testing/tiny-random-MarianMTModel"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

# model = BartForConditionalGeneration.from_pretrained(model_path)
# tokenizer = BartTokenizer.from_pretrained(model_path)
# data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [None]:
preprocessing = tokenizing(tokenizer)
tokenized_dataset = dataset.map(preprocessing, batched=True, remove_columns=dataset["train"].column_names)

In [None]:
training_args = TrainingArguments(
    output_dir="marian-mt",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=1000,
    learning_rate=5e-5,
    report_to="wandb",
    run_name="marian-mt-ru-rsl",
    eval_steps=500,
    metric_for_best_model="bleu",
    greater_is_better=True
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    data_collator=data_collator
)

In [None]:
trainer.train()

In [None]:
wandb.finish()

In [None]:
model.eval()

In [None]:
def translate_texts(texts, model, tokenizer, device="cpu", max_length=128, num_beams=5):
    model.to(device)
    inputs = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            num_beams=num_beams,
            early_stopping=True,
            no_repeat_ngram_size=3
        )

    translations = tokenizer.batch_decode(
        outputs,
        skip_special_tokens=True
    )
    return translations

In [None]:
df = pd.read_csv("giga_chat_dataset.csv")
data_for_translation = df["russian"].to_list()

translations = translate_texts(
    data_for_translation,
    model,
    tokenizer
)
translations