In [1]:
%pip install transformers
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128

Note: you may need to restart the kernel to use updated packages.
Looking in indexes: https://download.pytorch.org/whl/cu128
Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
from transformers import MarianTokenizer, MarianMTModel, AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, AutoConfig
from transformers.generation import BeamSearchScorer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:

def augment_tokenize(text, terms, tokenizer, device="cpu"):
    vocab = tokenizer.get_vocab()
    text_tokenized = tokenizer(text).input_ids
    for term_source,term_target in terms:
        term_source_tokenized = tokenizer(term_source).input_ids[:-1]
        term_target_tokenized = list(tokenizer(text_target=term_target).input_ids)[:-1]
        term_target_tokenized = [vocab["augmentsymbol1"]] + \
            term_target_tokenized + \
            [vocab["augmentsymbol2"]]
            
        current_aug_part_index = 0
        new_text_tokenized = []
        for token in text_tokenized:
            #TODO: add check for the word continuing
            if current_aug_part_index == len(term_source_tokenized):
                new_text_tokenized += [vocab["augmentsymbol0"]] + term_source_tokenized + \
                term_target_tokenized
                current_aug_part_index = 0
            if token == term_source_tokenized[current_aug_part_index]:
                current_aug_part_index += 1
            elif current_aug_part_index > 1:
                new_text_tokenized += term_source_tokenized[0:current_aug_part_index]
                new_text_tokenized.append(token)
                current_aug_part_index = 0
            else:
                new_text_tokenized.append(token)
        text_tokenized = new_text_tokenized

    input_ids = torch.tensor([text_tokenized], device=device)  # batch dimension added
    attention_mask = torch.ones_like(input_ids, device=device)

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask
    }

In [4]:
def load_model(model_name_or_path, device):
    config = AutoConfig.from_pretrained(model_name_or_path)

    # Check model type by architecture string
    if "Marian" in config.model_type or config.architectures and any("Marian" in arch for arch in config.architectures):
        model = MarianMTModel.from_pretrained(model_name_or_path).to(device).eval()
        model_type = "marian"
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16).to(device).eval()
        model_type = "causal_lm"

    return model, model_type

In [5]:
def load_tokenizer(model_name_or_path):
    config = AutoConfig.from_pretrained(model_name_or_path)

    # Check model type by architecture string
    if "Marian" in config.model_type or config.architectures and any("Marian" in arch for arch in config.architectures):
        tokenizer = MarianTokenizer.from_pretrained(model_name_or_path)
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

    return tokenizer

In [6]:
#TODO: Add batching for beam search, so that different beams are generated as batches.
# Also add batching for inputs, so that sets of sentences can be processed as a batch.
# Add past_key_ids to speed up LLM inference. 

#TODO: The vocab fixed model is being trained on Puhti, test that with LLM. Try training on puhti
# without --tsv (that caused some errors before), to see if the logical epoch weirdness (epoch being just 1M
# sentences) has an effect (is it maybe only training on the first million sentence pairs over and over
# again?).

#TODO: Start working with the pipeline, fine-tuning RAT models with the same vocab as the term model. Also
# do term models with single term only, to test ensembling performance.

