In [1]:
!pwd

/auto/brno2/home/rahmang/xcomet/COMET


In [2]:
!nvidia-smi

Sun Mar 16 21:24:29 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.05              Driver Version: 560.35.05      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A40                     Off |   00000000:01:00.0 Off |                    0 |
|  0%   42C    P0             80W /  300W |     327MiB /  46068MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [1]:
from comet import download_model, load_from_checkpoint
from comet.models.multitask.unified_metric import UnifiedMetric
from comet.models.utils import Prediction
from typing import Dict, Optional
from typing import List, Dict
from typing import Union, Tuple
import inspect 
import torch
import torch.nn as nn  # <-- Add this
class CustomXCOMET(UnifiedMetric):

    def prepare_sample(
        self, sample: List[Dict[str, Union[str, float]]], stage: str = "fit"
    ) -> Union[Tuple[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]]:
        """Tokenizes input data and prepares targets for training.

        Args:
            sample (List[Dict[str, Union[str, float]]]): Mini-batch
            stage (str, optional): Model stage ('train' or 'predict'). Defaults to "fit".

        Returns:
            Union[Tuple[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]]: Model input
                and targets.
        """
        # Get the caller's function name and file location
        caller_frame = inspect.stack()[1]
        caller_function = caller_frame.function
        caller_file = caller_frame.filename
    
        # Print caller details
        print(f"prepare_sample called by: {caller_function} (from {caller_file})")
        print("sample in the prepare_sample: ", sample)
        print("the stage is: ", stage)
        for k in sample[0]:
            print("K is: ", k)
        inputs = {k: [d[k] for d in sample] for k in sample[0]}
        print(f'''inputs["mt"]: {inputs["mt"]}''')
        input_sequences = [
            self.encoder.prepare_sample(inputs["mt"], self.word_level, None),
        ]
        print("input_sequences: ", input_sequences)
        src_input, ref_input = False, False
        if ("src" in inputs) and ("src" in self.hparams.input_segments):
            input_sequences.append(self.encoder.prepare_sample(inputs["src"]))
            src_input = True

        if ("ref" in inputs) and ("ref" in self.hparams.input_segments):
            input_sequences.append(self.encoder.prepare_sample(inputs["ref"]))
            ref_input = True

        unified_input = src_input and ref_input
        model_inputs = self.concat_inputs(input_sequences, unified_input)
        if stage == "predict":
            return model_inputs["inputs"]

        scores = [float(s) for s in inputs["score"]]
        targets = Target(score=torch.tensor(scores, dtype=torch.float))

        if "system" in inputs:
            targets["system"] = inputs["system"]

        if self.word_level:
            # Labels will be the same accross all inputs because we are only
            # doing sequence tagging on the MT. We will only use the mask corresponding
            # to the MT segment.
            seq_len = model_inputs["mt_length"].max()
            targets["mt_length"] = model_inputs["mt_length"]
            targets["labels"] = model_inputs["inputs"][0]["label_ids"][:, :seq_len]

        return model_inputs["inputs"], targets
    
    def decode(
        self,
        subword_probs: torch.Tensor,
        input_ids: torch.Tensor,
        mt_offsets: torch.Tensor,
    ) -> List[Dict]:
        """Decode error spans from subwords.

        Args:
            subword_probs (torch.Tensor): probabilities of each label for each subword.
            input_ids (torch.Tensor): input ids from the model.
            mt_offsets (torch.Tensor): subword offsets.

        Return:
            List with of dictionaries with text, start, end, severity and a
            confidence score which is the average of the probs for that label.
        """
        print("====================== decode function ========================")
        print("subword_probs: ", subword_probs)
        print("input_ids: ", input_ids)
        print("mt_offsets: ", mt_offsets)
        decoded_output = []
        print("length of mt_offsets: ", len(mt_offsets))
        for i in range(len(mt_offsets)):
            print("the value of i is before inner for loop: ", i)
            seq_len = len(mt_offsets[i])
            print("seq_len: ", seq_len)
            error_spans, in_span, span = [], False, {}
            for token_id, probs, token_offset in zip(
                input_ids[i, :seq_len], subword_probs[i][:seq_len], mt_offsets[i]
            ):  
                print("the value of i is: ", i)
                print("token_id :", token_id, ", probs: ", probs, ", token_offset: ", token_offset)
                if self.decoding_threshold:
                    if torch.sum(probs[1:]) > self.decoding_threshold:
                        probability, label_value = torch.topk(probs[1:], 1)
                        label_value += 1  # offset from removing label 0
                    else:
                        # This is just to ensure same format but at this point
                        # we will only look at label 0 and its prob
                        probability, label_value = torch.topk(probs[0], 1)
                        print("probs[0] =============", probs[0])
                else:
                    probability, label_value = torch.topk(probs, 1)

                # Some torch versions topk returns a shape 1 tensor with only
                # a item inside
                label_value = (
                    label_value.item()
                    if label_value.dim() < 1
                    else label_value[0].item()
                )
                label = self.label_encoder.ids_to_label.get(label_value)
                print("===================================================")
                print("label: ", label)
                # Label set:
                # O I-minor I-major
                # Begin of annotation span
                if label.startswith("I") and not in_span:
                    in_span = True
                    span["tokens"] = [
                        token_id,
                    ]
                    span["severity"] = label.split("-")[1]
                    span["offset"] = list(token_offset)
                    span["confidence"] = [
                        probability,
                    ]
                    span["check severity"] = [label.split("-")[1]]

                # Inside an annotation span
                elif label.startswith("I") and in_span:
                    span["tokens"].append(token_id)
                    span["confidence"].append(probability)
                    # Update offset end
                    span["offset"][1] = token_offset[1]
                    span["check severity"].append(label.split("-")[1])
                # annotation span finished.
                elif label == "O" and in_span:
                    error_spans.append(span)
                    in_span, span = False, {}

            sentence_output = []
            for span in error_spans:
                print("span[tokens]: ", span["tokens"])
                sentence_output.append(
                    {
                        
                        "text": self.encoder.tokenizer.decode(span["tokens"]),
                        "confidence": torch.concat(span["confidence"]).mean().item(),
                        "severity": span["severity"],
                        "start": span["offset"][0],
                        "end": span["offset"][1],
                        "check severity": span["check severity"]
                    }
                )
            decoded_output.append(sentence_output)
        return decoded_output
    
    def predict_step(
        self,
        batch: Dict[str, torch.Tensor],
        batch_idx: Optional[int] = None,
        dataloader_idx: Optional[int] = None,
    ) -> Prediction:
        """PyTorch Lightning predict_step

        Args:
            batch (Dict[str, torch.Tensor]): The output of your prepare_sample function
            batch_idx (Optional[int], optional): Integer displaying which batch this is
                Defaults to None.
            dataloader_idx (Optional[int], optional): Integer displaying which
                dataloader this is. Defaults to None.

        Returns:
            Prediction: Model Prediction
        """
        if len(batch) == 3:
            print("++++++++++++++++++++++inside predict_step function+++++++++++++++++++")
            print("batch: ", batch)
            print("i am inside when len of the batch is 3")
            predictions = [self.forward(**input_seq) for input_seq in batch]
            #print("predictions: ", predictions)
            avg_scores = torch.stack([pred.score for pred in predictions], dim=0).mean(dim=0)
            #print("avg scores", avg_scores)
            batch_prediction = Prediction(
                scores=avg_scores,
                metadata=Prediction(
                    src_scores=predictions[0].score,
                    ref_scores=predictions[1].score,
                    unified_scores=predictions[2].score,
                ),
            )
            if self.word_level:
                mt_mask = batch[0]["label_ids"] != -1
                mt_length = mt_mask.sum(dim=1)
                seq_len = mt_length.max()
                subword_probs = [
                    nn.functional.softmax(o.logits, dim=2)[:, :seq_len, :] * w
                    for w, o in zip(self.input_weights_spans, predictions)
                ]
                subword_probs = torch.sum(torch.stack(subword_probs), dim=0)
                print("subword probs :", subword_probs)
                tokenizer = self.encoder.tokenizer
                # check if you can access the word ids
                # print("batch[word_ids] :==================")
                # for sample in batch:
                #     print("word_ids:", sample["word_ids"])  # Access per-sample word_ids
                # ====== Reconstruct MT sentence from batch ======
                # input_ids = batch[0]["input_ids"]  # Tokenized MT input
                # mt_sentence = self.encoder.tokenizer.decode(
                #         input_ids[0],
                #         skip_special_tokens=True,
                #         clean_up_tokenization_spaces=True
                # )

                # # Tokenize the MT sentence to get subword-to-token alignment
                # tokenized = self.encoder.tokenizer(
                #         mt_sentence,
                #         return_offsets_mapping=True,
                #         return_tensors="pt",
                #         truncation=True,
                # )

                # subword_ids = tokenized.word_ids()
                # print("subword_ids after mt sentence extractions: ", subword_ids)
                # # Group subword probabilities by original tokens
                # token_probs = {}
                # print("subword_probs[0]: ", subword_probs[0])
                # for idx, prob in enumerate(subword_probs[0]):
                #     subword_idx = subword_ids[idx]
                #     if subword_idx is None:  # Skip special tokens
                #         continue
                #     if subword_idx not in token_probs:
                #         token_probs[subword_idx] = []
                #     token_probs[subword_idx].append(prob.cpu().numpy())
                # print("token_probs: ", token_probs)
                # # Aggregate probabilities ( max for each class)
                # token_level_probs = []
                # for token_idx in sorted(token_probs.keys()):
                #     probs = torch.stack([torch.tensor(p) for p in token_probs[token_idx]])
                #     max_probs, _ = torch.max(probs, dim=0)
                #     token_level_probs.append(max_probs.numpy())
                # print("token_level_probs: ", token_level_probs)
                
                # # Extract word IDs (index of the original word for each token)
                # word_ids = tokenized.word_ids()[:mt_length]    #  [None, 0, 0, 1, 1, 2, ...]

                # # Convert token IDs to tokens (subwords)
                # tokens = tokenizer.convert_ids_to_tokens(tokenized["input_ids"][0])
                # print("tokens: ",tokens)
                # # Group tokens by their word ID
                # word_to_tokens = {}
                # print("word_ids: ", word_ids)
                # for idx, word_id in enumerate(word_ids):
                #     if word_id is None:
                #         continue  # Skip special tokens like [CLS], [SEP]
                #     if word_id not in word_to_tokens:
                #         word_to_tokens[word_id] = []
                #     word_to_tokens[word_id].append(tokens[idx])
                # print("word_to_tokens: ", word_to_tokens)

                # print("sorted(word_to_tokens.keys()) :", sorted(word_to_tokens.keys()))
                # # Reconstruct original words from grouped tokens
                # word_mapping = []
                # for word_id in sorted(word_to_tokens.keys()):
                #     tokens = word_to_tokens[word_id]
                #     # Merge subwords into a single string (handles ## prefixes)
                #     word = tokenizer.convert_tokens_to_string(tokens).strip()
                #     word_mapping.append(word)

                # # Print results
                # print("Tokenized Words:", word_mapping)
                # # Map tokens to probabilities
                # token_predictions = [
                #         {"token": token, "probabilities": probs.tolist()}
                #         for token, probs in zip(word_mapping, token_level_probs)
                # ]

                # print("Token-Level Probabilities:")
                # for pred in token_predictions:
                #     print(f"{pred['token']}: {pred['probabilities']}")



                ########################################################3
                ## create error span using decode function
                error_spans = self.decode(
                    subword_probs, batch[0]["input_ids"], batch[0]["mt_offsets"]
                )
                batch_prediction.metadata["error_spans"] = error_spans
        else:
            print("i am inside when len of the batch is not 3")
            model_output = self.forward(**batch[0])
            batch_prediction = Prediction(scores=model_output.score)
            if self.word_level:
                mt_mask = batch[0]["label_ids"] != -1
                mt_length = mt_mask.sum(dim=1)
                seq_len = mt_length.max()
                subword_probs = nn.functional.softmax(model_output.logits, dim=2)[:, :seq_len, :]
                error_spans = self.decode(
                    subword_probs, batch[0]["input_ids"], batch[0]["mt_offsets"]
                )
                batch_prediction = Prediction(
                    scores=model_output.score,
                    metadata=Prediction(error_spans=error_spans),
                )
        return batch_prediction

