In [1]:
!pwd

/auto/brno2/home/rahmang/xcomet/COMET


In [2]:
!nvidia-smi

Wed Mar 19 16:55:59 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A40                     On  |   00000000:61:00.0 Off |                    0 |
|  0%   37C    P8             22W /  300W |       1MiB /  46068MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [1]:
from comet import download_model, load_from_checkpoint
from comet.models.multitask.unified_metric import UnifiedMetric
from comet.models.utils import Prediction
from typing import Dict, Optional
from typing import List, Dict
from typing import Union, Tuple
from collections import defaultdict
import numpy
import inspect 
import torch
import torch.nn as nn  # <-- Add thiss
class CustomXCOMET(UnifiedMetric):

    def prepare_sample(
        self, sample: List[Dict[str, Union[str, float]]], stage: str = "fit"
    ) -> Union[Tuple[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]]:
        """Tokenizes input data and prepares targets for training.

        Args:
            sample (List[Dict[str, Union[str, float]]]): Mini-batch
            stage (str, optional): Model stage ('train' or 'predict'). Defaults to "fit".

        Returns:
            Union[Tuple[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]]: Model input
                and targets.
        """
        # print("=================++++++++++++++++++++++++++++++++++==========================")

        # Get the caller's function name and file location
        # caller_frame = inspect.stack()[1]
        # caller_function = caller_frame.function
        # caller_file = caller_frame.filename
    
        # # Print caller details
        # print(f"prepare_sample called by: {caller_function} (from {caller_file})")
        # print("sample in the prepare_sample: ", sample)
        # print("the stage is: ", stage)
        # for k in sample[0]:
        #     print("K is: ", k)
        inputs = {k: [d[k] for d in sample] for k in sample[0]}
        print(f'''inputs["mt"]: {inputs["mt"]}''')
        # only this will return word_ids from self.encoder.prepare_sample ->
        # self.encoder.subword_tokenize as it is set self.word_level
        # for src and ref, self.word_level = False by default
        input_sequences = [
            self.encoder.prepare_sample(inputs["mt"], self.word_level, None),
        ]
        input_sequences_mt = input_sequences.copy()  # Now independent of input_sequences #added by me
        print("input input_sequences just with MT: in the prepare_sample in unified_metric ", input_sequences_mt)

        src_input, ref_input = False, False
        if ("src" in inputs) and ("src" in self.hparams.input_segments):
            input_sequences.append(self.encoder.prepare_sample(inputs["src"]))
            src_input = True

        if ("ref" in inputs) and ("ref" in self.hparams.input_segments):
            input_sequences.append(self.encoder.prepare_sample(inputs["ref"]))
            ref_input = True
        # print("input_sequences after adding source and ref: ")
        # for inp in input_sequences:
        #     print(inp)
        # print("input_sequences after adding source and ref: ")
        unified_input = src_input and ref_input
        model_inputs = self.concat_inputs(input_sequences, unified_input) #updated unified_metric's
        #concat_inputs function to return word_ids
        #print("model inputs ++++++++++++++: ", model_inputs)
        if stage == "predict":
            #print("word ids: ", model_inputs[])
            #return model_inputs["inputs"]
            #, model_inputs["word_ids"] #model_inputs["word_ids"] added by me
            all_inputs = model_inputs["inputs"] 
            #added by me
            words_id_dict = {
                "words_id": model_inputs["word_ids"],
                "mt_sentences": inputs["mt"],
                "mt_sentences_tokenized": input_sequences_mt
            }
            updated = all_inputs + (words_id_dict,)
            #print("updated dict ========", updated)
            #Update the OrderedDict
            model_inputs["inputs"] = updated
            return model_inputs["inputs"]
        scores = [float(s) for s in inputs["score"]]
        targets = Target(score=torch.tensor(scores, dtype=torch.float))

        if "system" in inputs:
            targets["system"] = inputs["system"]

        if self.word_level:
            # Labels will be the same accross all inputs because we are only
            # doing sequence tagging on the MT. We will only use the mask corresponding
            # to the MT segment.
            seq_len = model_inputs["mt_length"].max()
            targets["mt_length"] = model_inputs["mt_length"]
            targets["labels"] = model_inputs["inputs"][0]["label_ids"][:, :seq_len]

        return model_inputs["inputs"], targets
        
    def word_level_prob(
        self,
        subword_probs: torch.Tensor,
        word_ids: Dict[
        str,
        Union[
            List[List[Optional[int]]],  # words_id
            List[str],  # mt_sentences
            List[Dict[
                str,
                Union[
                    torch.Tensor,  # input_ids/label_ids/attention_mask
                    List[List[Tuple[int, int]]],  # offsets
                    List[List[Optional[int]]]  # word_ids
            ]]
        ]
    ]
        ]
    ) -> List[List[Dict[str, float]]]:
        """ Returns word level probability score
        word_ids = {
    'words_id': [
        [None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, None],
        [None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, None]
    ],
    'mt_sentences': [
        "Can I receive my food in 10 to 15 minutes?",
        "Can you send it for 10 to 15 minutes?"
    ],
    'mt_sentences_tokenized': [
        {
            'input_ids': tensor([
                [0, 4171, 87, 53299, 759, 15381, 23, 209, 47, 423, 14633, 32, 2],
                [0, 4171, 398, 25379, 442, 100, 209, 47, 423, 14633, 32, 2, 1]
            ], device='cuda:0'),
            'label_ids': tensor([
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1]
            ], device='cuda:0'),
            'attention_mask': tensor([
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
            ], device='cuda:0'),
            'offsets': [
                [(0, 0), (0, 3), (3, 5), (5, 13), (13, 16), (16, 21), 
                 (21, 24), (24, 27), (27, 30), (30, 33), (33, 41), (41, 42), (0, 0)],
                [(0, 0), (0, 3), (3, 7), (7, 12), (12, 15), (15, 19), 
                 (19, 22), (22, 25), (25, 28), (28, 36), (36, 37), (0, 0)]
            ],
            'word_ids': [
                [None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, None],
                [None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, None]
            ]
            }
                ]
                    }

        """
        tokenizer = self.encoder.tokenizer
        # ====== Reconstruct MT sentence from batch ======
        # input_ids = batch[0]["input_ids"]  # Tokenized MT input
        # mt_sentence = self.encoder.tokenizer.decode(
        #         input_ids[0],
        #         skip_special_tokens=True,
        #         clean_up_tokenization_spaces=True
        # )
        # print("mt sentence:++++++++++++++++++ ", mt_sentence)

        
        ## run over the mt sentences in the dict word_ids
        word_level_prob = []
        all_tokenized_sentences = []
        for index, item in enumerate(word_ids["words_id"]):
            mt_sentence = word_ids["mt_sentences"][index]
            print("mt_sentence: ", mt_sentence)
            # Tokenize the MT sentence to get subword-to-token alignment
            tokenized = self.encoder.tokenizer(
                    mt_sentence,
                    return_offsets_mapping=True,
                    return_tensors="pt",
                    truncation=True,
                            
            )
        
            #print("tokenized: ", tokenized)
            #subword_ids = tokenized.word_ids()
            #subword_ids = subword_ids[:seq_len]
            #subword_ids[-1] = None
            #print("subword_ids after mt sentence extractions: ", subword_ids)
            # Group subword probabilities by original tokens
            subword_ids = item
            token_probs = {}
            print(f"subword_probs[{index}]:  {subword_probs[index]}")
            attention_mask = word_ids["mt_sentences_tokenized"][0]["attention_mask"][index]
            for idx, prob in enumerate(subword_probs[index]):
                # if attention_mask[index] == 0:
                #     break
                if idx >= len(subword_ids):
                    break
                subword_idx = subword_ids[idx]
                if subword_idx is None:  # Skip special tokens
                    continue
                if subword_idx not in token_probs:
                    token_probs[subword_idx] = []
                token_probs[subword_idx].append(prob.cpu().numpy())
            print("token_probs: ", token_probs)
      
            # Aggregate probabilities (average for each class)
            token_level_probs = []
            for token_idx in sorted(token_probs.keys()):
                # Stack subword probabilities for this token
                subword_probs_for_token = torch.stack([torch.tensor(p) for p in token_probs[token_idx]])
        
                # Compute mean across subwords (dim=0 → average over subwords, per class)
                mean_probs = torch.mean(subword_probs_for_token, dim=0)
        
                token_level_probs.append(mean_probs.numpy())
            print("token_level_probs: ", token_level_probs)
    
            # After computing token_level_probs:
    
            # Tolerance for floating-point errors (e.g., 1e-3 = 0.1% tolerance)
            tolerance = 1e-3
    
            for token_idx, probs in enumerate(token_level_probs):
                total = numpy.sum(probs)
                if not numpy.isclose(total, 1.0, atol=tolerance):
                    print(f"Token {token_idx} probabilities sum to {total:.4f} (expected ~1.0)")
                else:
                    print(f"Token {token_idx} probabilities sum to {total:.4f}")

                        
            # Extract word IDs (index of the original word for each token)
            #word_ids = tokenized.word_ids()[:mt_length]    #  [None, 0, 0, 1, 1, 2, ...]
    
            # Convert token IDs to tokens (subwords)
            tokens = tokenizer.convert_ids_to_tokens(tokenized["input_ids"][0])
            print("tokens: ",tokens)
            # Group tokens by their word ID
            word_to_tokens = {}
            #print("word_ids: ", subword_ids)
            for idx, word_id in enumerate(subword_ids):
                if word_id is None:
                    continue  # Skip special tokens like [CLS], [SEP]
                if word_id not in word_to_tokens:
                    word_to_tokens[word_id] = []
                word_to_tokens[word_id].append(tokens[idx])
            print("word_to_tokens: ", word_to_tokens)
    
            print("sorted(word_to_tokens.keys()) :", sorted(word_to_tokens.keys()))
            # Reconstruct original words from grouped tokens
            word_mapping = []
            for word_id in sorted(word_to_tokens.keys()):
                tokens = word_to_tokens[word_id]
                # Merge subwords into a single string (handles ## prefixes)
                word = tokenizer.convert_tokens_to_string(tokens).strip()
                word_mapping.append(word)
    
            # Print results
            print("Tokenized Words:", word_mapping)
            all_tokenized_sentences.append(word_mapping)
            # Map tokens to probabilities
            token_predictions = [
                    {"token": token, "probabilities": probs.tolist()}
                    for token, probs in zip(word_mapping, token_level_probs)
            ]
    
            # print("Token-Level Probabilities:")
            for pred in token_predictions:
                print(f"{pred['token']}: {pred['probabilities']}")
            print("first sentence finished=====================")
            word_level_prob.append(token_predictions)
        return word_level_prob, all_tokenized_sentences

    def correct_span(
        self,
        track_token_to_words: List[int],
        mt_offsets: List[Tuple[int, int]],
        word_ids: Dict,
        Tokenized_Words: List[List[str]]
    ):
        # track_token_to_words = [-1, -1, -1, 3, 4, -1, -1, -1, 8, -1, 10, 11, 12, 13, 14, -1]
        # mt_offsets = [[(0, 0), (0, 5), (5, 13), (13, 17), (17, 21), (21, 25), (25, 26), (26, 27), (27, 38), (38, 39), (39, 46), (46, 49), (49, 53), (53, 57), (57, 59), (0, 0)]]
        # word_ids =  [None, 0, 1, 2, 3, 4, 4, 4, 4, 4, 5, 6, 7, 8, 8, None]
        mapping = {}
        print("==========================================")
        print("mt_offsets in the correct_span: ", mt_offsets)
        for index, item in enumerate(word_ids):
            if item is None:
                continue
        
            if item in mapping:
                # Append to existing lists
                mapping[item]['subwords'].append(index)
                # last_offset = mt_offsets[0][index][1]
                last_offset = mt_offsets[index][1]
                mapping[item]['offsets'][1] = last_offset
            else:
                # Initialize new entry
                
                mapping[item] = {
                    'subwords': [index],
                    'offsets': list(mt_offsets[index])
                }
        
        print(mapping)
                
        start = False
        from collections import OrderedDict
        words_in_span = []  # Creating an ordered set
        all_word_spans = defaultdict()
        set_to_check_multiple_subwords = set()
        index = 0
        for item in track_token_to_words:
            if item == -1:
                if start == True:
                    start = False
                    text = ""
                    print(words_in_span)
                    for item in words_in_span:
                        text += f" {Tokenized_Words[item]}"
                    print(text)
                    word_span = defaultdict()
                    word_span['text'] = text.strip()
                    word_span['start'] = mapping[words_in_span[0]]['offsets'][0]
                    word_span['end'] = mapping[words_in_span[-1]]['offsets'][1]
                    print("word span: ", word_span)
                    all_word_spans[index] = word_span
                    index += 1
                    words_in_span= []
            else:
                print("item: ", item)
                start = True
                word = word_ids[item]
                if word not in set_to_check_multiple_subwords:
                    set_to_check_multiple_subwords.add(word)
                    words_in_span.append(word)
                    
        print("all_word_spans: ", all_word_spans)
        return all_word_spans


        
    def decode(
        self,
        subword_probs: torch.Tensor,
        input_ids: torch.Tensor,
        mt_offsets: torch.Tensor,
        word_id:Dict[
        str,
        Union[
            List[List[Optional[int]]],  # words_id
            List[str],  # mt_sentences
            List[Dict[
                str,
                Union[
                    torch.Tensor,  # input_ids/label_ids/attention_mask
                    List[List[Tuple[int, int]]],  # offsets
                    List[List[Optional[int]]]  # word_ids
            ]]
        ]
    ]
        ]  # Added by me
    ) -> List[Dict]:
        """Decode error spans from subwords.

        Args:
            subword_probs (torch.Tensor): probabilities of each label for each subword.
            input_ids (torch.Tensor): input ids from the model.
            mt_offsets (torch.Tensor): subword offsets.
            word_id(dict): A dictionary that contains words_id mapping to all MT sentences,
            raw MT sentences and tokenized mt sentences
        Return:
            List with of dictionaries with text, start, end, severity and a
            confidence score which is the average of the probs for that label.
        """
        print("====================== decode function ========================")
        # print("subword_probs: ", subword_probs)
        # print("input_ids: ", input_ids)
        # print("mt_offsets: ", mt_offsets)
        decoded_output = []
        decoded_output_corrected = []
        print("now the word_level_prob function is being called: ")
        word_level_prob, all_tokenized_sentences = self.word_level_prob(subword_probs,word_id)
        #all_tokenized_sentences:List[List[str]]
        print("all_tokenized_sentences: ", all_tokenized_sentences)
        #print("length of mt_offsets: ", len(mt_offsets))
        for i in range(len(mt_offsets)):
            print("the value of i is ========= before inner for loop: ", i)
            seq_len = len(mt_offsets[i])
            #print("seq_len: ", seq_len)
            error_spans, in_span, span = [], False, {}

            track_token_to_words = []
            count_index = 0
            for token_id, probs, token_offset in zip(
                input_ids[i, :seq_len], subword_probs[i][:seq_len], mt_offsets[i]
            ):  
                #print("the value of i is: ", i)
                #print("token_id :", token_id, ", probs: ", probs, ", token_offset: ", token_offset)
                if self.decoding_threshold:
                    if torch.sum(probs[1:]) > self.decoding_threshold:
                        
                        print("token_id who has higher error sums than threshold: ",token_id)
                        print("and sum of it: ", torch.sum(probs[1:]))
                        probability, label_value = torch.topk(probs[1:], 1)
                        label_value += 1  # offset from removing label 0
                    else:
                        print("token_id who has higher error sums than threshold: ",token_id)

                        # This is just to ensure same format but at this point
                        # we will only look at label 0 and its prob
                        probability, label_value = torch.topk(probs[0], 1)
                        #print("probs[0] =============", probs[0])
                else:
                    print("no decoding threshold set")
                    probability, label_value = torch.topk(probs, 1)

                # Some torch versions topk returns a shape 1 tensor with only
                # a item inside
                label_value = (
                    label_value.item()
                    if label_value.dim() < 1
                    else label_value[0].item()
                )
                label = self.label_encoder.ids_to_label.get(label_value)
                #print("===================================================")
                #print("label: ", label)
                # Label set:
                # O I-minor I-major
                # Begin of annotation span
                if label.startswith("I") and not in_span:
                    in_span = True
                    span["tokens"] = [
                        token_id,
                    ]
                    span["severity"] = label.split("-")[1]
                    span["offset"] = list(token_offset)
                    #span["offset_word] = [list(token_offset)]
                    span["confidence"] = [
                        probability,
                    ]
                    span["check severity"] = [label.split("-")[1]]
                    track_token_to_words.append(count_index)
                    #span["word_indices"] =  set(word_id) if isinstance(word_id, list) else {word_id},  # Track word indices
                # Inside an annotation span
                elif label.startswith("I") and in_span:
                    span["tokens"].append(token_id)
                    span["confidence"].append(probability)
                    # Update offset end
                    span["offset"][1] = token_offset[1]
                    #span["offset_word] = [list(token_offset)]
                    span["check severity"].append(label.split("-")[1])
                    # if isinstance(word_id, list):
                    #     span["word_indices"].update(word_id)
                    # else:
                    #     span["word_indices"].add(word_id)
                    track_token_to_words.append(count_index)
                # annotation span finished.
                elif label == "O" and in_span:
                    error_spans.append(span)
                    in_span, span = False, {}
                    track_token_to_words.append(-1)
                #added by me
                elif label == "O" and not in_span:
                    track_token_to_words.append(-1)
                count_index = count_index + 1

            print("track_token_to_words: ", track_token_to_words)

            sentence_output = []
            for span in error_spans:
                # # Collect unique word indices for the span
                # unique_word_indices = sorted(set(span["word_indices"]) - {None})  # Remove None safely

                # # Extract words belonging to this specific span
                # span_words = [tokenized_words[idx] for idx in unique_word_indices]

                sentence_output.append(
                    {
                        
                        "text": self.encoder.tokenizer.decode(span["tokens"]),
                        #"text": " ".join(span_words),  # Use words instead of tokens
                        "confidence": torch.concat(span["confidence"]).mean().item(),
                        "severity": span["severity"],
                        "start": span["offset"][0],
                        "end": span["offset"][1],
                        "check severity": span["check severity"]
                    }
                )
            decoded_output.append(sentence_output)
            print("sentence_output: ", sentence_output)
            
            #get corrected word level error span
            corrected_span = self.correct_span(track_token_to_words,mt_offsets[i], word_id["words_id"][i],all_tokenized_sentences[i])
            sentence_out_word_level = []
            count = 0
            for error_span in sentence_output:
                sentence_out_word_level.append({
                    "text": corrected_span[count]['text'],
                    "confidence": error_span['confidence'],
                    "severity": error_span["severity"],
                    "start": corrected_span[count]['start'],
                    "end": corrected_span[count]['end'],
                })
                count += 1
                    
            print("sentence_out_word_level: ",sentence_out_word_level)
            decoded_output_corrected.append(sentence_out_word_level)
        print("decoded_output_corrected", decoded_output_corrected)
        #decoded_output.append(decoded_output_corrected)
        return decoded_output, decoded_output_corrected, word_level_prob

    
    def predict_step(
        self,
        batch: Dict[str, torch.Tensor],
        batch_idx: Optional[int] = None,
        dataloader_idx: Optional[int] = None,
    ) -> Prediction:
        """PyTorch Lightning predict_step

        Args:
            batch (Dict[str, torch.Tensor]): The output of your prepare_sample function
            batch_idx (Optional[int], optional): Integer displaying which batch this is
                    Defaults to None.
            dataloader_idx (Optional[int], optional): Integer displaying which
                dataloader this is. Defaults to None.

        Returns:
            Prediction: Model Prediction
        """
        # # Get the caller's function name and file location
        # caller_frame = inspect.stack()[1]
        # caller_function = caller_frame.function
        # caller_file = caller_frame.filename
    
        # # Print caller details
        # print(f"predict_step called by: {caller_function} (from {caller_file})")
        if len(batch) == 4: # after adding word_ids, the batch length will increase by 1S
            print("batch: ", batch)
            print("i am inside when len of the batch is 3")
            #print("word ids: ", batch[-1])
            # predictions = [self.forward(**input_seq) for input_seq in batch]
            # now the length of the batch is 4, 4th dictionary is the words_id 
            #dictionary, and we will skip the forward call to it
            predictions = [self.forward(**input_seq) for input_seq in batch[:-1]]
            
            #print("predictions: ", predictions)
            avg_scores = torch.stack([pred.score for pred in predictions], dim=0).mean(dim=0)
            #print("avg scores", avg_scores)
            batch_prediction = Prediction(
                scores=avg_scores,
                metadata=Prediction(
                    src_scores=predictions[0].score,
                    ref_scores=predictions[1].score,
                    unified_scores=predictions[2].score,
                ),
            )
            if self.word_level:
                mt_mask = batch[0]["label_ids"] != -1
                mt_length = mt_mask.sum(dim=1)
                seq_len = mt_length.max() 
                subword_probs = [
                    nn.functional.softmax(o.logits, dim=2)[:, :seq_len, :] * w
                    for w, o in zip(self.input_weights_spans, predictions)
                ]
                subword_probs = torch.sum(torch.stack(subword_probs), dim=0)
                
                ########################################################3
                ## create error span using decode function
                word_ids = batch[-1].copy()
                error_spans, corrected_error_spans, word_level_prob = self.decode(
                    subword_probs, batch[0]["input_ids"], batch[0]["mt_offsets"], word_ids
                    #,word_ids, word_mapping
                )
                batch_prediction.metadata["error_spans"] = error_spans
                batch_prediction.metadata["corrected_error_spans"] = corrected_error_spans
                batch_prediction.metadata["word_level_probability"]=word_level_prob
        else:
            print("i am inside when len of the batch is not 3")
            model_output = self.forward(**batch[0])
            batch_prediction = Prediction(scores=model_output.score)
            if self.word_level:
                mt_mask = batch[0]["label_ids"] != -1
                mt_length = mt_mask.sum(dim=1)
                seq_len = mt_length.max()
                subword_probs = nn.functional.softmax(model_output.logits, dim=2)[:, :seq_len, :]
                error_spans = self.decode(
                    subword_probs, batch[0]["input_ids"], batch[0]["mt_offsets"]
                )
                batch_prediction = Prediction(
                    scores=model_output.score,
                    metadata=Prediction(error_spans=error_spans),
                )
        return batch_prediction

