# Eng to ASL (SignWriting) and Visuailazation

### Setting Up

In [None]:
!git clone https://github.com/sign-language-processing/signbank-plus
# !cd signbank-plus/signbank_plus/nmt/ && git clone https://github.com/J22Melody/signwriting-translation
%pip install -r signbank-plus/requirements.txt
%pip install subword-nmt OpenNMT-py==1.2.0 install torch torchvision torchaudio

### Importing

In [None]:
import csv
import gzip
import itertools
import random
from collections import defaultdict
from pathlib import Path

from tqdm import tqdm
import importlib
module_name = 'signbank-plus.signbank_plus.load_data'
load_data_fol = importlib.import_module(module_name)
load_data = load_data_fol.load_data
load_file = load_data_fol.load_file
from signwriting.tokenizer import SignWritingTokenizer

from shutil import copy2
import subprocess
from os.path import exists, join, isfile
from os import listdir
import re


def run_command(command):
    print("\n" + command + "\n")
    process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
    for line in process.stdout:
        print(line, end='')
    process.wait()
    if process.returncode != 0:
        raise subprocess.CalledProcessError(process.returncode, command)

### Pre-Processing

In [None]:
ALL_FLAGS = set()

def get_source_target(data, field="annotated_texts"):
    random.Random(42).shuffle(data)  # Shuffle data consistently
    for instance in data:
        if field in instance:
            for text in instance[field]:
                if len(text.strip()) > 0 and len(instance["sign_writing"].strip()) > 0:
                    yield {
                        "puddle_id": instance["puddle_id"] if "puddle_id" in instance else None,
                        "example_id": instance["example_id"] if "example_id" in instance else None,
                        "flags": [instance["spoken_language"], instance["sign_language"]],
                        "source": instance["sign_writing"].strip(),
                        "target": text.strip(),
                    }


def get_source_target_no_test(data, field="annotated_texts"):
    test_instances = load_data("benchmark")
    test_instances = {(instance['puddle_id'], instance['example_id']) for instance in test_instances}
    for instance in get_source_target(data, field):
        if (instance['puddle_id'], instance['example_id']) not in test_instances:
            yield instance


# Model 1: Original data
def get_original_data():
    data = load_data("raw")
    yield from get_source_target_no_test(data, field="texts")


# Model 2: Cleaned data
def get_cleaned_data():
    data = load_data("raw", "gpt-3.5-cleaned", "manually-cleaned", "bible")
    yield from get_source_target_no_test(data, field="annotated_texts")


# Model 3: Expanded data
def get_expanded_data():
    data = load_data("raw", "gpt-3.5-cleaned", "gpt-3.5-expanded", "manually-cleaned", "bible")
    yield from get_source_target_no_test(data, field="annotated_texts")


def get_expanded_data_en():
    data = load_data("gpt-3.5-expanded.en")
    yield from get_source_target_no_test(data, field="annotated_texts")


def test_set():
    data = load_file("benchmark", array_fields=["gold_texts"])
    yield from get_source_target(data, field="gold_texts")


def save_parallel_csv(path: Path, data: iter, split="train", extra_flags=[]):
    for flag in extra_flags:
        ALL_FLAGS.add(flag)

    f_source = open(f"{path}/{split}.source", "w", encoding="utf-8")
    f_source_tokenized = open(f"{path}/{split}.source.tokenized", "w", encoding="utf-8")
    f_target = open(f"{path}/{split}.target", "w", encoding="utf-8")
    f_csv = open(f"{path}/{split}.csv", "w", encoding="utf-8")

    f_spoken_gzip = gzip.open(path.joinpath(f'{split}.spoken.gz'), 'wt', encoding='utf-8')
    f_signed_gzip = gzip.open(path.joinpath(f'{split}.signed.gz'), 'wt', encoding='utf-8')

    tokenizer = SignWritingTokenizer()

    writer = csv.DictWriter(f_csv, fieldnames=["source", "target"])
    writer.writeheader()
    for instance in tqdm(data):
        if 0 < len(instance["target"]) < 512 and 0 < len(instance["source"]) < 1024:
            flag_tokens = [f"${flag}" for flag in instance["flags"]]
            for flag in flag_tokens:
                ALL_FLAGS.add(flag)
            flags = " ".join(flag_tokens)

            source = flags + " " + instance["source"]
            writer.writerow({
                "source": source,
                "target": instance["target"],
            })
            f_source.write(source + "\n")
            f_target.write(instance["target"] + "\n")

            tokens_source = list(tokenizer.text_to_tokens(instance["source"]))
            tokenized_source = " ".join(tokens_source)
            f_source_tokenized.write(flags + " " + tokenized_source + "\n")

            gzip_flags = " ".join(extra_flags) + " " + flags
            # We detokenize the SignWriting, which removes "A" prefixes, and box placement
            detokenized_source = tokenizer.tokens_to_text(tokens_source)
            f_spoken_gzip.write(gzip_flags + " " + instance["target"] + "\n")
            f_signed_gzip.write(gzip_flags + " " + detokenized_source + "\n")

    f_source.close()
    f_source_tokenized.close()
    f_target.close()
    f_csv.close()


