In [61]:
%pip install transformers
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
%pip install sentencepiece

Note: you may need to restart the kernel to use updated packages.
Looking in indexes: https://download.pytorch.org/whl/cu128
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [62]:
import torch
from transformers import LogitsProcessor, MarianTokenizer, MarianMTModel, AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, AutoConfig

In [63]:
def load_model(model_name_or_path, device):
    config = AutoConfig.from_pretrained(model_name_or_path)

    # Check model type by architecture string
    if "Marian" in config.model_type or config.architectures and any("Marian" in arch for arch in config.architectures):
        model = MarianMTModel.from_pretrained(model_name_or_path).to(device).eval()
        model_type = "marian"
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16).to(device).eval()
        model_type = "causal_lm"

    return model, model_type

In [64]:
def load_tokenizer(model_name_or_path):
    config = AutoConfig.from_pretrained(model_name_or_path)

    # Check model type by architecture string
    if "Marian" in config.model_type or config.architectures and any("Marian" in arch for arch in config.architectures):
        tokenizer = MarianTokenizer.from_pretrained(model_name_or_path)
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

    return tokenizer

In [None]:
def augment_tokenize(text, terms, tokenizer, padding_side, device="cpu"):
    #TODO: implement batching, add padding according to padding_side parameter
    vocab = tokenizer.get_vocab()
    text_tokenized = tokenizer(text).input_ids
    for term_source,term_target in terms:
        term_source_tokenized = tokenizer(term_source).input_ids[:-1]
        term_target_tokenized = list(tokenizer(text_target=term_target).input_ids)[:-1]
        term_target_tokenized = [vocab["augmentsymbol1"]] + \
            term_target_tokenized + \
            [vocab["augmentsymbol2"]]
            
        current_aug_part_index = 0
        new_text_tokenized = []
        for token in text_tokenized:
            #TODO: add check for the word continuing
            if current_aug_part_index == len(term_source_tokenized):
                new_text_tokenized += [vocab["augmentsymbol0"]] + term_source_tokenized + \
                term_target_tokenized
                current_aug_part_index = 0
            if token == term_source_tokenized[current_aug_part_index]:
                current_aug_part_index += 1
            elif current_aug_part_index > 1:
                new_text_tokenized += term_source_tokenized[0:current_aug_part_index]
                new_text_tokenized.append(token)
                current_aug_part_index = 0
            else:
                new_text_tokenized.append(token)
        text_tokenized = new_text_tokenized

    input_ids = torch.tensor([text_tokenized], device=device)  # batch dimension added
    attention_mask = torch.ones_like(input_ids, device=device)

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask
    }

In [66]:
def sum_ignore_inf(tensor):
    # Create a mask for finite values (not -inf or inf)
    finite_mask = torch.isfinite(tensor)
    
    # Apply the mask to get only finite values, setting others to 0
    finite_values = torch.where(finite_mask, tensor, torch.zeros_like(tensor))
    
    # Sum all finite values
    total_sum = torch.sum(finite_values)
    
    return total_sum

tensor = torch.randn(160, 50000)
tensor[0, 0] = -float('inf')  # Add an -inf for testing
result = sum_ignore_inf(tensor)
print(result)

tensor(1945.0933)