# Load checkpoint into your custom class
path = "/storage/brno2/home/rahmang/xcomet/downloadedxcomet/models--Unbabel--XCOMET-XL/snapshots/50d428488e021205a775d5fab7aacd9502b58e64/checkpoints/model.ckpt"

model = CustomXCOMET.load_from_checkpoint(path,strict = False)
data = [
    {
        "src": "Boris Johnson teeters on edge of favour with Tory MPs",
        "mt": "Boris Johnson ist bei Tory-Abgeordneten völlig in der Gunst",
        "ref": "Boris Johnsons Beliebtheit bei Tory-MPs steht auf der Kippe"
    }
]
# data = [
#     {
#         "src": "10 到 15 分钟可以送到吗",
#         "mt": "Can I receive my food in 10 to 15 minutes?",
#         "ref": "Can it be delivered between 10 to 15 minutes?"
#     },
#     {
#         "src": "Pode ser entregue dentro de 10 a 15 minutos?",
#         "mt": "Can you send it for 10 to 15 minutes?",
#         "ref": "Can it be delivered between 10 to 15 minutes?"
#     }
# ]
model_output = model.predict(data, batch_size=8, gpus=1)
# Segment-level scores
print (model_output.scores)

# System-level score
print (model_output.system_score)

# Score explanation (error spans)
print (model_output.metadata.error_spans)