def save_splits(path: Path, data: iter, extra_flags: list = [], dev_num=3000):
    path.mkdir(parents=True, exist_ok=True)
    if dev_num > 0:
        save_parallel_csv(path, itertools.islice(data, dev_num), split="dev", extra_flags=extra_flags)
    save_parallel_csv(path, data, split="train", extra_flags=extra_flags)


def save_test(path: Path, data: iter):
    path.mkdir(parents=True, exist_ok=True)
    save_parallel_csv(path, data, split="all")

    # Read source file and target file
    with open(f"{path}/all.source", 'r', encoding='utf-8') as f:
        source_lines = [l.strip() for l in f.readlines()]
    with open(f"{path}/all.source.tokenized", 'r', encoding='utf-8') as f:
        source_lines_tokenized = [l.strip() for l in f.readlines()]
    with open(f"{path}/all.target", 'r', encoding='utf-8') as f:
        target_lines = [l.strip() for l in f.readlines()]

    source_map = {source_tokenized: source for source_tokenized, source in zip(source_lines_tokenized, source_lines)}

    source_target_map = defaultdict(list)
    for source, target in zip(source_lines_tokenized, target_lines):
        source_target_map[source].append(target)

    max_references = max(len(references) for references in source_target_map.values())
    print(f"Max test references: {max_references}")

    path.mkdir(parents=True, exist_ok=True)

    with open(f"{path}/test.source.unique", 'w') as f1:
        with open(f"{path}/test.source.unique.tokenized", 'w') as f2:
            for source, references in source_target_map.items():
                f1.write(source_map[source])
                f1.write("\n")
                f2.write(source)
                f2.write("\n")

    for i in range(max_references):
        with open(f"{path}/test.target.{i}", 'w', encoding='utf-8') as f:
            for source, references in source_target_map.items():
                if len(references) > i:
                    f.write(references[i])
                f.write("\n")


if True:
    parallel_path = Path("signbank-plus/data/parallel")

    save_test(parallel_path / "test", test_set())

    save_splits(parallel_path / "original", get_original_data())
    save_splits(parallel_path / "cleaned", get_cleaned_data())
    save_splits(parallel_path / "expanded", itertools.chain.from_iterable([
        get_expanded_data(),
        get_expanded_data_en()
    ]))

    save_splits(parallel_path / "more", itertools.chain.from_iterable([
        get_source_target(load_data("sign2mint"), field="texts"),
        get_source_target(load_data("signsuisse"), field="texts"),
        get_source_target(load_data("fingerspelling"), field="texts"),
    ]), dev_num=0)

    print("\n" + ",".join(ALL_FLAGS))

### Training

In [None]:
def find_latest_checkpoint(checkpoint_dir):
    checkpoint_files = [f for f in listdir(checkpoint_dir) if re.match(r'checkpoint_step_\d+\.pt', f)]
    if not checkpoint_files:
        return None
    checkpoint_files.sort()
    latest_checkpoint = join(checkpoint_dir, checkpoint_files[-2 if len(checkpoint_files) > 1 else -1])
    return latest_checkpoint


