In [1]:
from comet import download_model, load_from_checkpoint
from comet.models.multitask.unified_metric import UnifiedMetric
from comet.models.utils import Prediction
from typing import Dict, Optional
from typing import List, Dict
from typing import Union, Tuple
from collections import defaultdict
import numpy
import inspect 
import torch
import torch.nn as nn  # <-- Add thiss
class CustomXCOMET(UnifiedMetric):

    
    def word_level_prob(
        self,
        subword_probs: torch.Tensor,
        batch
    ) -> List[List[Dict[str, float]]]:
        """ Returns word level probability score
        word_ids = {
    'words_id': [
        [None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, None],
        [None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, None]
    ],
    'mt_sentences': [
        "Can I receive my food in 10 to 15 minutes?",
        "Can you send it for 10 to 15 minutes?"
    ],
    'mt_sentences_tokenized': [
        {
            'input_ids': tensor([
                [0, 4171, 87, 53299, 759, 15381, 23, 209, 47, 423, 14633, 32, 2],
                [0, 4171, 398, 25379, 442, 100, 209, 47, 423, 14633, 32, 2, 1]
            ], device='cuda:0'),
            'label_ids': tensor([
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1]
            ], device='cuda:0'),
            'attention_mask': tensor([
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
            ], device='cuda:0'),
            'offsets': [
                [(0, 0), (0, 3), (3, 5), (5, 13), (13, 16), (16, 21), 
                 (21, 24), (24, 27), (27, 30), (30, 33), (33, 41), (41, 42), (0, 0)],
                [(0, 0), (0, 3), (3, 7), (7, 12), (12, 15), (15, 19), 
                 (19, 22), (22, 25), (25, 28), (28, 36), (36, 37), (0, 0)]
            ],
            'word_ids': [
                [None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, None],
                [None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, None]
            ]
            }
                ]
                    }

        """
        tokenizer = self.encoder.tokenizer
        # ====== Reconstruct MT sentence from batch ======
        input_ids = batch[0]["input_ids"]  # Tokenized MT input
        mt_sentence = self.encoder.tokenizer.decode(
                input_ids[0],
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True
        )
        print("mt sentence:++++++++++++++++++ ", mt_sentence)

        
        ## run over the mt sentences in the dict word_ids
        word_level_prob = []
        all_tokenized_sentences = []
       
        # Tokenize the MT sentence to get subword-to-token alignment
        tokenized = self.encoder.tokenizer(
                mt_sentence,
                return_offsets_mapping=True,
                return_tensors="pt",
                truncation=True,
                        
        )

        #print("tokenized: ", tokenized)
        subword_ids = tokenized.word_ids()
        #subword_ids = subword_ids[:seq_len]
        #subword_ids[-1] = None
        #print("subword_ids after mt sentence extractions: ", subword_ids)
        # Group subword probabilities by original tokens
        #subword_ids = item
        token_probs = {}
        #print(f"subword_probs[{index}]:  {subword_probs[index]}")
        #attention_mask = word_ids["mt_sentences_tokenized"][0]["attention_mask"][index] #did not use this
        for idx, prob in enumerate(subword_probs[0]):
            # if attention_mask[index] == 0:
            #     break
            if idx >= len(subword_ids):
                break
            subword_idx = subword_ids[idx]
            if subword_idx is None:  # Skip special tokens
                continue
            if subword_idx not in token_probs:
                token_probs[subword_idx] = []
            token_probs[subword_idx].append(prob.cpu().numpy())
        print("token_probs: ", token_probs)

        # Aggregate probabilities (average for each class)
        token_level_probs = []
        for token_idx in sorted(token_probs.keys()):
            # Stack subword probabilities for this token
            subword_probs_for_token = torch.stack([torch.tensor(p) for p in token_probs[token_idx]])

            # Compute mean across subwords (dim=0 → average over subwords, per class)
            mean_probs = torch.mean(subword_probs_for_token, dim=0)

            token_level_probs.append(mean_probs.numpy())
        print("token_level_probs: ", token_level_probs)
        # After computing token_level_probs:
        # Tolerance for floating-point errors (e.g., 1e-3 = 0.1% tolerance)
        tolerance = 1e-3
        for token_idx, probs in enumerate(token_level_probs):
            total = numpy.sum(probs)
            if not numpy.isclose(total, 1.0, atol=tolerance):
                print(f"Token {token_idx} probabilities sum to {total:.4f} (expected ~1.0)")
            else:
                print(f"Token {token_idx} probabilities sum to {total:.4f}")
                    
        # Extract word IDs (index of the original word for each token)
        #word_ids = tokenized.word_ids()[:mt_length]    #  [None, 0, 0, 1, 1, 2, ...]
        # Convert token IDs to tokens (subwords)
        tokens = tokenizer.convert_ids_to_tokens(tokenized["input_ids"][0])
        print("tokens: ",tokens)
        # Group tokens by their word ID
        word_to_tokens = {}
        #print("word_ids: ", subword_ids)
        for idx, word_id in enumerate(subword_ids):
            if word_id is None:
                continue  # Skip special tokens like [CLS], [SEP]
            if word_id not in word_to_tokens:
                word_to_tokens[word_id] = []
            word_to_tokens[word_id].append(tokens[idx])
        print("word_to_tokens: ", word_to_tokens)
        print("sorted(word_to_tokens.keys()) :", sorted(word_to_tokens.keys()))
        # Reconstruct original words from grouped tokens
        word_mapping = []
        for word_id in sorted(word_to_tokens.keys()):
            tokens = word_to_tokens[word_id]
            # Merge subwords into a single string (handles ## prefixes)
            word = tokenizer.convert_tokens_to_string(tokens).strip()
            word_mapping.append(word)
        # Print results
        print("Tokenized Words:", word_mapping)
        all_tokenized_sentences.append(word_mapping)
        # Map tokens to probabilities
        token_predictions = [
                {"token": token, "probabilities": probs.tolist()}
                for token, probs in zip(word_mapping, token_level_probs)
        ]
        # print("Token-Level Probabilities:")
        for pred in token_predictions:
            print(f"{pred['token']}: {pred['probabilities']}")
        print("first sentence finished=====================")
        word_level_prob.append(token_predictions)
        return word_level_prob, all_tokenized_sentences

    
    #word_ids = batch[-1].copy()
    #           word_level_prob, all_tokenized_sentences = self.word_level_prob(subword_probs,word_ids)
    def predict_step(
        self,
        batch: Dict[str, torch.Tensor],
        batch_idx: Optional[int] = None,
        dataloader_idx: Optional[int] = None,
    ) -> Prediction:
        """PyTorch Lightning predict_step

        Args:
            batch (Dict[str, torch.Tensor]): The output of your prepare_sample function
            batch_idx (Optional[int], optional): Integer displaying which batch this is
                Defaults to None.
            dataloader_idx (Optional[int], optional): Integer displaying which
                dataloader this is. Defaults to None.

        Returns:
            Prediction: Model Prediction
        """
        if len(batch) == 3:
            predictions = [self.forward(**input_seq) for input_seq in batch]
            # Final score is the average of the 3 scores!
            avg_scores = torch.stack([pred.score for pred in predictions], dim=0).mean(
                dim=0
            )
            batch_prediction = Prediction(
                scores=avg_scores,
                metadata=Prediction(
                    src_scores=predictions[0].score,
                    ref_scores=predictions[1].score,
                    unified_scores=predictions[2].score,
                ),
            )
            if self.word_level:
                mt_mask = batch[0]["label_ids"] != -1
                mt_length = mt_mask.sum(dim=1)
                seq_len = mt_length.max()
                subword_probs = [
                    nn.functional.softmax(o.logits, dim=2)[:, :seq_len, :] * w
                    for w, o in zip(self.input_weights_spans, predictions)
                ]
                subword_probs = torch.sum(torch.stack(subword_probs), dim=0)
                word_level_prob, all_tokenized_sentences = self.word_level_prob(subword_probs, batch)
                error_spans = self.decode(
                    subword_probs, batch[0]["input_ids"], batch[0]["mt_offsets"]
                )
                batch_prediction.metadata["error_spans"] = error_spans

        else:
            model_output = self.forward(**batch[0])
            batch_prediction = Prediction(scores=model_output.score)
            if self.word_level:
                mt_mask = batch[0]["label_ids"] != -1
                mt_length = mt_mask.sum(dim=1)
                seq_len = mt_length.max()
                subword_probs = nn.functional.softmax(model_output.logits, dim=2)[
                    :, :seq_len, :
                ]
                error_spans = self.decode(
                    subword_probs, batch[0]["input_ids"], batch[0]["mt_offsets"]
                )
                batch_prediction = Prediction(
                    scores=model_output.score,
                    metadata=Prediction(error_spans=error_spans),
                )
        return batch_prediction
