In [1]:
import pandas as pd

df = pd.read_parquet('train.parquet')
cv = pd.read_csv('cv_split.csv')
df = df.merge(cv, on='id', how='left')

df_test = pd.read_csv('test.csv')

In [2]:
import spacy

from spacy.training.iob_utils import biluo_to_iob, doc_to_biluo_tags
from tqdm.autonotebook import tqdm
tqdm.pandas()

df.trigger_words = df.trigger_words.apply(lambda x: [] if x is None else x)

  from tqdm.autonotebook import tqdm


In [3]:
PRETRAINED_MODEL = '/home/abazdyrev/pretrained_models/gemma2-27b-it-unmasked'
MAX_LEN = 2560

In [4]:
from transformers import AutoTokenizer
import pandas as pd
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

In [5]:
from transformers import AutoTokenizer
import pandas as pd

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

def convert_to_seq_labeling(text, tokenizer, trigger_spans=None):
    tokenized_output = tokenizer(
        text, return_offsets_mapping=True, add_special_tokens=True, max_length=MAX_LEN,
        truncation=True, padding=False
    )
    tokens = tokenized_output["input_ids"]
    offsets = tokenized_output["offset_mapping"]

    # Get subword tokenized versions of the text
    token_strings = tokenizer.convert_ids_to_tokens(tokens)

    
    # Initialize labels as 'O'
    labels = [0] * len(tokens)

    if trigger_spans is not None:
        # Assign 'TRIGGER' to overlapping tokens
        for start, end in trigger_spans:
            for i, (tok_start, tok_end) in enumerate(offsets):
                if tok_start == 0 and tok_end == 0:
                    continue
                if tok_start < end and tok_end > start:  # If token overlaps with the trigger span
                    labels[i] = 1

    tokenized_output['labels'] = labels
    return tokenized_output

In [6]:
from tqdm.autonotebook import tqdm

tqdm.pandas()

df['seq_labels'] = df.progress_apply(lambda row: convert_to_seq_labeling(row['content'], tokenizer, row['trigger_words']), axis=1)
for column in df.seq_labels.iloc[0].keys():
    df[column] = df.seq_labels.apply(lambda x: x.get(column))

df_test['seq_labels'] = df_test.progress_apply(lambda row: convert_to_seq_labeling(row['content'], tokenizer, None), axis=1)
for column in df_test.seq_labels.iloc[0].keys():
    df_test[column] = df_test.seq_labels.apply(lambda x: x.get(column))

  0%|          | 0/3822 [00:00<?, ?it/s]

  0%|          | 0/5735 [00:00<?, ?it/s]

In [7]:
from datasets import Dataset
import numpy as np

df['is_valid'] = df.fold == 4

columns = list(df.seq_labels.iloc[0].keys()) + ['content', 'trigger_words']
ds_train = Dataset.from_pandas(df[df.is_valid==0][columns].reset_index(drop=True))
ds_valid = Dataset.from_pandas(df[df.is_valid==1][columns].reset_index(drop=True))

columns = list(df.seq_labels.iloc[0].keys()) + ['content']
ds_test = Dataset.from_pandas(df_test[columns].reset_index(drop=True))

In [8]:
import torch

from packaging import version
import importlib.metadata

from transformers import Gemma2Model, Gemma2ForCausalLM, Gemma2PreTrainedModel, Gemma2Config
from transformers.models.gemma2.modeling_gemma2 import (
    Gemma2DecoderLayer,
    Gemma2Attention,
    Gemma2FlashAttention2,
    Gemma2SdpaAttention,
    Gemma2MLP,
    Gemma2RMSNorm,
)

from torch import nn
from transformers.utils import logging

from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.utils.import_utils import _is_package_available
from transformers.cache_utils import Cache, StaticCache
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import MaskedLMOutput
from transformers.models.gemma2.modeling_gemma2 import *

from peft import PeftModel

logger = logging.get_logger(__name__)