In [97]:
class MultiInputLogitsProcessor(LogitsProcessor):
    def __init__(self, models, model_types, tokenizer, models_info, average_mode="probs", only_main_model=False):
        self.models = models
        self.model_types = model_types
        self.tokenizer = tokenizer
        self.models_info = models_info
        self.average_mode = average_mode
        self.current_inputs = {}  # Will store prepared inputs for each model
        self.only_main_model = only_main_model

    def viking_template(self, sentence):
        return f"<|im_start|>user\nTranslate into Finnish: {sentence}<|im_end|>\n<|im_start|>assistant\n"

    def marian_llmvoc_template(self, sentence):
        return sentence + "</s>"

    def prepare_inputs(self, src_sentences, num_beams):
        """Prepare and store model-specific inputs, expanded for beam search"""
        self.current_inputs = {}
        batch_size = len(src_sentences)
        
        for i, (model, model_type) in enumerate(zip(self.models, self.model_types)):
            info = self.models_info[i]
            
            # If this is a Marian with LLM vocab, apply template
            if model_type == "marian" and not isinstance(self.tokenizer, MarianTokenizer):
                templated_src_sentences = [self.marian_llmvoc_template(x) for x in src_sentences]
                padding_side = "right"
            elif model_type != "marian":
                templated_src_sentences = [self.viking_template(x) for x in src_sentences]
                padding_side = "left"
            else:
                templated_src_sentences = src_sentences
                padding_side = "right"
            if info.get("terms"):
                inputs = augment_tokenize(
                    templated_src_sentences, 
                    info["terms"], 
                    self.tokenizer, 
                    model.device,
                    padding_side
                )
            else:
                inputs = self.tokenizer(
                    templated_src_sentences, 
                    return_tensors="pt", 
                    padding=True,
                    truncation=True,
                    padding_side=padding_side
                ).to(model.device)
            
            # Expand inputs for beam search, except for main model
            if i != 0:
                encoder_input_ids = inputs["input_ids"].unsqueeze(1).expand(-1, num_beams, -1)
                encoder_input_ids = encoder_input_ids.reshape(batch_size * num_beams, -1)
                
                attention_mask = inputs["attention_mask"].unsqueeze(1).expand(-1, num_beams, -1)
                attention_mask = attention_mask.reshape(batch_size * num_beams, -1)
            else:
                encoder_input_ids = inputs["input_ids"]
                attention_mask = inputs["attention_mask"]
                
            self.current_inputs[i] = {
                "encoder_input_ids": encoder_input_ids,
                "attention_mask": attention_mask,
                "original_batch_size": batch_size
            }

    def __call__(self, input_ids, scores):
        # TODO: add batched stopping for LLMs based on line break
        # Strip input from input_ids for Marian decoding
        
        if self.only_main_model:
            return scores
        
        avg_probs = self._average_probs(input_ids, scores)
        
        difference = scores-avg_probs
        summed_difference = sum_ignore_inf(difference)
        return avg_probs

    def _average_probs(self, input_ids, scores):
        """Average probabilities from all models (log space)"""
        batch_size_times_beams = input_ids.shape[0]
        vocab_size = scores.shape[-1]
        all_probs = torch.zeros((len(self.models), batch_size_times_beams, vocab_size),
                              device=scores.device)
        
        for i, (model, model_type) in enumerate(zip(self.models, self.model_types)):
            if i == 0:
                all_probs[i] = torch.exp(scores)
            else:
                logits = self._get_model_logits(
                    model,
                    model_type,
                    self.current_inputs[i]["encoder_input_ids"],
                    self.current_inputs[i]["attention_mask"],
                    input_ids,
                    i
                )
                all_probs[i] = torch.nn.functional.softmax(logits, dim=-1)
        mean_log_props = torch.log(all_probs.mean(dim=0))
        return mean_log_props
        #return torch.logsumexp(all_probs, dim=0) - torch.log(torch.tensor(len(self.models), device=scores.device))

    def _get_model_logits(self, model, model_type, encoder_inputs, attention_mask, input_ids, model_idx):
        """Get logits from a single model"""
        with torch.no_grad():
            if model_type == "marian":
                outputs = model(
                    input_ids=encoder_inputs,
                    attention_mask=attention_mask,
                    decoder_input_ids=input_ids,
                )
            else:
                input_ids_full = torch.cat([encoder_inputs, input_ids], dim=-1)
                attention_mask_full = torch.cat([
                    attention_mask,
                    torch.ones_like(input_ids)
                ], dim=-1)
                
                outputs = model(
                    input_ids=input_ids_full,
                    attention_mask=attention_mask_full,
                )
            
            logits = outputs.logits[:, -1, :]
            logits[:, [self.tokenizer.pad_token_id]] = -float('inf')
            return logits