def train(_1, _2, _3 = None):
    _2t = Path(_2)
    _2t.parent.mkdir(parents=True, exist_ok=True)
    _2t.mkdir(exist_ok=True)
    (_2t / "data").mkdir(exist_ok=True)
    (_2t / "model").mkdir(exist_ok=True)
    
    if _3 is not None:
        copy2(f"{_3}/bpe.codes.target", f"{_2}/bpe.codes.target")
        
    
    # Target BPE
    if not exists(f"{_2}/bpe.codes.target"):
        command = f"subword-nmt learn-bpe -s 3000 < {_1}/train.target > {_2}/bpe.codes.target"
        run_command(command)
    
    if not exists(f"{_2}/data/train.target"):
        command = f"subword-nmt apply-bpe -c {_2}/bpe.codes.target < {_1}/train.target > {_2}/data/train.target"
        run_command(command)
        
    if not exists(f"{_2}/data/dev.target"):
        command = f"subword-nmt apply-bpe -c {_2}/bpe.codes.target < {_1}/dev.target > {_2}/data/dev.target"
        run_command(command)
        
    # Copy source
    if not exists(f"{_2}/data/train.source"):
        copy2(f"{_1}/train.source.tokenized", f"{_2}/data/train.source")
    
    if not exists(f"{_2}/data/dev.source"):
        copy2(f"{_1}/dev.source.tokenized", f"{_2}/data/dev.source")


    if not exists(f"{_2}/processed.vocab.pt"):
        command = f"onmt_preprocess --save_data {_2}/processed --shard_size 2000000 \
        --train_src {_2}/data/train.source --train_tgt {_2}/data/train.target \
        --valid_src {_2}/data/dev.source --valid_tgt {_2}/data/dev.target \
        --src_seq_length 512 \
        --tgt_seq_length 512 \
        --log_file_level DEBUG"
        run_command(command)

    train_from_param = ""
    latest_checkpoint = find_latest_checkpoint(f"{_2}/model/")
    if latest_checkpoint:
        train_from_param = f"--train_from {latest_checkpoint}"

    command = f"onmt_train --data {_2}/processed --save_model {_2}/model/checkpoint \
    {train_from_param} --layers 2 --rnn_size 512 --word_vec_size 512 --heads 8 \
    --encoder_type transformer --decoder_type transformer --position_encoding --transformer_ff 2048 --dropout 0.1 \
    --early_stopping 10 --early_stopping_criteria accuracy ppl --batch_size 2048 --accum_count 3 --batch_type tokens \
    --max_generator_batches 2 --normalization tokens --optim adam --adam_beta2 0.998 --decay_method noam \
    --warmup_steps 3000 --learning_rate 0.5 --max_grad_norm 0 --param_init 0 --param_init_glorot --label_smoothing 0.1 \
    --valid_steps 500 --save_checkpoint_steps 500 --world_size 1 --gpu_ranks 0"
    run_command(command)
    

train("signbank-plus/data/parallel/original", "opennmt/original")
train("signbank-plus/data/parallel/cleaned", "opennmt/cleaned")
train("signbank-plus/data/parallel/expanded", "opennmt/expanded")

### Evauluation

In [None]:
def find_best_model(_2):
    model_files = listdir(join(_2, "model"))
    model_indices = [int(file.split('_')[2]) for file in model_files]
    sorted_models = sorted(zip(model_indices, model_files), key=lambda x: x[0])
    
    best_model = sorted_models[-11][1] if len(sorted_models) >= 11 else None
    return best_model


def evall(_1, _2):
    best_model = find_best_model(_2)

    if best_model is None:
        print("Error: Unable to find the best model.")
        return

    # Translating
    translations_bpe_path = join(_2, "test.translations.bpe")
    if not isfile(translations_bpe_path):
        translate_command = f"onmt_translate --model {join(_2, 'model', best_model)} \
        --src {join(_1, 'test.source.unique.tokenized')} \
        --output {translations_bpe_path} \
        --gpu 0 --replace_unk --beam_size 5"
        run_command(translate_command)

    # Removing BPE
    translations_path = join(_2, "test.translations")
    with open(translations_bpe_path, 'r') as f_in, open(translations_path, 'w') as f_out:
        for line in f_in:
            f_out.write(line.replace("@@ ", "").replace("@@", ""))

    # Computing BLEU and CHR F
    sacrebleu_output_path = join(_2, "sacrebleu.txt")
    sacrebleu_command = f"sacrebleu $(find {_1} -type f -name 'test.target*') -i {translations_path} -m bleu chrf --width 2 > {sacrebleu_output_path}"
    run_command(sacrebleu_command)
   
    
evall("signbank-plus/data/parallel/test", "opennmt/original")
evall("signbank-plus/data/parallel/test", "opennmt/cleaned")
evall("signbank-plus/data/parallel/test", "opennmt/expanded")