def is_transformers_attn_greater_or_equal_4_41():
    if not _is_package_available("transformers"):
        return False

    return version.parse(importlib.metadata.version("transformers")) >= version.parse(
        "4.41.0"
    )


class ModifiedGemma2Attention(Gemma2Attention):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.is_causal = False


class ModifiedGemma2FlashAttention2(Gemma2FlashAttention2):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.is_causal = False


class ModifiedGemma2SdpaAttention(Gemma2SdpaAttention):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.is_causal = False


GEMMA2_ATTENTION_CLASSES = {
    "eager": ModifiedGemma2Attention,
    "flash_attention_2": ModifiedGemma2FlashAttention2,
    "sdpa": ModifiedGemma2SdpaAttention,
}


class ModifiedGemma2DecoderLayer(Gemma2DecoderLayer):
    def __init__(self, config: Gemma2Config, layer_idx: int):
        nn.Module.__init__(self)
        self.config = config
        self.hidden_size = config.hidden_size

        self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](
            config=config, layer_idx=layer_idx
        )

        self.mlp = Gemma2MLP(config)
        self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.is_sliding = not bool(layer_idx % 2)
        self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.sliding_window = config.sliding_window



class Gemma2BiModel(Gemma2Model):
    _no_split_modules = ["ModifiedGemma2DecoderLayer"]

    def __init__(self, config: Gemma2Config):
        Gemma2PreTrainedModel.__init__(self, config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(
            config.vocab_size, config.hidden_size, self.padding_idx
        )
        self.layers = nn.ModuleList(
            [
                ModifiedGemma2DecoderLayer(config, layer_idx)
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
        self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

    def _update_causal_mask(
        self,
        attention_mask: torch.Tensor,
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: Cache,
        output_attentions: bool,
    ):
        if self.config._attn_implementation == "flash_attention_2":
            if attention_mask is not None and 0.0 in attention_mask:
                return attention_mask
            return None

        dtype, device = input_tensor.dtype, input_tensor.device
        min_dtype = torch.finfo(dtype).min
        sequence_length = input_tensor.shape[1]
        if past_key_values is not None:
            target_length = past_key_values.get_max_length()
        else:
            target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]

        if attention_mask is not None and attention_mask.dim() == 4:
            # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
            if attention_mask.max() != 0:
                raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
            causal_mask = attention_mask
        else:
            causal_mask = torch.zeros(
                (sequence_length, target_length), dtype=dtype, device=device
            )
            # causal_mask = torch.full(
            #     (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
            # )
            # if sequence_length != 1:
            #     causal_mask = torch.triu(causal_mask, diagonal=1)
            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
            causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
            if attention_mask is not None:
                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
                mask_length = attention_mask.shape[-1]
                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
                padding_mask = padding_mask == 0
                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
                    padding_mask, min_dtype
                )
        return causal_mask


class Gemma2BiForMNTP(Gemma2ForCausalLM):
    def __init__(self, config):
        Gemma2PreTrainedModel.__init__(self, config)
        self.model = Gemma2BiModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    # getter for PEFT model
    def get_model_for_peft(self):
        return self.model

    # setter for PEFT model
    def set_model_for_peft(self, model: PeftModel):
        self.model = model

    # save the PEFT model
    def save_peft_model(self, path):
        self.model.save_pretrained(path)


class Gemma2BiMLM(Gemma2ForCausalLM):
    def __init__(self, config):
        Gemma2PreTrainedModel.__init__(self, config)
        self.model = Gemma2BiModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[HybridCache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **loss_kwargs,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = outputs[0]
        prediction_scores = self.lm_head(sequence_output)
        masked_lm_loss = None
        if labels is not None:
            # move labels to correct device to enable model parallelism
            labels = labels.to(prediction_scores.device)
            loss_fct = CrossEntropyLoss()
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

        if not return_dict:
            output = (prediction_scores,) + outputs[2:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        return MaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    # getter for PEFT model
    def get_model_for_peft(self):
        return self.model

    # setter for PEFT model
    def set_model_for_peft(self, model: PeftModel):
        self.model = model

    # save the PEFT model
    def save_peft_model(self, path):
        self.model.save_pretrained(path)


In [9]:
from transformers import XLMRobertaConfig, XLMRobertaTokenizer, XLMRobertaForMaskedLM, AutoModelForMaskedLM, Gemma2ForCausalLM
from transformers import LineByLineTextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification


import torch
from transformers import Gemma2ForSequenceClassification, BitsAndBytesConfig
from peft import get_peft_config, prepare_model_for_kbit_training, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType

nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=False,
   bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)
backbone_model = Gemma2BiModel.from_pretrained(
    PRETRAINED_MODEL, torch_dtype=torch.bfloat16,
    quantization_config=nf4_config
)

`low_cpu_mem_usage` was None, now default to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/12 [00:00<?, ?it/s]

In [10]:
import torch
from transformers import Gemma2ForTokenClassification, LlamaForTokenClassification, BitsAndBytesConfig
from peft import get_peft_config, prepare_model_for_kbit_training, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType



id2label = {0: 0, 1: 1}
label2id = {0: 0, 1: 1}

from peft import get_peft_config, prepare_model_for_kbit_training, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType


nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=False,
   bnb_4bit_compute_dtype=torch.bfloat16
)

model = Gemma2ForTokenClassification.from_pretrained(
    '/home/abazdyrev/pretrained_models/gemma-2-27b-it',
    id2label=id2label,
    label2id=label2id,
    torch_dtype=torch.bfloat16,
    quantization_config=nf4_config
)

`low_cpu_mem_usage` was None, now default to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/12 [00:00<?, ?it/s]

Some weights of Gemma2ForTokenClassification were not initialized from the model checkpoint at /home/abazdyrev/pretrained_models/gemma-2-27b-it and are newly initialized: ['score.bias', 'score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
model.model = backbone_model

In [12]:
torch.cuda.empty_cache()

In [13]:
model

Gemma2ForTokenClassification(
  (model): Gemma2BiModel(
    (embed_tokens): Embedding(256000, 4608, padding_idx=0)
    (layers): ModuleList(
      (0-45): 46 x ModifiedGemma2DecoderLayer(
        (self_attn): ModifiedGemma2Attention(
          (q_proj): Linear4bit(in_features=4608, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4608, out_features=2048, bias=False)
          (v_proj): Linear4bit(in_features=4608, out_features=2048, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4608, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear4bit(in_features=4608, out_features=36864, bias=False)
          (up_proj): Linear4bit(in_features=4608, out_features=36864, bias=False)
          (down_proj): Linear4bit(in_features=36864, out_features=4608, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((4608,), eps=1e-06)
   

In [14]:
model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=64,  # the dimension of the low-rank matrices
    lora_alpha=32, # scaling factor for LoRA activations vs pre-trained weight activations
    lora_dropout=0.05, 
    bias='none',
    inference_mode=False,
    task_type=TaskType.TOKEN_CLS,
    target_modules=['o_proj', 'v_proj', "q_proj", "k_proj", "gate_proj", "down_proj", "up_proj"]
) 

model = get_peft_model(model, lora_config)
# Trainable Parameters
model.print_trainable_parameters()

trainable params: 456,729,602 || all params: 27,683,867,140 || trainable%: 1.6498


In [15]:
import os
import random
import numpy as np
import torch

def set_seeds(seed):
    """Set seeds for reproducibility """
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        

set_seeds(seed=42)

In [16]:
import math
from transformers import Trainer, pipeline, TrainingArguments
from typing import Any
from transformers.trainer_utils import EvalPrediction


train_args = TrainingArguments(
    output_dir='model_checkpoints_gemma2_27b_mlm_finetune',
    logging_dir='./model_logs_gemma2_27b_mlm_finetune',
    learning_rate=2e-5,
    weight_decay=0.01,
    lr_scheduler_type='cosine',
    warmup_ratio=0.0,
    num_train_epochs=3,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    bf16=True,
    report_to="wandb",
    optim='adamw_8bit',
    eval_strategy='steps',
    save_strategy="steps",
    eval_steps=200,
    logging_steps=20,
    save_steps=200,
    save_total_limit=10,
    metric_for_best_model='eval_f1',
    greater_is_better=True,
    load_best_model_at_end=True,
)

In [17]:
import wandb

# Initialize with team/entity
wandb.init(
    project="unlp-span-ident-task",
    entity="bazdyrev99-igor-sikorsky-kyiv-polytechnic-institute", 
    name='gemma2-27b-mlm-finetune'
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbazdyrev99[0m ([33mIASA-BA-Diploma-Ivan-Bashtovyi[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [18]:
from itertools import chain

positive_class_balance = pd.Series(list(chain(*df.labels.tolist()))).mean()
positive_class_balance

0.22925695654588976

In [19]:
import math
from transformers import Trainer, pipeline, TrainingArguments
from typing import Any
from tqdm.autonotebook import tqdm
from transformers.trainer_utils import EvalPrediction

def extract_chars_from_spans(spans):
    """
    Given a list of spans (each a tuple (start, end)),
    return a set of character indices for all spans.
    """
    char_set = set()
    for start, end in spans:
        # Each span covers positions start, start+1, ..., end-1.
        char_set.update(range(start, end))
    return char_set

class SpanEvaluationTrainer(Trainer):
    def __init__(
        self,
        model: Any = None,
        args: TrainingArguments = None,
        data_collator: Any = None,
        train_dataset: Any = None,
        eval_dataset: Any = None,
        tokenizer: Any = None,
        desired_positive_ratio: float = 0.25,
        **kwargs,
    ):
        """
        Initialize the Trainer with our custom compute_metrics.
        """
        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            compute_metrics=self.compute_metrics,  # assign our custom compute_metrics
            **kwargs,
        )
        self.desired_positive_ratio = desired_positive_ratio

    def _calculate_inner_metric(self, gt_spans_all, pred_spans_all):
        total_true_chars = 0
        total_pred_chars = 0
        total_overlap_chars = 0
        for true_spans, pred_spans in zip(gt_spans_all, pred_spans_all):
            if isinstance(true_spans, str):
                try:
                    true_spans = eval(true_spans)
                except Exception:
                    true_spans = []
                    
            # Convert spans to sets of character indices.
            true_chars = extract_chars_from_spans(true_spans)
            pred_chars = extract_chars_from_spans(pred_spans)
            
            total_true_chars += len(true_chars)
            total_pred_chars += len(pred_chars)
            total_overlap_chars += len(true_chars.intersection(pred_chars))
            
            union_chars = true_chars.union(pred_chars)
            
        # Compute precision, recall, and F1.
        precision = total_overlap_chars / total_pred_chars if total_pred_chars > 0 else 0
        recall = total_overlap_chars / total_true_chars if total_true_chars > 0 else 0
        f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        metrics = {
            "precision": precision,
            "recall": recall,
            "f1": f1
        }
        return metrics

    def _find_optimal_threshold(self, probabilities, labels):
        """Finds the threshold that achieves the desired positive class balance."""
        best_th = 0.5  # Default starting point
        best_diff = float("inf")
        optimal_th = best_th
        
        for thold in np.linspace(0.01, 0.99, num=100):
            predictions = (probabilities[:, :, 1] >= thold).astype(int)
            true_predictions = [
                [p for (p, l) in zip(prediction, label) if l != -100]
                for prediction, label in zip(predictions, labels)
            ]
            total_pos = sum([sum(row for row in prediction) for prediction in true_predictions])
            total = sum([len(prediction) for prediction in true_predictions])
            
            positive_ratio = total_pos / total if total > 0 else 0
            
            diff = abs(positive_ratio - self.desired_positive_ratio)
            if diff < best_diff:
                best_diff = diff
                optimal_th = thold
        
        return optimal_th
        
        
    def compute_metrics(self, eval_pred: EvalPrediction) -> dict:
        eval_dataset = self.eval_dataset
        logits, labels = eval_pred
        probabilities = torch.softmax(torch.tensor(logits), dim=-1).cpu().numpy()
    
        #thresholds = np.linspace(0.1, 0.5, num=41)
        thresholds = [self._find_optimal_threshold(probabilities, labels)]
        results = []
        best_f1 = -1
        best_th = 0
        best_metrics = None
    
        for thold in tqdm(thresholds):
            # Apply thresholding instead of argmax
            predictions = (probabilities[:, :, 1] >= thold).astype(int)
    
            true_predictions = [
                [p for (p, l) in zip(prediction, label) if l != -100]
                for prediction, label in zip(predictions, labels)
            ]
    
            pred_spans_all = []
            for pred, offsets in zip(true_predictions, eval_dataset['offset_mapping']):
                samplewise_spans = []
                current_span = None
                for token_label, span in zip(pred, offsets):
                    if token_label == 1:  # If the current token is labeled as an entity (1)
                        if current_span is None:
                            current_span = [span[0], span[1]]  # Start a new span
                        else:
                            current_span[1] = span[1]  # Extend the span to include the current token
                    else:  # If token_label == 0 (not an entity)
                        if current_span is not None:
                            samplewise_spans.append(tuple(current_span))  # Save completed span
                            current_span = None  # Reset for the next entity
    
                # If the last token was part of a span, save it
                if current_span is not None:
                    samplewise_spans.append(tuple(current_span))
    
                pred_spans_all.append(samplewise_spans)
    
            # Store results for this threshold
            current_metrics = self._calculate_inner_metric(eval_dataset['trigger_words'], pred_spans_all)
            if current_metrics['f1'] >= best_f1:
                best_f1 = current_metrics['f1']
                best_th = thold
                best_metrics = current_metrics
                best_metrics['thold'] = thold
                
            
            results.append(current_metrics)
        return best_metrics

In [20]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

trainer = SpanEvaluationTrainer(
    model=model,
    args=train_args,
    train_dataset=ds_train,
    eval_dataset=ds_valid,
    data_collator=data_collator,
    tokenizer=tokenizer,
    desired_positive_ratio=positive_class_balance
)

  super().__init__(


In [None]:
trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)


Step,Training Loss,Validation Loss


In [21]:
import torch
from transformers import Gemma2ForTokenClassification, BitsAndBytesConfig
from peft import PeftModel, get_peft_config, prepare_model_for_kbit_training, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType



nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=False,
   bnb_4bit_compute_dtype=torch.bfloat16
)

model = Gemma2ForTokenClassification.from_pretrained(
    PRETRAINED_MODEL,
    id2label=id2label,
    label2id=label2id,
    torch_dtype=torch.bfloat16,
    device_map="cuda:0",
    quantization_config=nf4_config
)

model = prepare_model_for_kbit_training(model)

model = PeftModel.from_pretrained(model, "./model_checkpoints_gemma2_27b_binary/checkpoint-600/")
# Trainable Parameters
model.print_trainable_parameters()

Loading checkpoint shards:   0%|          | 0/12 [00:00<?, ?it/s]

Some weights of Gemma2ForTokenClassification were not initialized from the model checkpoint at /home/abazdyrev/pretrained_models/gemma-2-27b-it/ and are newly initialized: ['score.bias', 'score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 9,218 || all params: 27,683,867,140 || trainable%: 0.0000


In [22]:
trainer.model = model.cuda().eval()

In [22]:
valid_preds = trainer.predict(ds_valid)

  0%|          | 0/1 [00:00<?, ?it/s]

In [23]:
valid_preds.predictions.shape, valid_preds.label_ids.shape

((764, 1440, 2), (764, 1440))

In [24]:
trainer.compute_metrics((valid_preds.predictions, valid_preds.label_ids))

  0%|          | 0/1 [00:00<?, ?it/s]

{'precision': 0.6380175550535446,
 'recall': 0.6279042139741697,
 'f1': 0.6329204872040508,
 'thold': 0.5049494949494949}

In [None]:
test_preds = trainer.predict(ds_test)

In [26]:
test_probabilities = torch.softmax(torch.tensor(test_preds.predictions), dim=-1).cpu().numpy()

In [27]:
trainer._find_optimal_threshold(test_probabilities, test_preds.label_ids)

0.45545454545454545

In [28]:
final_th = (0.50494+0.4554545)/2 - 0.12
final_th

0.36019725

In [29]:
def inference_aggregation(probabilities, labels, offset_mappings, thold):
    predictions = (probabilities[:, :, 1] >= thold).astype(int)
    true_predictions = [
        [p for (p, l) in zip(prediction, label) if l != -100] for prediction, label in zip(predictions, labels)
    ]
    pred_spans_all = []
    for pred, offsets in zip(true_predictions, offset_mappings):
        samplewise_spans = []
        current_span = None
        for token_label, span in zip(pred, offsets):
            if token_label == 1:  # If the current token is labeled as an entity (1)
                if current_span is None:
                    current_span = [span[0], span[1]]  # Start a new span
                else:
                    current_span[1] = span[1]  # Extend the span to include the current token
            else:  # If token_label == 0 (not an entity)
                if current_span is not None:
                    samplewise_spans.append(tuple(current_span))  # Save completed span
                    current_span = None  # Reset for the next entity
        
                    # If the last token was part of a span, save it
        if current_span is not None:
            samplewise_spans.append(tuple(current_span))
        
        pred_spans_all.append(samplewise_spans)
    return [str(row) for row in pred_spans_all]

In [30]:
valid_probabilities = torch.softmax(torch.tensor(valid_preds.predictions), dim=-1).cpu().numpy()
valid_results = inference_aggregation(valid_probabilities, valid_preds.label_ids, ds_valid['offset_mapping'], final_th)

In [31]:
import pandas as pd
import pandas.api.types
from sklearn.metrics import f1_score
import ast


class ParticipantVisibleError(Exception):
    """Custom exception for participant-visible errors."""
    pass


def score(solution: pd.DataFrame, submission: pd.DataFrame, row_id_column_name: str) -> float:
    """
    Compute span-level F1 score based on overlap.

    Parameters:
    - solution (pd.DataFrame): Ground truth DataFrame with row ID and token labels.
    - submission (pd.DataFrame): Submission DataFrame with row ID and token labels.
    - row_id_column_name (str): Column name for the row identifier.

    Returns:
    - float: The token-level weighted F1 score.

    Example:
    >>> solution = pd.DataFrame({
    ...     "id": [1, 2, 3],
    ...     "trigger_words": [[(612, 622), (725, 831)], [(300, 312)], []]
    ... })
    >>> submission = pd.DataFrame({
    ...     "id": [1, 2, 3],
    ...     "trigger_words": [[(612, 622), (700, 720)], [(300, 312)], [(100, 200)]]
    ... })
    >>> score(solution, submission, "id")
    0.16296296296296295
    """
    if not all(col in solution.columns for col in ["id", "trigger_words"]):
        raise ValueError("Solution DataFrame must contain 'id' and 'trigger_words' columns.")
    if not all(col in submission.columns for col in ["id", "trigger_words"]):
        raise ValueError("Submission DataFrame must contain 'id' and 'trigger_words' columns.")
    
    def safe_parse_spans(trigger_words):
        if isinstance(trigger_words, str):
            try:
                return ast.literal_eval(trigger_words)
            except (ValueError, SyntaxError):
                return []
        if isinstance(trigger_words, (list, tuple, np.ndarray)):
            return trigger_words
        return []

    def extract_tokens_from_spans(spans):
        tokens = set()
        for start, end in spans:
            tokens.update(range(start, end))
        return tokens
    
    solution = solution.copy()
    submission = submission.copy()

    solution["trigger_words"] = solution["trigger_words"].apply(safe_parse_spans)
    submission["trigger_words"] = submission["trigger_words"].apply(safe_parse_spans)

    merged = pd.merge(
        solution,
        submission,
        on="id",
        suffixes=("_solution", "_submission")
    )

    total_true_tokens = 0
    total_pred_tokens = 0
    overlapping_tokens = 0

    for _, row in merged.iterrows():
        true_spans = row["trigger_words_solution"]
        pred_spans = row["trigger_words_submission"]

        true_tokens = extract_tokens_from_spans(true_spans)
        pred_tokens = extract_tokens_from_spans(pred_spans)

        total_true_tokens += len(true_tokens)
        total_pred_tokens += len(pred_tokens)
        overlapping_tokens += len(true_tokens & pred_tokens)

    precision = overlapping_tokens / total_pred_tokens if total_pred_tokens > 0 else 0
    recall = overlapping_tokens / total_true_tokens if total_true_tokens > 0 else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    return f1

In [32]:
from copy import deepcopy

df_gt = df[df.fold==4][['id', 'trigger_words']].reset_index(drop=True)
df_pred = deepcopy(df_gt)
df_pred['trigger_words'] = valid_results
score(df_gt, df_pred, row_id_column_name='id')

0.650320400396776

In [33]:
test_results = inference_aggregation(test_probabilities, test_preds.label_ids, ds_test['offset_mapping'], final_th)

In [34]:
ss = pd.read_csv('sample_submission.csv')
ss['trigger_words'] = test_results

In [35]:
ss.to_csv('submissions/gemma2-27b-binary-cv0.65blackmagic.csv', index=False)

In [42]:
trainer.compute_metrics((valid_preds.predictions, valid_preds.label_ids))

  0%|          | 0/1 [00:00<?, ?it/s]

{'precision': 0.6380175550535446,
 'recall': 0.6279042139741697,
 'f1': 0.6329204872040508,
 'thold': 0.5049494949494949}

In [46]:
class THoptimizer:
    def __init__(self, ds):
        self.eval_dataset = ds
        
    def _calculate_inner_metric(self, gt_spans_all, pred_spans_all):
        total_true_chars = 0
        total_pred_chars = 0
        total_overlap_chars = 0
        for true_spans, pred_spans in zip(gt_spans_all, pred_spans_all):
            if isinstance(true_spans, str):
                try:
                    true_spans = eval(true_spans)
                except Exception:
                    true_spans = []
                    
            # Convert spans to sets of character indices.
            true_chars = extract_chars_from_spans(true_spans)
            pred_chars = extract_chars_from_spans(pred_spans)
            
            total_true_chars += len(true_chars)
            total_pred_chars += len(pred_chars)
            total_overlap_chars += len(true_chars.intersection(pred_chars))
            
            union_chars = true_chars.union(pred_chars)
            
        # Compute precision, recall, and F1.
        precision = total_overlap_chars / total_pred_chars if total_pred_chars > 0 else 0
        recall = total_overlap_chars / total_true_chars if total_true_chars > 0 else 0
        f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        metrics = {
            "precision": precision,
            "recall": recall,
            "f1": f1
        }
        return metrics

    def _find_optimal_threshold(self, probabilities, labels):
        """Finds the threshold that achieves the desired positive class balance."""
        best_th = 0.5  # Default starting point
        best_diff = float("inf")
        optimal_th = best_th
        
        for thold in np.linspace(0.01, 0.99, num=100):
            predictions = (probabilities[:, :, 1] >= thold).astype(int)
            true_predictions = [
                [p for (p, l) in zip(prediction, label) if l != -100]
                for prediction, label in zip(predictions, labels)
            ]
            total_pos = sum([sum(row for row in prediction) for prediction in true_predictions])
            total = sum([len(prediction) for prediction in true_predictions])
            
            positive_ratio = total_pos / total if total > 0 else 0
            
            diff = abs(positive_ratio - self.desired_positive_ratio)
            if diff < best_diff:
                best_diff = diff
                optimal_th = thold
        
        return optimal_th
        
    def compute_metrics(self, eval_pred) -> dict:
        eval_dataset = self.eval_dataset
        logits, labels = eval_pred
        probabilities = torch.softmax(torch.tensor(logits), dim=-1).cpu().numpy()
    
        thresholds = np.linspace(0.1, 0.5, num=41)
        #thresholds = [self._find_optimal_threshold(probabilities, labels)]
        results = []
        best_f1 = -1
        best_th = 0
        best_metrics = None
    
        for thold in tqdm(thresholds):
            # Apply thresholding instead of argmax
            predictions = (probabilities[:, :, 1] >= thold).astype(int)
    
            true_predictions = [
                [p for (p, l) in zip(prediction, label) if l != -100]
                for prediction, label in zip(predictions, labels)
            ]
    
            pred_spans_all = []
            for pred, offsets in zip(true_predictions, eval_dataset['offset_mapping']):
                samplewise_spans = []
                current_span = None
                for token_label, span in zip(pred, offsets):
                    if token_label == 1:  # If the current token is labeled as an entity (1)
                        if current_span is None:
                            current_span = [span[0], span[1]]  # Start a new span
                        else:
                            current_span[1] = span[1]  # Extend the span to include the current token
                    else:  # If token_label == 0 (not an entity)
                        if current_span is not None:
                            samplewise_spans.append(tuple(current_span))  # Save completed span
                            current_span = None  # Reset for the next entity
    
                # If the last token was part of a span, save it
                if current_span is not None:
                    samplewise_spans.append(tuple(current_span))
    
                pred_spans_all.append(samplewise_spans)
    
            # Store results for this threshold
            current_metrics = self._calculate_inner_metric(eval_dataset['trigger_words'], pred_spans_all)
            if current_metrics['f1'] >= best_f1:
                best_f1 = current_metrics['f1']
                best_th = thold
                best_metrics = current_metrics
                best_metrics['thold'] = thold
                
            
            results.append(current_metrics)
        return best_metrics

In [47]:
thoptimizer = THoptimizer(trainer.eval_dataset)

In [48]:
thoptimizer.compute_metrics((valid_preds.predictions, valid_preds.label_ids))

  0%|          | 0/41 [00:00<?, ?it/s]

{'precision': 0.5853007860760009,
 'recall': 0.7316582658724098,
 'f1': 0.6503469603258409,
 'thold': 0.36}

In [51]:
final_th = (0.36 + 0.47)/2
final_th

0.415

In [52]:
from copy import deepcopy

valid_probabilities = torch.softmax(torch.tensor(valid_preds.predictions), dim=-1).cpu().numpy()
valid_results = inference_aggregation(valid_probabilities, valid_preds.label_ids, ds_valid['offset_mapping'], final_th)
df_gt = df[df.fold==4][['id', 'trigger_words']].reset_index(drop=True)
df_pred = deepcopy(df_gt)
df_pred['trigger_words'] = valid_results
score(df_gt, df_pred, row_id_column_name='id')

0.6469428475473636

In [53]:
test_results = inference_aggregation(test_probabilities, test_preds.label_ids, ds_test['offset_mapping'], final_th)

In [54]:
ss = pd.read_csv('sample_submission.csv')
ss['trigger_words'] = test_results

In [55]:
ss.to_csv('submissions/gemma2-27b-binary-cv0.646-regularized-gs.csv', index=False)

In [34]:
import pickle

pickle.dump(valid_preds, open('local_preds/valid_preds_gemma2_binary.pkl', 'wb'))
pickle.dump(test_preds, open('local_preds/test_preds_gemma2_binary.pkl', 'wb'))