In [68]:
class ModelGroup():
    def __init__(self, models_info, tokenizer_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.models = []
        self.model_types = []
        self.tokenizer = load_tokenizer(tokenizer_path)
        
        for info in models_info:
            model, model_type = load_model(info["name"], self.device)
            self.models.append(model)
            self.model_types.append(model_type)

In [98]:
class ShallowFusion:
    def __init__(self, models_info, model_group, main_model_idx=0):
        self.models_info = models_info
        self.main_model_idx = main_model_idx
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.models = model_group.models
        self.model_types = model_group.model_types
        self.tokenizer = model_group.tokenizer
    
    def translate(self, src_sentences, num_beams=4, max_length=50, only_main_model=False):
        # Initialize logits processor
        logits_processor = MultiInputLogitsProcessor(
            models=self.models,
            model_types=self.model_types,
            tokenizer=self.tokenizer,
            models_info=self.models_info,
            only_main_model=only_main_model)
        
        # Prepare all model inputs
        logits_processor.prepare_inputs(src_sentences, num_beams)
        
        # Get main model components
        main_model = self.models[self.main_model_idx]
        main_inputs = logits_processor.current_inputs[self.main_model_idx]
        
        # Generate with ensemble
        if self.model_types[self.main_model_idx] == "marian":
            outputs = main_model.generate(
                input_ids=main_inputs["encoder_input_ids"],
                attention_mask=main_inputs["attention_mask"],
                num_beams=num_beams,
                max_length=max_length,
                logits_processor=[logits_processor],
                early_stopping=True,
                eos_token_id=self.tokenizer.eos_token_id,
                use_cache=False,
                pad_token_id=self.tokenizer.pad_token_id,
            )
        else:
            outputs = main_model.generate(
                input_ids=main_inputs["encoder_input_ids"],
                attention_mask=main_inputs["attention_mask"],
                num_beams=num_beams,
                max_length=max_length,
                logits_processor=[logits_processor],
                early_stopping=True,
                eos_token_id=23,
                pad_token_id=self.tokenizer.pad_token_id
            )
        
        return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)

In [93]:
models_info = [
    {"name": "../termmodel.eng-fin", "terms": None},  # Main model
    {"name": "../termmodel.eng-fin", "terms": None},  # Aux model 1
]

#model_group = ModelGroup(models_info, "LumiOpen/Viking-7B")
model_group = ModelGroup(models_info, "../termmodel.eng-fin")

In [None]:
#TODO: Why does plain generate produce junk but ensemble not? Move on to testing with terms!
ensemble = ShallowFusion(models_info, model_group)

test_sents = ["This is a test.","There was a storm in Spain.","We did not anticipate a crowd."]
num_beams = 8
translations = ensemble.translate(
    test_sents,
    num_beams=num_beams,
    max_length=100,
    only_main_model=False
)

print(translations)
#for translation in translations:
#    print(translation.split("\n")[-1])

"""base_translations = ensemble.translate(
    test_sents,
    num_beams=num_beams,
    max_length=200,
    only_main_model=True
)
print(base_translations)"""


['Tämä on testi.', 'Espanjassa oli myrsky.', 'Emme ennakoineet väkijoukkoa.']


'base_translations = ensemble.translate(\n    test_sents,\n    num_beams=num_beams,\n    max_length=200,\n    only_main_model=True\n)\nprint(base_translations)'

In [69]:
device="cuda"

model, model_type = load_model("../termmodel.eng-fin", device)
tokenizer = load_tokenizer("../termmodel.eng-fin")

test_sents = ["Firefighters at blaze caused by a Shahed drone attack on the Ukrainian Red Cross base"]
inputs = tokenizer(
                    test_sents, 
                    return_tensors="pt", 
                    padding=False,
                    truncation=True
                ).to(model.device)

for i in range(1,4):
    output = model.generate(**inputs,num_beams=i)
    print(f"{i}: {tokenizer.batch_decode(output, skip_special_tokens=True)}")

1: ['Palomiehet tulipalossa, jotka johtuivat Shahed-lennokkihyökkäyksestä Ukrainan Punaisen Ristin tukikohtaan']
2: ['Tulihälytyshävittäjät, jotka ovat saaneet alkunsa Ukrainan Punaisen Ristin tukikohdan hyökkäyksestä']
3: ['Tulihälytyshävittäjät, jotka ovat joutuneet hyökkäykseen Ukrainan Punaiseen Ristiin']
