In [None]:
import pathlib
import random

import numpy as np
import torch

from loguru import logger

from language_detection.data import transform_text, get_mask_from_lengths
from language_detection.model import TrainingConfig, TransformerClassifier

In [None]:
# NOTEBOOK ONLY
from dataclasses import dataclass
@dataclass
class Args:
    checkpoint_file: str = "experiments/wili2018/wili2018-checkpoint-000020.pt"
    debug: bool = False
args = Args()

In [None]:
class LanguageDetector:

    def __init__(self, checkpoint_filepath: str):
        logger.info("initializing detector from checkpoint '{checkpoint_filepath}'...")
        self.checkpoint = self.load_checkpoint(checkpoint_filepath)
        if "output_mapping" not in self.checkpoint:
            raise ValueError("checkpoint file is missing 'output_mapping'!")
        self.config = self.load_config()
        self.model = self.load_model()
        self.device = str(list(self.model.parameters())[0].device)
        random.seed(self.config.seed)
        np.random.seed(self.config.seed)
        torch.manual_seed(self.config.seed)
        logger.info(f"done! (using device '{self.device}')")

    def load_checkpoint(self, checkpoint_path: str):
        if not pathlib.Path(checkpoint_path).is_file():
            raise ValueError(f"checkpoint file '{checkpoint_path}' does not exist!")
        checkpoint = torch.load(checkpoint_path)
        return checkpoint
    
    def load_config(self) -> TrainingConfig:
        if not self.checkpoint:
            raise RuntimeError("no checkpoint loaded, load checkpoint first!")
        config = TrainingConfig(**self.checkpoint["config"])
        return config

    def load_model(self) -> TransformerClassifier:
        if not self.config:
            raise RuntimeError("no config loaded, load config first!")
        model = TransformerClassifier(num_classes=self.checkpoint["num_classes"])
        model.load_state_dict(self.checkpoint["model_state_dict"])
        if torch.cuda.is_available():
            device_string = "cuda"
        else:
            device_string = "cpu"
        _ = model.to(device_string)
        model.eval()
        return model

    def create_input_tensor(self, test_text: str) -> tuple[torch.Tensor, torch.Tensor]:
        """create model inputs for single sample"""
        x_input, y_output, seq_len, _idxs = transform_text(
            text=test_text, 
            is_training=False, 
            max_length=self.config.max_length
        )
        x_input = x_input.reshape(1, self.config.max_length)
        seq_lens = torch.tensor([seq_len])
        x_input = x_input.to(self.device)
        pad_mask = get_mask_from_lengths(seq_lens, self.config.max_length, self.device)
        return x_input, pad_mask
    
    def predict(self, test_text: str) -> str:
        x_input, x_mask = self.create_input_tensor(test_text)
        clf_logits, _mlm_logits = self.model.forward(x_input, x_mask)
        preds = clf_logits.max(1).indices.detach().cpu().numpy()
        lang_code = self.checkpoint["output_mapping"][preds[0]]
        if "extended_labels" in self.checkpoint:
            lang_name = self.checkpoint["extended_labels"][lang_code]
            return lang_name
        return lang_code

In [None]:
lang_detector = LanguageDetector(args.checkpoint_file)

In [None]:
test_string = "Sportief succes kost veel geld in de Formule 1. Red Bull, het team van wereldkampioen Max Verstappen, moet volgend jaar meer dan ooit betalen om te mogen deelnemen aan de koningsklasse"
print(lang_detector.predict(test_string))

In [None]:
test_string = "Außenministerin Baerbock formuliert drei deutsche Ziele für die COP28: Mehr Tempo, mehr Solidarität, mehr Partnerschaft."
print(lang_detector.predict(test_string))

In [None]:
test_string = "제가 이 채널에서는 어디에 썼는지 지금 당장 찾지를 못했는데, PR 만들 때 format에도 써놨습니다. 예를 들면 PR만들려고 열어 보시면 자동으로 이런 문구가 떠요(사람들이 안 볼까봐 제가 일부러 넣어뒀습니다"
print(lang_detector.predict(test_string))

In [None]:
test_string = "日本語のニュースや番組をテレビとラジオで海外向けに放送しています。海外にお住まいの方、旅行中の方にも情報をお届けします。"
print(lang_detector.predict(test_string))

In [None]:
test_string = "ي بي سي العربية هي شبكة لنقل الأخبار والمعلومات ومقاطع الفيديو إلى العالم عبر عدة وسائط، تشمل الإنترنت ومواقع التواصل الاجتماعي والراديو والتلفزيون ..."
print(lang_detector.predict(test_string))