class ShallowFusion():
    def __init__(self, models_info):
        self.models_info = models_info
        # === Load tokenizer and models ===
        # The tokenizer of the first model is used
        self.tokenizer = load_tokenizer(models_info[0]["name"])
        
        # tokenizer = AutoTokenizer.from_pretrained("LumiOpen/Viking-7B")
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.init_models()

    def init_models(self):
        self.models = []        
        for info in self.models_info:
            model, model_type = load_model(info["name"], self.device)
            self.models.append((model,model_type))

    def initialize_inputs(self, src_sentences_per_model):
        self.encoder_input_ids_list = []
        self.attention_masks_list = []
        self.decoder_input_ids_list = []

        num_beams = 4
        self.eos_token_id = self.tokenizer.eos_token_id
        self.pad_token_id = self.tokenizer.pad_token_id  # Marian uses pad token to start decoding
        self.start_token_id = self.tokenizer.pad_token_id  # Marian uses pad token to start decoding

        for (sentence,terms),(model,model_type) in zip(src_sentences_per_model,self.models):
            # Tokenize and prepare inputs
            if terms:
                enc_inputs = augment_tokenize(sentence, terms, self.tokenizer, self.device)
            else:
                enc_inputs = self.tokenizer(sentence, return_tensors="pt").to(self.device)
                
            encoder_input_ids = enc_inputs["input_ids"].expand(num_beams, -1).clone()
            attention_mask = enc_inputs["attention_mask"].expand(num_beams, -1).clone()

            self.encoder_input_ids_list.append(encoder_input_ids)
            self.attention_masks_list.append(attention_mask)

            if model_type == "marian":
                # Init decoder_input_ids with a start token
                decoder_input_ids = torch.full(
                    (num_beams, 1),
                    fill_value=self.start_token_id,
                    dtype=torch.long,
                    device=self.device,
                )
            else:
                # LLMs keep track of decoder input ids, but concat them with the input for generation
                decoder_input_ids = torch.empty((num_beams, 0), dtype=torch.long, device=self.device)
            self.decoder_input_ids_list.append(decoder_input_ids)
    
    def translate(self, src_sentences_per_model, num_beams=4, max_length=50):
        self.initialize_inputs(src_sentences_per_model)
        
        # === Beam search setup ===
        beam_scorer = BeamSearchScorer(
            batch_size=1,
            num_beams=num_beams,
            device=self.device,
            length_penalty=1.0,
            do_early_stopping=True,
            num_beam_hyps_to_keep=num_beams,
        )

        beam_scores = torch.zeros((num_beams,), dtype=torch.float, device=self.device)
        beam_scores[1:] = -1e9  # Only first beam is active at the beginning

        cur_len = 1  # decoder_input_ids starts with 1 token

        # === Step-by-step beam search loop ===
        while cur_len < max_length:
            all_log_probs = []

            # === Each model provides logits from its own source + decoder input ===
            for (model,model_type), encoder_input_ids, attention_mask, decoder_input_ids in zip(
                self.models, self.encoder_input_ids_list, self.attention_masks_list, self.decoder_input_ids_list
            ):
                with torch.no_grad():
                    if model_type == "marian":
                        outputs = model(
                            input_ids=encoder_input_ids,
                            attention_mask=attention_mask,
                            decoder_input_ids=decoder_input_ids,
                        )
                        logits = outputs.logits[:, -1, :]  # (num_beams, vocab_size)
                        logits[:, [self.pad_token_id]] = float("-inf")
                    else:
                        input_ids = torch.cat([encoder_input_ids, decoder_input_ids], dim=-1)
                        outputs = model(
                            input_ids=input_ids,
                            attention_mask=attention_mask,
                        )
                        logits = outputs.logits[:, -1, :]  # (num_beams, vocab_size)

                    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
                    all_log_probs.append(log_probs)

            # === Shallow fusion: average log_probs ===
            avg_log_probs = torch.stack(all_log_probs).mean(dim=0)

            # === Beam step ===
            next_beam_scores, next_tokens = torch.topk(avg_log_probs, 2, dim=1)
            next_beam_scores += beam_scores[:, None]

            next_beam_scores = next_beam_scores.view(1, -1)
            next_tokens = next_tokens.view(1, -1)
            next_indices = torch.arange(num_beams, device=self.device).repeat_interleave(2).view(1, -1)

            decoder_input_ids_for_process = self.decoder_input_ids_list[0]
            
            beam_outputs = beam_scorer.process(
                decoder_input_ids_for_process,
                next_beam_scores,
                next_tokens,
                next_indices,
                eos_token_id=self.eos_token_id,
            )

            # === Update decoder_input_ids and beam scores ===
            for i in range(len(self.models)):
                self.decoder_input_ids_list[i] = torch.cat(
                    [
                        self.decoder_input_ids_list[i][beam_outputs.data["next_beam_indices"]],
                        beam_outputs.data["next_beam_tokens"].unsqueeze(-1),
                    ],
                    dim=-1,
                )

            beam_scores = beam_outputs.data["next_beam_scores"]
            cur_len += 1

            if beam_scorer.is_done:
                break

        # === Finalize hypotheses ===
        final_outputs = beam_scorer.finalize(
            self.decoder_input_ids_list[0],
            beam_scores,
            final_beam_tokens=None,
            final_beam_indices=None,
            max_length=cur_len,
            pad_token_id=self.pad_token_id,
            eos_token_id=self.eos_token_id
        )

        # === Decode translations ===
        translation = [self.tokenizer.decode(t, skip_special_tokens=False) for t in final_outputs.data["sequences"][0]]
        translations = self.tokenizer.batch_decode(final_outputs.data["sequences"], skip_special_tokens=True,)
        return translations