# Load checkpoint into your custom class
path = "/storage/brno2/home/rahmang/xcomet/downloadedxcomet/models--Unbabel--XCOMET-XL/snapshots/50d428488e021205a775d5fab7aacd9502b58e64/checkpoints/model.ckpt"

model = load_from_checkpoint(path)
# from comet import download_model, load_from_checkpoint

# model_path = download_model("Unbabel/XCOMET-XL")
# model = load_from_checkpoint(model_path)

data = [
    {
        "src": "Boris Johnson teeters on edge of favour with Tory MPs",
        "mt": "Boris Johnson ist bei Tory-Abgeordneten völlig in der Gunst",
        #"ref": "Boris Johnsons Beliebtheit bei Tory-MPs steht auf der Kippe"
    }
]
# data = [
#     {
#         "src": "10 到 15 分钟可以送到吗",
#         "mt": "Can I receive my food in 10 to 15 minutes?",
#         "ref": "Can it be delivered between 10 to 15 minutes?"
#     },
#     {
#         "src": "Pode ser entregue dentro de 10 a 15 minutos?",
#         "mt": "Can you send it for 10 to 15 minutes?",
#         "ref": "Can it be delivered between 10 to 15 minutes?"
#     }
# ]
model_output = model.predict(data, batch_size=8, gpus=1)
# Segment-level scores
print (model_output.scores)

# System-level score
print (model_output.system_score)

# Score explanation (error spans)
print (model_output.metadata.error_spans)


  from .autonotebook import tqdm as notebook_tqdm
Encoder model frozen.
/storage/brno2/home/rahmang/envs/xcomet/lib/python3.11/site-packages/pytorch_lightning/core/saving.py:195: Found keys that are not in the model state dict but in the checkpoint: ['encoder.model.embeddings.position_ids']
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A40') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [GPU-7a094b53-c5ef-9efd-bb6d-2ffdd67ae03d]
Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  1.25it/s]


[0.8439337015151978]
0.8439337015151978
[[{'text': 'ist', 'confidence': 0.45241352915763855, 'severity': 'minor', 'start': 13, 'end': 17}, {'text': 'y-Abgeordneten völlig in der Gunst', 'confidence': 0.43448933959007263, 'severity': 'minor', 'start': 25, 'end': 59}]]