#added by me
print (model_output.metadata.corrected_error_spans)
print (model_output.metadata.word_level_probability)

  from .autonotebook import tqdm as notebook_tqdm
Encoder model frozen.
/storage/brno2/home/rahmang/envs/xcomet/lib/python3.11/site-packages/pytorch_lightning/core/saving.py:195: Found keys that are not in the model state dict but in the checkpoint: ['encoder.model.embeddings.position_ids']
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A40') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [GPU-ab9f9e85-81f2-18d0-531c-1b30853fbaad]
Predicting: 0it [00:00, ?it/s]

inputs["mt"]: ['Boris Johnson ist bei Tory-Abgeordneten völlig in der Gunst']
input input_sequences just with MT: in the prepare_sample in unified_metric  [{'input_ids': tensor([[     0,  67151,  59520,    443,   1079,   6653,     53,      9, 241832,
             19,  86454,     23,    122,  25706,    271,      2]]), 'label_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'offsets': [[(0, 0), (0, 5), (5, 13), (13, 17), (17, 21), (21, 25), (25, 26), (26, 27), (27, 38), (38, 39), (39, 46), (46, 49), (49, 53), (53, 57), (57, 59), (0, 0)]], 'word_ids': [[None, 0, 1, 2, 3, 4, 4, 4, 4, 4, 5, 6, 7, 8, 8, None]]}]


Predicting DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

batch:  ({'input_ids': tensor([[     0,  67151,  59520,    443,   1079,   6653,     53,      9, 241832,
             19,  86454,     23,    122,  25706,    271,      2,      2,  67151,
          59520,  32686,  23962,     98, 121303,    111,   1238, 141775,    678,
           6653,     53,  10646,      7,      2]], device='cuda:0'), 'attention_mask': tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True]], device='cuda:0'), 'label_ids': tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, -1, -1,
         -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]],
       device='cuda:0'), 'mt_offsets': [[(0, 0), (0, 5), (5, 13), (13, 17), (17, 21), (21, 25), (25, 26), (26, 27), (27, 38), (38, 39), (39, 46), (46, 49), (49, 53), (53, 57), (57, 59), (0, 0)]]}, {'input_ids': tensor([[     0,  67151,  59520,    

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Predicting DataLoader 0: 100%|██████████| 1/1 [00:01<00:00,  1.60s/it]

now the word_level_prob function is being called: 
mt_sentence:  Boris Johnson ist bei Tory-Abgeordneten völlig in der Gunst
subword_probs[0]:  tensor([[0.3554, 0.2145, 0.2102, 0.2199],
        [0.5846, 0.1495, 0.1520, 0.1139],
        [0.6238, 0.1366, 0.1445, 0.0952],
        [0.1280, 0.1513, 0.2221, 0.4986],
        [0.3183, 0.1446, 0.2166, 0.3205],
        [0.5808, 0.1740, 0.1477, 0.0975],
        [0.4583, 0.2637, 0.2076, 0.0703],
        [0.4516, 0.2554, 0.2178, 0.0751],
        [0.2695, 0.2651, 0.2737, 0.1917],
        [0.2902, 0.2185, 0.2517, 0.2396],
        [0.1104, 0.1495, 0.2237, 0.5164],
        [0.1130, 0.1312, 0.1900, 0.5659],
        [0.1128, 0.1399, 0.2107, 0.5366],
        [0.1422, 0.1559, 0.2223, 0.4797],
        [0.1456, 0.1444, 0.1989, 0.5111],
        [0.5671, 0.1140, 0.1744, 0.1444]], device='cuda:0')
token_probs:  {0: [array([0.58458394, 0.14953566, 0.15199055, 0.11388987], dtype=float32)], 1: [array([0.6237879 , 0.13660619, 0.14445591, 0.09515005], dtype=float32)




[0.6365968585014343]
0.6365968585014343
[[{'text': 'ist bei', 'confidence': 0.4095497727394104, 'severity': 'critical', 'start': 13, 'end': 21, 'check severity': ['critical', 'critical']}, {'text': 'Abgeordnete', 'confidence': 0.2736634612083435, 'severity': 'major', 'start': 27, 'end': 38, 'check severity': ['major']}, {'text': 'völlig in der Gunst', 'confidence': 0.5219249129295349, 'severity': 'critical', 'start': 39, 'end': 59, 'check severity': ['critical', 'critical', 'critical', 'critical', 'critical']}]]
[[{'text': 'ist bei', 'confidence': 0.4095497727394104, 'severity': 'critical', 'start': 13, 'end': 21}, {'text': 'Tory-Abgeordneten', 'confidence': 0.2736634612083435, 'severity': 'major', 'start': 21, 'end': 39}, {'text': 'völlig in der Gunst', 'confidence': 0.5219249129295349, 'severity': 'critical', 'start': 39, 'end': 59}]]
[[{'token': 'Boris', 'probabilities': [0.5845839381217957, 0.1495356559753418, 0.15199054777622223, 0.11388987302780151]}, {'token': 'Johnson', 'probab