In [7]:
class BatchedShallowFusion():
    def __init__(self, models_info):
        self.models_info = models_info
        # === Load tokenizer and models ===
        # The tokenizer of the first model is used
        self.tokenizer = load_tokenizer(models_info[0]["name"])
        
        # tokenizer = AutoTokenizer.from_pretrained("LumiOpen/Viking-7B")
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.init_models()

    def init_models(self):
        self.models = []        
        for info in self.models_info:
            model, model_type = load_model(info["name"], self.device)
            self.models.append((model,model_type))

    def initialize_inputs(self, src_sentences_per_model, batch_size, num_beams):
        self.encoder_input_ids_list = []
        self.attention_masks_list = []
        self.decoder_input_ids_list = []

        self.eos_token_id = self.tokenizer.eos_token_id
        self.pad_token_id = self.tokenizer.pad_token_id  # Marian uses pad token to start decoding
        self.start_token_id = self.tokenizer.pad_token_id  # Marian uses pad token to start decoding

        for (sentences, terms_list), (model, model_type) in zip(src_sentences_per_model, self.models):
            # Tokenize and prepare inputs for all sentences in the batch
            if terms_list and terms_list[0]:  # Check if terms are provided
                # Assuming augment_tokenize can handle batch processing or needs to be modified
                enc_inputs = [augment_tokenize(sentence, terms, self.tokenizer, self.device) 
                              for sentence, terms in zip(sentences, terms_list)]
                encoder_input_ids = torch.cat([x["input_ids"] for x in enc_inputs], dim=0)
                attention_mask = torch.cat([x["attention_mask"] for x in enc_inputs], dim=0)
            else:
                enc_inputs = self.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True).to(self.device)
                encoder_input_ids = enc_inputs["input_ids"]
                attention_mask = enc_inputs["attention_mask"]
            
            # Expand for beam search
            encoder_input_ids = encoder_input_ids.unsqueeze(1).expand(-1, num_beams, -1).reshape(batch_size * num_beams, -1)
            attention_mask = attention_mask.unsqueeze(1).expand(-1, num_beams, -1).reshape(batch_size * num_beams, -1)

            self.encoder_input_ids_list.append(encoder_input_ids)
            self.attention_masks_list.append(attention_mask)

            if model_type == "marian":
                # Init decoder_input_ids with a start token for each beam in each batch
                decoder_input_ids = torch.full(
                    (batch_size * num_beams, 1),
                    fill_value=self.start_token_id,
                    dtype=torch.long,
                    device=self.device,
                )
            else:
                # LLMs keep track of decoder input ids, but concat them with the input for generation
                decoder_input_ids = torch.empty((batch_size * num_beams, 0), dtype=torch.long, device=self.device)
            self.decoder_input_ids_list.append(decoder_input_ids)
    
    def translate(self, src_sentences_per_model, num_beams=4, max_length=50):
        # Determine batch size from input
        batch_size = len(src_sentences_per_model[0][0])
        self.initialize_inputs(src_sentences_per_model, batch_size, num_beams)
        
        # === Beam search setup ===
        beam_scorer = BeamSearchScorer(
            batch_size=batch_size,
            num_beams=num_beams,
            device=self.device,
            length_penalty=1.0,
            do_early_stopping=True,
            num_beam_hyps_to_keep=num_beams,
        )

        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=self.device)
        beam_scores[:, 1:] = -1e9  # Only first beam is active at the beginning
        beam_scores = beam_scores.view(-1)  # Flatten to (batch_size * num_beams,)

        cur_len = 1  # decoder_input_ids starts with 1 token

        # === Step-by-step beam search loop ===
        while cur_len < max_length:
            all_log_probs = []

            # === Each model provides logits from its own source + decoder input ===
            for (model, model_type), encoder_input_ids, attention_mask, decoder_input_ids in zip(
                self.models, self.encoder_input_ids_list, self.attention_masks_list, self.decoder_input_ids_list
            ):
                with torch.no_grad():
                    if model_type == "marian":
                        outputs = model(
                            input_ids=encoder_input_ids,
                            attention_mask=attention_mask,
                            decoder_input_ids=decoder_input_ids,
                        )
                        logits = outputs.logits[:, -1, :]  # (batch_size * num_beams, vocab_size)
                        logits[:, [self.pad_token_id]] = float("-inf")
                    else:
                        input_ids = torch.cat([encoder_input_ids, decoder_input_ids], dim=-1)
                        outputs = model(
                            input_ids=input_ids,
                            attention_mask=attention_mask,
                        )
                        logits = outputs.logits[:, -1, :]  # (batch_size * num_beams, vocab_size)

                    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
                    all_log_probs.append(log_probs)

            # === Shallow fusion: average log_probs ===
            avg_log_probs = torch.stack(all_log_probs).mean(dim=0)

            # === Beam step ===
            next_beam_scores, next_tokens = torch.topk(avg_log_probs, 2, dim=1)
            
            # Reshape scores and tokens for beam processing
            next_beam_scores = next_beam_scores.view(batch_size, num_beams * 2)
            next_tokens = next_tokens.view(batch_size, num_beams * 2)
            
            # Prepare beam scores for addition
            beam_scores = beam_scores.view(batch_size, num_beams)
            beam_scores = beam_scores.unsqueeze(-1).expand(-1, -1, 2).reshape(batch_size, num_beams * 2)
            next_beam_scores = next_beam_scores + beam_scores

            # Prepare for beam scorer
            next_indices = torch.arange(num_beams, device=self.device).repeat(2 * batch_size).view(batch_size, -1)

            # the shape is incorrect, should be batch_size * num_beams
            decoder_input_ids_for_process = self.decoder_input_ids_list[0]
            
            beam_outputs = beam_scorer.process(
                decoder_input_ids_for_process,
                next_beam_scores,
                next_tokens,
                next_indices,
                eos_token_id=self.eos_token_id,
                pad_token_id=self.pad_token_id
            )

            # === Update decoder_input_ids and beam scores ===
            for i in range(len(self.models)):
                self.decoder_input_ids_list[i] = torch.cat(
                    [
                        self.decoder_input_ids_list[i][beam_outputs.data["next_beam_indices"]],
                        beam_outputs.data["next_beam_tokens"].unsqueeze(-1),
                    ],
                    dim=-1,
                )
                
                # Update the attention mask for LLMs (no need to update Marian masks, since they use
                # the same input masks throughout)
                if self.models[i][1] != "marian":
                    new_mask = torch.ones_like(beam_outputs.data["next_beam_tokens"].unsqueeze(-1))
                    self.attention_masks_list[i] = torch.cat([
                        self.attention_masks_list[i],
                        new_mask
                    ], dim=-1)

            beam_scores = beam_outputs.data["next_beam_scores"].view(-1)
            cur_len += 1

            if beam_scorer.is_done.all():
                break

        # === Finalize hypotheses ===
        final_outputs = beam_scorer.finalize(
            self.decoder_input_ids_list[0],
            beam_scores,
            final_beam_tokens=None,
            final_beam_indices=None,
            max_length=cur_len,
            pad_token_id=self.pad_token_id,
            eos_token_id=self.eos_token_id
        )

        # === Decode translations ===
        translations = self.tokenizer.batch_decode(final_outputs.data["sequences"], skip_special_tokens=True)
        return translations