# Load checkpoint into your custom class
path = "/storage/brno2/home/rahmang/xcomet/downloadedxcomet/models--Unbabel--XCOMET-XL/snapshots/50d428488e021205a775d5fab7aacd9502b58e64/checkpoints/model.ckpt"

model = CustomXCOMET.load_from_checkpoint(path,strict = False)
data = [
    {
        "src": "Boris Johnson teeters on edge of favour with Tory MPs",
        "mt": "Boris Johnson ist bei Tory-Abgeordneten völlig in der Gunst",
        "ref": "Boris Johnsons Beliebtheit bei Tory-MPs steht auf der Kippe"
    }
]
# data = [
#     {
#         "src": "10 到 15 分钟可以送到吗",
#         "mt": "Can I receive my food in 10 to 15 minutes?",
#         "ref": "Can it be delivered between 10 to 15 minutes?"
#     },
#     {
#         "src": "Pode ser entregue dentro de 10 a 15 minutos?",
#         "mt": "Can you send it for 10 to 15 minutes?",
#         "ref": "Can it be delivered between 10 to 15 minutes?"
#     }
# ]
model_output = model.predict(data, batch_size=8, gpus=1)
# Segment-level scores
print (model_output.scores)

# System-level score
print (model_output.system_score)

# Score explanation (error spans)
print (model_output.metadata.error_spans)

  from .autonotebook import tqdm as notebook_tqdm
