<a href="https://colab.research.google.com/github/Arseniy-Polyakov/russian_sign_language_recognition/blob/main/Sign_Language_Recognition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch --upgrade --quiet

In [None]:
import re
import wandb
import pandas as pd
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    BartForConditionalGeneration,
    BartTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForSeq2Seq
)
from datasets import load_dataset

In [None]:
wandb.init(
    project="", # Enter project's folder here (For instance: nllb-translation)
    name="" # Enter project's name here (For instance: nllb-ru-rsl)
)

In [None]:
dataset = load_dataset(
    "csv",
    data_files={
        "train": "train.csv",
        "test": "test.csv"
    }
)
dataset

In [None]:
model_path = "" 
# Enter model path (For example: "facebook/bart-base", "google/mt5-small", "Helsinki-NLP/opus-mt-ru-en", "facebook/nllb-200-distilled-600M", "facebook/mbart-large-50")
# model_path = "facebook/mbart-large-50"

if model_path != "facebook/mbart-large-50":
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
else:
    model = BartForConditionalGeneration.from_pretrained(model_path)
    tokenizer = BartTokenizer.from_pretrained(model_path)
    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [None]:
tokenizer.src_lang = "" # Enter source language code (or ISO-code) here. For instance: ru_Cyrl
tokenizer.tgt_lang = "" # Enter target language code (or ISO-code) here. For instance: ru_Cyrl

In [None]:
def tokenizing(batch):
    return tokenizer(
        batch["russian"],
        text_target=batch["rsl"],
        truncation=True,
        padding="max_length",
        max_length=128,
    )

In [None]:
print(tokenizer.src_lang)
print(tokenizer.tgt_lang)

In [None]:
tokenized_dataset = dataset.map(tokenizing, batched=True, remove_columns=dataset["train"].column_names)

In [None]:
training_args = TrainingArguments(
    output_dir="", # Enter project's directory here
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=250,
    learning_rate=2e-5,
    report_to="wandb",
    run_name="", # Enter project's name here
    eval_steps=100,
    remove_unused_columns=False,
    metric_for_best_model="bleu",
    greater_is_better=True
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    data_collator=data_collator
)

In [None]:
trainer.train()

In [None]:
wandb.finish()

In [None]:
model.eval()

In [None]:
def translate_texts(texts, model, tokenizer, device="cpu", max_length=128, num_beams=5):
    model.to(device)
    inputs = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            num_beams=num_beams,
            early_stopping=True,
            no_repeat_ngram_size=3
        )

    translations = tokenizer.batch_decode(
        outputs,
        skip_special_tokens=True
    )
    return translations

In [None]:
df = pd.read_csv("giga_chat_dataset.csv")
data_for_translation = df["russian"].to_list()

translations = translate_texts(
    data_for_translation,
    model,
    tokenizer
)
translations