In [None]:
# === Define models and their unique source inputs ===
models_info = [
    {
        "name": "LumiOpen/Viking-7B"
    },
    {
        "name": "../converted-vocab_fix_model_viking7",
    },
]

#shallow_fusion = BatchedShallowFusion(models_info)
shallow_fusion = ShallowFusion(models_info)



In [None]:
def viking_template(sentence):
    return (f"<|im_start|>user\nTranslate into Finnish: {sentence}<|im_end|>\n<|im_start|>assistant\n",[])

def marian_llmvoc_template(sentence):
    return (sentence + "</s>", [])

In [None]:
import os
os.environ["PYDEVD_WARN_SLOW_RESOLVE_TIMEOUT"] = "100"

# TODO: Test using custom logits processor

"""test_sents = [
    "This is a test.",
    "The ice was melting.",
    "Tunnels shudder as the animals fly through them.",
    "Tunnels shudder as the bats fly through them."]"""

test_sents = [
    "Tunnels shudder as the bats fly through them."]
input_sentences = [
    list(zip(*[viking_template(test_sent) for test_sent in test_sents])),
    list(zip(*[marian_llmvoc_template(test_sent) for test_sent in test_sents]))
]
print(input_sentences)
#translations = shallow_fusion.translate(input_sentences)
translations = shallow_fusion.translate((viking_template(test_sents[0]), marian_llmvoc_template(test_sents[0])))

print("\n🌍 Shallow Fusion Translations (Multi-Input):")
print("\n".join(translations))

[[('<|im_start|>user\nTranslate into Finnish: Firefighters at blaze caused by a Shahed drone attack on the Ukrainian Red Cross base.<|im_end|>\n<|im_start|>assistant\n',), ([],)], [('Firefighters at blaze caused by a Shahed drone attack on the Ukrainian Red Cross base.</s>',), ([],)]]

🌍 Shallow Fusion Translations (Multi-Input):
Palomiehet, jotka ovat joutuneet hyökkäyksestä Ukrainan Punaisen Ristin tukikohtaan.
Palomiehet, jotka ovat joutuneet hyökkäyksestä Ukrainan Punaisen Ristin tukikohtaan.[2]
Palomiehet, jotka ovat joutuneet hyökkäyksestä Ukrainan Punaisen Ristin tukikohtaan.[1]
Palomiehet, jotka ovat joutuneet hyökkäyksestä Ukrainan Punaisen Ristin tukikohtaan.[1][2]