Encoder model frozen.
/storage/brno2/home/rahmang/envs/xcomet/lib/python3.11/site-packages/pytorch_lightning/core/saving.py:195: Found keys that are not in the model state dict but in the checkpoint: ['encoder.model.embeddings.position_ids']
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A40') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [GPU-457cb514-0138-7a8c-26d6-18b5a663d37f]
Predicting: 0it [00:00, ?it/s]

prepare_sample called by: prepare_for_inference (from /auto/brno2/home/rahmang/xcomet/COMET/comet/models/base.py)
sample in the prepare_sample:  [{'src': 'Boris Johnson teeters on edge of favour with Tory MPs', 'mt': 'Boris Johnson ist bei Tory-Abgeordneten völlig in der Gunst', 'ref': 'Boris Johnsons Beliebtheit bei Tory-MPs steht auf der Kippe'}]
the stage is:  predict
K is:  src
K is:  mt
K is:  ref
inputs["mt"]: ['Boris Johnson ist bei Tory-Abgeordneten völlig in der Gunst']
sample:  ['Boris Johnson ist bei Tory-Abgeordneten völlig in der Gunst']
encoder_input:  {'input_ids': [[0, 67151, 59520, 443, 1079, 6653, 53, 9, 241832, 19, 86454, 23, 122, 25706, 271, 2]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'offset_mapping': [[(0, 0), (0, 5), (5, 13), (13, 17), (17, 21), (21, 25), (25, 26), (26, 27), (27, 38), (38, 39), (39, 46), (46, 49), (49, 53), (53, 57), (57, 59), (0, 0)]]}
len(sample):  1
the value of i is:  0
encoder_input[i],   Encoding(num_tokens=16

Predicting DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

++++++++++++++++++++++inside predict_step function+++++++++++++++++++
batch:  ({'input_ids': tensor([[     0,  67151,  59520,    443,   1079,   6653,     53,      9, 241832,
             19,  86454,     23,    122,  25706,    271,      2,      2,  67151,
          59520,  32686,  23962,     98, 121303,    111,   1238, 141775,    678,
           6653,     53,  10646,      7,      2]], device='cuda:0'), 'attention_mask': tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True]], device='cuda:0'), 'label_ids': tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, -1, -1,
         -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]],
       device='cuda:0'), 'mt_offsets': [[(0, 0), (0, 5), (5, 13), (13, 17), (17, 21), (21, 25), (25, 26), (26, 27), (27, 38), (38, 39), (39, 46), (46, 49), (49, 53), (53, 57), (

Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  1.53it/s]

subword probs : tensor([[[0.3554, 0.2145, 0.2102, 0.2199],
         [0.5846, 0.1495, 0.1520, 0.1139],
         [0.6238, 0.1366, 0.1445, 0.0952],
         [0.1280, 0.1513, 0.2221, 0.4986],
         [0.3183, 0.1446, 0.2166, 0.3205],
         [0.5808, 0.1740, 0.1477, 0.0975],
         [0.4583, 0.2637, 0.2076, 0.0703],
         [0.4516, 0.2554, 0.2178, 0.0751],
         [0.2695, 0.2651, 0.2737, 0.1917],
         [0.2902, 0.2185, 0.2517, 0.2396],
         [0.1104, 0.1495, 0.2237, 0.5164],
         [0.1130, 0.1312, 0.1900, 0.5659],
         [0.1128, 0.1399, 0.2107, 0.5366],
         [0.1422, 0.1559, 0.2223, 0.4797],
         [0.1456, 0.1444, 0.1989, 0.5111],
         [0.5671, 0.1140, 0.1744, 0.1444]]], device='cuda:0')
subword_probs:  tensor([[[0.3554, 0.2145, 0.2102, 0.2199],
         [0.5846, 0.1495, 0.1520, 0.1139],
         [0.6238, 0.1366, 0.1445, 0.0952],
         [0.1280, 0.1513, 0.2221, 0.4986],
         [0.3183, 0.1446, 0.2166, 0.3205],
         [0.5808, 0.1740, 0.1477, 0.0975],
   




[0.6365968585014343]
0.6365968585014343
[[{'text': 'ist bei', 'confidence': 0.4095497727394104, 'severity': 'critical', 'start': 13, 'end': 21, 'check severity': ['critical', 'critical']}, {'text': 'Abgeordnete', 'confidence': 0.2736634612083435, 'severity': 'major', 'start': 27, 'end': 38, 'check severity': ['major']}, {'text': 'völlig in der Gunst', 'confidence': 0.5219249129295349, 'severity': 'critical', 'start': 39, 'end': 59, 'check severity': ['critical', 'critical', 'critical', 'critical', 'critical']}]]
