In [None]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import logging as log
log.basicConfig(level=log.DEBUG)
import syslog
import os
import json
from nltk import word_tokenize 
import re

import torch
from dataclasses import dataclass
from typing import Dict, List, Tuple, Any, Optional, Union
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer
from torch.utils.data import DataLoader
import logging

import numpy as np
from sklearn.utils.class_weight import compute_class_weight

import json
from pathlib import Path

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
class NLPUtils:
    
    LABELS = {
        'NotMentioned': 0,
        'Entailment': 1,
        'Contradiction': 2,
    }
    
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    @staticmethod
    def load_data(path: Union[str, Path]) -> Dict:
        try:
            path = Path(path)
            with path.open('r', encoding='utf-8') as f:
                return json.load(f)
        except FileNotFoundError:
            raise FileNotFoundError(f"Data file not found at: {path}")
        except json.JSONDecodeError:
            raise json.JSONDecodeError(f"Invalid JSON format in file: {path}")



    @staticmethod
    def clean_text(text: str) -> str:

        if not isinstance(text, str):
            return ""
            
        text = (text.replace('\n', ' ')
                   .strip()
                   .lower())
        
        replacements = [
            (r'\\t', ' '),
            (r'\\r', ' '),
            (r'(.)\1{2,}', r'\1'),
            (r'\s+', ' ')  
        ]
        
        for pattern, replacement in replacements:
            text = re.sub(pattern, replacement, text)
            
        return text.strip()

    @staticmethod
    def extract_hypotheses(data: Dict) -> Dict[str, str]:

        if not isinstance(data, dict) or 'labels' not in data:
            raise ValueError("Invalid data format: missing 'labels' key")
            
        return {
            key: NLPUtils.clean_text(value.get('hypothesis', ''))
            for key, value in data['labels'].items()
        }

    @staticmethod
    def get_hypothesis_index(hypothesis_name: str) -> Optional[int]:
        try:
            return int(hypothesis_name.split('-')[-1])
        except (ValueError, IndexError):
            return None


    @staticmethod
    def tokenize(text: str) -> str:
        try:
            return ' '.join(word_tokenize(text))
        except Exception as e:
            print(f"Tokenization error: {e}")
            return text

    @classmethod
    def get_labels(cls) -> Dict[str, int]:
        return cls.LABELS.copy()

In [4]:
cfg = {
    "model_name": "bert-base-uncased",
    "batch_size": 32,
    "train_path": "/kaggle/input/project-data/train (1).json",
    "test_path": "/kaggle/input/project-data/test (1).json",
    "dev_path": "/kaggle/input/project-data/dev (1).json",
    "max_length": 512,
    "models_save_dir": "/kaggle/working/saved_model",
    "results_dir": "/kaggle/working/results",
    "dataset_dir": "/kaggle/working/dataset_dir"
}

In [5]:
# create dir if not exists
from pathlib import Path
Path(cfg["models_save_dir"]).mkdir(parents=True, exist_ok=True)
Path(cfg["dataset_dir"]).mkdir(parents=True, exist_ok=True)

In [6]:
tokenizer = AutoTokenizer.from_pretrained(cfg['model_name'])

tokenizer.save_pretrained(cfg['models_save_dir'])

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



('/kaggle/working/saved_model/tokenizer_config.json',
 '/kaggle/working/saved_model/special_tokens_map.json',
 '/kaggle/working/saved_model/vocab.txt',
 '/kaggle/working/saved_model/added_tokens.json',
 '/kaggle/working/saved_model/tokenizer.json')

In [7]:
tokenizer = AutoTokenizer.from_pretrained(cfg['models_save_dir'])


In [8]:
!pip install icecream
from icecream import ic

Collecting icecream
  Downloading icecream-2.1.3-py2.py3-none-any.whl.metadata (1.4 kB)
Downloading icecream-2.1.3-py2.py3-none-any.whl (8.4 kB)
Installing collected packages: icecream
Successfully installed icecream-2.1.3


In [None]:
@dataclass
class Context:

    """Represents a context window in the document."""

    doc_id: int
    start_char_idx: int
    end_char_idx: int
    spans: List[Dict[str, Any]]


In [None]:
@dataclass
class DataPoint:
    
    hypothesis: str
    premise: str
    marked_beg: bool
    marked_end: bool
    nli_label: torch.Tensor
    span_labels: torch.Tensor
    doc_id: torch.Tensor
    hypothesis_id: torch.Tensor
    span_ids: torch.Tensor


In [None]:
class NLIDataset(Dataset):
    
    SPAN_TOKEN = '[SPAN]'

    def __init__(
        self,
        documents: List[Dict],
        tokenizer: PreTrainedTokenizer,
        hypothesis: Dict[str, str],
        context_sizes: List[int],
        surround_character_size: int
    ):
        self.tokenizer = tokenizer
        self.tokenizer.add_special_tokens({'additional_special_tokens': [self.SPAN_TOKEN]})
        
        contexts = self._generate_contexts(documents, context_sizes, surround_character_size)
        self.data_points = self._create_data_points(documents, contexts, hypothesis)
        self.span_token_id = self.tokenizer.convert_tokens_to_ids(self.SPAN_TOKEN)

    def _generate_contexts(
        self,
        documents: List[Dict],
        context_sizes: List[int],
        surround_character_size: int
    ) -> List[Context]:
        
        contexts = []
        
        for context_size in context_sizes:
            for doc_id, doc in enumerate(documents):
                char_idx = 0
                document_spans = doc['spans']
                
                while char_idx < len(doc['text']):
                    context = self._create_context(
                        doc_id, char_idx, context_size, document_spans
                    )
                    
                    if not contexts or context != contexts[-1]:
                        contexts.append(context)
                        
                        # Update char_idx based on spans
                        if (len(context.spans) == 1 and 
                            not context.spans[0]['marked']):
                            char_idx = context.end_char_idx - surround_character_size
                        else:
                            char_idx = context.spans[-1]['start_char_idx'] - surround_character_size
                    else:
                        char_idx = context.end_char_idx - surround_character_size
                        
        return contexts

    def _create_context(
        self,
        doc_id: int,
        char_idx: int,
        context_size: int,
        document_spans: List[Tuple[int, int]]
    ) -> Context:
        """Create a single context window."""
        context = Context(
            doc_id=doc_id,
            start_char_idx=char_idx,
            end_char_idx=char_idx + context_size,
            spans=[]
        )
        
        for span_id, (start, end) in enumerate(document_spans):
            if end <= char_idx:
                continue
                
            context.spans.append({
                'start_char_idx': max(start, char_idx),
                'end_char_idx': min(end, char_idx + context_size),
                'marked': start >= char_idx and end <= char_idx + context_size,
                'span_id': span_id
            })
            
            if end > char_idx + context_size:
                break
                
        return context

    def _create_data_points(
        self,
        documents: List[Dict],
        contexts: List[Context],
        hypothesis: Dict[str, str]
    ) -> List[DataPoint]:
        
        data_points = []
        label_dict = NLPUtils.get_labels() 
        
        for nda_name, nda_desc in hypothesis.items():
            for context in contexts:
                doc = documents[context.doc_id]
                nli_label = label_dict[doc['annotation_sets'][0]['annotations'][nda_name]['choice']]
                
                premise, span_ids, span_labels = self._process_spans(
                    context, doc, nda_name
                )
                
                if nli_label == label_dict['NotMentioned']:
                    span_labels = torch.zeros(len(span_labels), dtype=torch.long)
                else:
                    span_labels = torch.tensor(span_labels, dtype=torch.long)
                
                data_point = DataPoint(
                    hypothesis=nda_desc,
                    premise=premise,
                    marked_beg=context.spans[0]['marked'],
                    marked_end=context.spans[-1]['marked'] or len(context.spans) == 1,
                    nli_label=torch.tensor(nli_label, dtype=torch.long),
                    span_labels=span_labels,
                    doc_id=torch.tensor(context.doc_id, dtype=torch.long),
                    hypothesis_id=torch.tensor(self._get_hypothesis_idx(nda_name), dtype=torch.long),
                    span_ids=torch.tensor(span_ids, dtype=torch.long)
                )
                
                data_points.append(data_point)
                
        return data_points

    def _process_spans(
        self,
        context: Context,
        document: Dict,
        nda_name: str
    ) -> Tuple[str, List[int], List[int]]:
        
        premise = ""
        span_ids = []
        span_labels = []
        
        for span in context.spans:
            # Calculate span label
            is_relevant = int(span['span_id'] in 
                            document['annotation_sets'][0]['annotations'][nda_name]['spans'])
            span_label = 2 * is_relevant - 1  # Convert 0->-1 and 1->1
            
            if span['marked']:
                span_labels.append(span_label)
                span_ids.append(span['span_id'])
            
            premise += f" {self.SPAN_TOKEN} "
            premise += document['text'][span['start_char_idx']:span['end_char_idx']]
            
        return premise.strip(), span_ids, span_labels

    def __len__(self) -> int:
        return len(self.data_points)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        data_point = self.data_points[idx]
        
        tokenized_data = self._tokenize_inputs(data_point)
        
        span_indices = self._process_span_indices(
            tokenized_data, data_point.marked_beg, data_point.marked_end
        )
        
        span_ids = data_point.span_ids[:len(span_indices)]
        
        return {
            'input_ids': tokenized_data['input_ids'],
            'attention_mask': tokenized_data['attention_mask'],
            'token_type_ids': tokenized_data['token_type_ids'],
            'span_indices': span_indices,
            'nli_label': data_point.nli_label,
            'span_labels': data_point.span_labels[:len(span_indices)],
            'data_for_metrics': {
                'doc_id': data_point.doc_id,
                'hypothesis_id': data_point.hypothesis_id,
                'span_ids': span_ids,
            }
        }

    def _tokenize_inputs(self, data_point: DataPoint) -> Dict[str, torch.Tensor]:
        tokenized = self.tokenizer(
            [data_point.hypothesis],
            [data_point.premise],
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )
        
        return {
            'input_ids': tokenized['input_ids'].squeeze(),
            'attention_mask': tokenized['attention_mask'].squeeze(),
            'token_type_ids': tokenized['token_type_ids'].squeeze(),
        }

    def _process_span_indices(
        self,
        tokenized_data: Dict[str, torch.Tensor],
        marked_beg: bool,
        marked_end: bool
    ) -> torch.Tensor:
        
        span_indices = torch.where(tokenized_data['input_ids'] == self.span_token_id)[0]
        
        if not marked_beg:
            span_indices = span_indices[1:]
            
        if not marked_end or tokenized_data['attention_mask'][-1] == 0:
            span_indices = span_indices[:-1]
            
        return span_indices

    @staticmethod
    def _get_hypothesis_idx(hypothesis_name: str) -> int:
        return int(hypothesis_name.split('-')[-1])

In [None]:
@dataclass
class DataConfig:
    
    train_path: str
    dev_path: str
    test_path: str
    context_size: int = 1100
    surround_character_size: int = 50
    batch_size: int = 32
    num_workers: int = 4

In [None]:
class NLIDataManager:
    """
    Manages the loading, preprocessing, and creation of NLI datasets.
    
    Attributes:
        config: DataConfig object containing configuration parameters
        tokenizer: Tokenizer for processing text
        hypothesis: Dictionary of hypotheses extracted from training data
    """
    
    def __init__(
        self,
        config: DataConfig,
        tokenizer: PreTrainedTokenizer
    ):
        
        self.config = config
        self.tokenizer = tokenizer
        self.hypothesis = None
        self.datasets = {}
        self._validate_paths()
        
    def _validate_paths(self) -> None:
        
        for path_name, path in {
            'train': self.config.train_path,
            'dev': self.config.dev_path,
            'test': self.config.test_path
        }.items():
            if not os.path.exists(path):
                raise FileNotFoundError(
                    f"{path_name} data path does not exist: {path}"
                )
    
    def load_and_preprocess(self) -> None:

        logger.info("Loading and preprocessing datasets...")
        
        train_data = self._load_data(self.config.train_path)
        dev_data = self._load_data(self.config.dev_path)
        test_data = self._load_data(self.config.test_path)
        
        self.hypothesis = self._extract_hypothesis(train_data)
        
        train_documents = train_data['documents']
        dev_documents = dev_data['documents']
        test_documents = test_data['documents']
        
        logger.info(f"Loaded documents - Train: {len(train_documents)}, "
                   f"Dev: {len(dev_documents)}, Test: {len(test_documents)}")
        
        self.datasets = {
            'train': self._create_dataset(train_documents, 'train'),
            'dev': self._create_dataset(dev_documents, 'dev'),
            'test': self._create_dataset(test_documents, 'test')
        }
        
        logger.info("Dataset creation completed successfully")
    

    @staticmethod
    def _load_data(path: str) -> Dict:

        try:
            return NLPUtils.load_data(path)
        except Exception as e:
            logger.error(f"Error loading data from {path}: {str(e)}")
            raise
    
    @staticmethod
    def _extract_hypothesis(data: Dict) -> Dict[str, str]:
        try:
            return NLPUtils.extract_hypotheses(data) 
        except Exception as e:
            logger.error(f"Error extracting hypothesis: {str(e)}")
            raise
    
    def _create_dataset(
        self,
        documents: List[Dict],
        split: str
    ) -> NLIDataset:

        try:
            return NLIDataset(
                documents=documents,
                tokenizer=self.tokenizer,
                hypothesis=self.hypothesis,
                context_sizes=[self.config.context_size],
                surround_character_size=self.config.surround_character_size
            )
        except Exception as e:
            logger.error(f"Error creating {split} dataset: {str(e)}")
            raise
    
    def get_dataloaders(
        self,
        shuffle_train: bool = True
    ) -> Dict[str, DataLoader]:

        if not self.datasets:
            raise ValueError("Datasets not initialized. Call load_and_preprocess first.")
        
        return {
            split: DataLoader(
                dataset,
                batch_size=self.config.batch_size,
                shuffle=(shuffle_train and split == 'train'),
                num_workers=self.config.num_workers,
                pin_memory=True
            )
            for split, dataset in self.datasets.items()
        }
    
    def get_dataset(self, split: str) -> Optional[NLIDataset]:
        return self.datasets.get(split)

  data_point['span_labels'] = torch.tensor(span_labels, dtype=torch.long)


In [None]:
config = DataConfig(
    train_path=cfg['train_path'],
    dev_path=cfg['dev_path'],
    test_path=cfg['test_path']
)

data_manager = NLIDataManager(config, tokenizer)

data_manager.load_and_preprocess()

dataloaders = data_manager.get_dataloaders()

train_dataset = data_manager.get_dataset('train')
dev_dataset = data_manager.get_dataset('dev')
test_dataset = data_manager.get_dataset('test')

In [13]:
ic(len(train_dataset), len(dev_dataset), len(test_dataset))


ic| len(train_dataset): 97546
    len(dev_dataset): 15385
    len(test_dataset): 28645


(97546, 15385, 28645)

In [None]:
logger = logging.getLogger(__name__)

@dataclass
class WeightResult:
    nli_weights: np.ndarray
    span_weight: float
    nli_class_distribution: dict
    span_class_distribution: dict

In [None]:
class ClassWeightCalculator:
    
    def __init__(self, exclude_span_label: int = -1):

        self.exclude_span_label = exclude_span_label
    
    def calculate_weights(
        self,
        dataset: Union[Dataset, List[dict]]
    ) -> WeightResult:

        try:
            nli_labels = self._extract_nli_labels(dataset)
            span_labels = self._extract_span_labels(dataset)
            
            if not nli_labels or not span_labels:
                raise ValueError("No labels found in dataset")
            
            nli_weights = self._compute_nli_weights(nli_labels)
            span_weight = self._compute_span_weight(span_labels)
            
            nli_distribution = self._compute_class_distribution(nli_labels)
            span_distribution = self._compute_class_distribution(span_labels)
            
            return WeightResult(
                nli_weights=nli_weights,
                span_weight=span_weight,
                nli_class_distribution=nli_distribution,
                span_class_distribution=span_distribution
            )
            
        except Exception as e:
            logger.error(f"Error calculating weights: {str(e)}")
            raise
    
    def _extract_nli_labels(self, dataset: Union[Dataset, List[dict]]) -> np.ndarray:
        try:
            if isinstance(dataset, Dataset):
                labels = [dataset[i]['nli_label'] for i in range(len(dataset))]
            else:
                labels = [x['nli_label'] for x in dataset]
            
            labels = [
                label.item() if isinstance(label, torch.Tensor) else label
                for label in labels
            ]
            
            return np.array(labels)
        
        except Exception as e:
            raise ValueError(f"Error extracting NLI labels: {str(e)}")
    
    def _extract_span_labels(self, dataset: Union[Dataset, List[dict]]) -> np.ndarray:
        try:
            span_labels = []
            
            if isinstance(dataset, Dataset):
                for i in range(len(dataset)):
                    labels = dataset[i]['span_labels']
                    if isinstance(labels, torch.Tensor):
                        labels = labels.numpy()
                    span_labels.extend(labels)
            else:
                for item in dataset:
                    labels = item['span_labels']
                    if isinstance(labels, torch.Tensor):
                        labels = labels.numpy()
                    span_labels.extend(labels)
            
            span_labels = [
                label for label in span_labels 
                if label != self.exclude_span_label
            ]
            
            return np.array(span_labels)
            
        except Exception as e:
            raise ValueError(f"Error extracting span labels: {str(e)}")
    
    def _compute_nli_weights(self, labels: np.ndarray) -> np.ndarray:

        unique_classes = np.unique(labels)
        weights = compute_class_weight(
            'balanced',
            classes=unique_classes,
            y=labels
        )
        return weights
    
    def _compute_span_weight(self, labels: np.ndarray) -> float:

        if len(labels) == 0:
            raise ValueError("No valid span labels found")
            
        n_negative = np.sum(labels == 0)
        n_positive = np.sum(labels == 1)
        
        if n_positive == 0:
            logger.warning("No positive span labels found")
            return 1.0
            
        return n_negative / n_positive
    


    @staticmethod
    def _compute_class_distribution(labels: np.ndarray) -> dict:
    
        unique, counts = np.unique(labels, return_counts=True)
        total = len(labels)
        return {
            label: {
                'count': count,
                'percentage': (count / total) * 100
            }
            for label, count in zip(unique, counts)
        }

In [None]:
def print_weight_statistics(result: WeightResult) -> None:
    
    print("\nClass Weight Statistics:")
    print("\nNLI Classification:")
    
    for class_idx, weight in enumerate(result.nli_weights):
        dist = result.nli_class_distribution.get(class_idx, {})
        print(f"Class {class_idx}:")
        print(f"  Weight: {weight:.3f}")
        print(f"  Count: {dist.get('count', 0)}")
        print(f"  Percentage: {dist.get('percentage', 0):.2f}%")
    
    print("\nSpan Classification:")
    print(f"Positive/Negative Weight: {result.span_weight:.3f}")
    
    for class_idx, stats in result.span_class_distribution.items():
        print(f"Class {class_idx}:")
        print(f"  Count: {stats['count']}")
        print(f"  Percentage: {stats['percentage']:.2f}%")

In [None]:
calculator = ClassWeightCalculator(exclude_span_label=-1)

weights = calculator.calculate_weights(train_dataset)

print_weight_statistics(weights)

nli_weights = torch.tensor(weights.nli_weights, device=NLPUtils.DEVICE)
span_weight = torch.tensor(weights.span_weight, device=NLPUtils.DEVICE)

In [16]:
ic(nli_weights, span_weight)


ic| nli_weights: [0.9712447975785092, 0.6134852801519468, 2.9380440348182284]
    span_weight: 24.93889485618809


([0.9712447975785092, 0.6134852801519468, 2.9380440348182284],
 24.93889485618809)

In [None]:
from transformers import PreTrainedModel, PretrainedConfig

class ContractNLIConfig(PretrainedConfig):
    
    def __init__(self, nli_weights = [1, 1, 1], span_weight = 1, lambda_ = 1, bert_model_name = cfg['model_name'], num_labels = len(NLPUtils.get_labels()), ignore_span_label = 2, **kwargs):
        super().__init__(**kwargs)
        self.bert_model_name = bert_model_name
        self.num_labels = num_labels
        self.lambda_ = lambda_
        self.ignore_span_label = ignore_span_label
        self.nli_weights = nli_weights
        self.span_weight = span_weight

In [None]:
from transformers import AutoModel
from torch import nn

class ContractNLI(PreTrainedModel):
    config_class = ContractNLIConfig

    def __init__(self, config):
        super().__init__(config)
        self.bert = AutoModel.from_pretrained(config.bert_model_name)
        self.bert.resize_token_embeddings(self.bert.config.vocab_size + 1, pad_to_multiple_of=8)
        self.bert.eval()
        for param in self.bert.parameters():
            param.requires_grad = False

        self.embedding_dim = self.bert.config.hidden_size
        self.num_labels = config.num_labels
        self.lambda_ = config.lambda_
        self.nli_criterion = nn.CrossEntropyLoss(weight=torch.tensor(self.config.nli_weights, dtype=torch.float32))
        self.span_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.config.span_weight, dtype=torch.float32))

        self.span_classifier = nn.Sequential(
            nn.Linear(self.embedding_dim * 4, self.embedding_dim * 4),
            nn.ReLU(),
            nn.Linear(self.embedding_dim * 4, self.embedding_dim * 2),
            nn.ReLU(),
            nn.Linear(self.embedding_dim * 2, 1)
        )

        self.nli_classifier = nn.Sequential(
            nn.Linear(self.embedding_dim * 4, self.embedding_dim * 4),
            nn.ReLU(),
            nn.Linear(self.embedding_dim * 4, self.embedding_dim * 2),
            nn.ReLU(),
            nn.Linear(self.embedding_dim * 2, self.num_labels)
        )

        self.init_weights()

    def _init_weights(self, module):
        
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=self.bert.config.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def forward(self, input_ids, attention_mask, token_type_ids, span_indices):
        outputs = self.bert(input_ids, attention_mask, token_type_ids, output_hidden_states=True).hidden_states[-4:]
        outputs = torch.stack(outputs, dim=0)
        outputs = outputs.permute([1, 2, 0, 3])
        outputs = outputs.reshape([outputs.shape[0], outputs.shape[1], -1])

        gather = torch.gather(outputs, 1, span_indices.unsqueeze(2).expand(-1, -1, outputs.shape[-1]))

        masked_gather = gather[span_indices != 0]
        span_logits = self.span_classifier(masked_gather)
        nli_logits = self.nli_classifier(outputs[:, 0, :])

        return span_logits, nli_logits

In [None]:
from transformers import Trainer

class ContractNLITrainer(Trainer):
    def __init__(self, *args, data_collator=None, **kwargs):
        super().__init__(*args, data_collator=data_collator, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):
        span_label = inputs.pop('span_labels')
        nli_label = inputs.pop('nli_label')
        inputs.pop('data_for_metrics')

        outputs = model(**inputs)
        span_logits, nli_logits = outputs[0], outputs[1]
        
        mask = span_label != -1
        span_label = span_label[mask]
        span_logits = span_logits[mask]
        
        span_label = span_label.float()
        span_logits = span_logits.float()
        
        span_label = span_label.view(-1)
        span_logits = span_logits.view(-1)        

        # if len(true_span_labels) == 0 or len(pred_span_labels) != len(true_span_labels):
        #     span_loss = torch.tensor(0, dtype=torch.float32, device=NLPUtils.DEVICE)
        # else:
        #     span_loss = self.model.span_criterion(pred_span_labels, true_span_labels)
        
        if len(span_label) == 0:
            span_loss = torch.tensor(0, dtype=torch.float32, device=NLPUtils.DEVICE)
        else:
            span_loss = self.model.span_criterion(span_logits, span_label)

        nli_loss = self.model.nli_criterion(nli_logits, nli_label)

        if torch.isnan(nli_loss):
            nli_loss = torch.tensor(0, dtype=torch.float32, device=NLPUtils.DEVICE)

        if torch.isnan(span_loss):
            span_loss = torch.tensor(0, dtype=torch.float32, device=NLPUtils.DEVICE)

        loss = span_loss + self.model.lambda_ * nli_loss

        if loss.item() == 0:
            loss = torch.tensor(0, dtype=torch.float32, device=NLPUtils.DEVICE, requires_grad=True)

        return (loss, outputs) if return_outputs else loss

    @staticmethod
    def collate_fn(features):
        span_indices_list = [feature['span_indices'] for feature in features]
        max_len = max([len(span_indices) for span_indices in span_indices_list])
        span_indices_list = [torch.cat([span_indices, torch.zeros(max_len - len(span_indices), dtype=torch.long)]) for span_indices in span_indices_list]

        span_ids_list = [feature['data_for_metrics']['span_ids'] for feature in features]
        max_len = max([len(span_ids) for span_ids in span_ids_list])
        
        # pad to get the doc id and hypothesis id for each input while evaluating
        span_ids_list = [torch.cat([span_ids, torch.full((max_len - len(span_ids),), -1)]) for span_ids in span_ids_list]
        
        input_ids = torch.stack([feature['input_ids'] for feature in features])
        attention_mask = torch.stack([feature['attention_mask'] for feature in features])
        token_type_ids = torch.stack([feature['token_type_ids'] for feature in features])
        span_indices = torch.stack(span_indices_list)
        nli_label = torch.stack([feature['nli_label'] for feature in features])
        span_label = torch.cat([feature['span_labels'] for feature in features], dim=0)
        data_for_metrics = {
            'doc_id': torch.stack([feature['data_for_metrics']['doc_id'] for feature in features]),
            'hypothesis_id': torch.stack([feature['data_for_metrics']['hypothesis_id'] for feature in features]),
            'span_ids': torch.stack(span_ids_list),
        }

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'token_type_ids': token_type_ids,
            'span_indices': span_indices,
            'nli_label': nli_label,
            'span_labels': span_label,
            'data_for_metrics': data_for_metrics,
        }

In [None]:
from transformers import TrainingArguments
from transformers import EarlyStoppingCallback

training_args = TrainingArguments(
    auto_find_batch_size=True,
    output_dir=cfg['results_dir'],   
    num_train_epochs=10,           
    gradient_accumulation_steps=4,   
    logging_strategy='epoch',
    eval_steps=2,
    save_steps=2,
    logging_steps=2,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=2,
    load_best_model_at_end=True,
    fp16=True,
    label_names=['nli_label', 'span_labels', 'data_for_metrics'],
    report_to='none',
)



In [21]:
def wandb_hp_space(trial):
    return {
        "method": "random",
        "metric": {
            "name": "eval/loss",
            "goal": "minimize"
        },
        "parameters": {
            "learning_rate": {
                "values": [1e-5, 3e-5, 5e-5]
            },
            "lambda_": {
                "values": [0.05, 0.1, 0.4]
            },
        }
    }

In [22]:
def model_init(trial):
    if trial is None:
        return ContractNLI(ContractNLIConfig(nli_weights=nli_weights, span_weight=span_weight))

    return ContractNLI(ContractNLIConfig(nli_weights=nli_weights, span_weight=span_weight, lambda_=trial['lambda_']))

In [None]:

trainer = ContractNLITrainer(
    model=None,                       
    args=training_args,              
    train_dataset=train_dataset,        
    eval_dataset=dev_dataset,     
    data_collator=ContractNLITrainer.collate_fn,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.001)],
    model_init=model_init,
)

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [28]:
trainer.train()


Epoch,Training Loss,Validation Loss
0,2.6348,3.922571
1,2.2836,4.243315
2,2.1869,2.967618
4,2.0903,5.793064
5,2.0722,5.055171


TrainOutput(global_step=7456, training_loss=2.2392635099877616, metrics={'train_runtime': 7060.2097, 'train_samples_per_second': 56.32, 'train_steps_per_second': 1.759, 'total_flos': 8.353250792946893e+16, 'train_loss': 2.2392635099877616, 'epoch': 5.9995976664655})

# Metric


In [24]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import logging as log
log.basicConfig(level=log.DEBUG)

In [25]:
import sys
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

os.environ['WANDB_ENTITY'] = 'contract-nli-db'
os.environ['WANDB_PROJECT'] = 'contract-nli-metric'
os.environ['WANDB_LOG_MODEL'] = 'end'

In [27]:
import json

cfg = {
    "train_path": "/kaggle/input/project-data/train (1).json",
    "test_path": "/kaggle/input/project-data/test (1).json",
    "dev_path": "/kaggle/input/project-data/dev (1).json",
    "model_name": "bert-base-uncased",
    "max_length": 512,
    "models_save_dir": "/kaggle/input/anlp-project-trained-model/checkpoint",
    "dataset_dir": "./scratch/shu7bh/contract_nli/dataset",
    "results_dir": "./scratch/shu7bh/contract_nli/results",
    "trained_model_dir": "/kaggle/input/anlp-project-trained-model/",
    "batch_size": 32
}

cfg

{'train_path': '/kaggle/input/project-data/train (1).json',
 'test_path': '/kaggle/input/project-data/test (1).json',
 'dev_path': '/kaggle/input/project-data/dev (1).json',
 'model_name': 'bert-base-uncased',
 'max_length': 512,
 'models_save_dir': '/kaggle/input/anlp-project-trained-model/checkpoint',
 'dataset_dir': './scratch/shu7bh/contract_nli/dataset',
 'results_dir': './scratch/shu7bh/contract_nli/results',
 'trained_model_dir': '/kaggle/input/anlp-project-trained-model/',
 'batch_size': 32}

In [28]:
# create dir if not exists
from pathlib import Path
Path(cfg["models_save_dir"]).mkdir(parents=True, exist_ok=True)
Path(cfg["dataset_dir"]).mkdir(parents=True, exist_ok=True)

In [29]:
tokenizer = AutoTokenizer.from_pretrained(cfg['model_name'])



In [30]:
!pip install icecream
from icecream import ic

  pid, fd = os.forkpty()
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [47]:
dev_data = load_data(os.path.join(cfg['dev_path']))
test_data = load_data(os.path.join(cfg['test_path']))

hypothesis = get_hypothesis(dev_data)

dev_data = dev_data['documents']
test_data = test_data['documents']

# dev_data = dev_data[:50]
# test_data = test_data[:50]

ic.disable()

ic(len(dev_data), len(test_data))
dev_dataset = NLIDataset(dev_data, tokenizer, hypothesis, [1100], 50)
test_dataset = NLIDataset(test_data, tokenizer, hypothesis, [1100], 50)

ic.enable()

del dev_data
del test_data
del hypothesis

  data_point['span_labels'] = torch.tensor(span_labels, dtype=torch.long)


In [48]:
print(len(dev_dataset))
print(len(test_dataset))

15385
28645


In [49]:
from sklearn.metrics import precision_recall_curve
import numpy as np
def get_micro_average_precision_at_recall(y_true, y_pred, recall_level):
    precision, recall, _ = precision_recall_curve(y_true, y_pred)
    return np.interp(recall_level, recall[::-1], precision[::-1])

In [None]:
# Import numpy and sklearn.metrics
import numpy as np
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import precision_score


def calculate_micro_average_precision(y_true, y_pred):

    num_classes = len(np.unique(y_true))
    
    if num_classes == 0:
        return 0.0

    average_precision = 0.0

    for class_idx in range(num_classes):
        y_true_indices = np.where(y_true == class_idx)
        average_precision += ic(precision_score(
            y_true[y_true_indices], y_pred[y_true_indices], average="micro"
        ))

    return average_precision / num_classes

In [None]:
from sklearn.metrics import f1_score
def calculate_f1_score_for_class(y_true, y_pred, class_idx):
    
    y_true_indices = np.where(y_true == class_idx)
    
    return f1_score(
        y_true[y_true_indices], y_pred[y_true_indices], average="macro"
    )

In [None]:
def precision_at_recall(y_true, y_scores, recall_threshold):
    precision, recall, threshold = precision_recall_curve(y_true, y_scores)
    idx = (np.abs(recall - recall_threshold)).argmin() 
    ic(threshold[idx])
    return precision[idx]

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    auto_find_batch_size=True,
    output_dir=cfg['results_dir'],  
    num_train_epochs=10,       
    gradient_accumulation_steps=4,
    logging_strategy='epoch',
    # eval_steps=0.25,
    # save_steps=0.25,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=2,
    load_best_model_at_end=True,
    # fp16=True,
    label_names=['nli_label', 'span_labels', 'data_for_metrics'],
    report_to='none',
)



In [54]:
cfg['trained_model_dir']

'/kaggle/input/anlp-project-trained-model/'

In [None]:
artifact_dir = '/kaggle/input/fully-trained-model/checkpoint-12194'  
model = ContractNLI.from_pretrained(artifact_dir).to(NLPUtils.DEVICE)

In [None]:
from transformers import Trainer
from sklearn.metrics import accuracy_score
from sklearn.metrics import average_precision_score
from tqdm import tqdm
import numpy as np

class ContractNLIMetricTrainer(ContractNLITrainer):
    def __init__(self, *args, data_collator=None, **kwargs):
        super().__init__(*args, data_collator=data_collator, **kwargs)

    def evaluate(self, eval_dataset=None, ignore_keys=None):
        self.model.eval()
        self.dataloader = ic(self.get_eval_dataloader(eval_dataset))

        eval_nli_labels = []
        eval_nli_preds = []
        true_labels_per_span = {}
        probs_per_span = {}

        nli_metrics = {}

        for inputs in tqdm(self.dataloader):
            inputs = self._prepare_inputs(inputs)
            span_labels = inputs.pop('span_labels')
            nli_labels = inputs.pop('nli_label')
            data_for_metrics = inputs.pop('data_for_metrics')

            span_indices_to_consider = torch.where(span_labels != -1)[0]

            with torch.no_grad():
                outputs = self.model(**inputs)
                span_logits, nli_logits = outputs[0], outputs[1]

                span_labels = span_labels.float()
                span_logits = span_logits.float()
                
                span_labels = span_labels.view(-1)
                span_logits = span_logits.view(-1)

                # start_index = 0
                
                indices_considered = 0 # total number of span indices considered

                # find the corresponding span index in data_for_metrics['span_ids'] considering -1 to be padding index
                # ic(span_index)
                for i, span_index_row in enumerate(data_for_metrics['span_ids']):
                    current_index = 0 # current row's first -1 index
                    # ic(span_index_row)
                    first_minus_one_index = torch.where(span_index_row == -1)[0]
                    # ic(first_minus_one_index)
                    if len(first_minus_one_index) == 0:
                        first_minus_one_index = len(span_index_row)
                    else:
                        first_minus_one_index = first_minus_one_index[0].item()

                    key = str(data_for_metrics['doc_id'][i].item())+ '-' + str(data_for_metrics['hypothesis_id'][i].item())

                    # mask span_labels and span_logits for the current row
                    mask = span_labels[indices_considered:indices_considered+first_minus_one_index] != -1
                    span_logits_masked = span_logits[indices_considered:indices_considered+first_minus_one_index][mask]

                    spans_contribution = torch.sum(torch.sigmoid(span_logits_masked)) / (len(span_logits_masked)) 

                    if key in nli_metrics:
                        nli_metrics[key]['spans_contribution'].append(spans_contribution)
                        nli_metrics[key]['nli_logits'].append(nli_logits[i])
                    else:
                        nli_metrics[key] = {}
                        nli_metrics[key]['true_nli_labels'] = nli_labels[i]
                        nli_metrics[key]['spans_contribution'] = [spans_contribution]
                        nli_metrics[key]['nli_logits'] = [nli_logits[i]]
                    
                    current_index = first_minus_one_index
                    indices_considered += current_index
                    
                    # ic(indices_considered)
                    # ic(current_index)
                    cnt = 0 # count to keep track of the number of span indices added in dictionary
                    
                    for span_index in span_indices_to_consider:

                        if span_index < indices_considered:
                            cnt += 1
                            value_index = span_index - (indices_considered - current_index)
                            doc_id = data_for_metrics['doc_id'][i]
                            hypothesis_id = data_for_metrics['hypothesis_id'][i]
                            span_id = data_for_metrics['span_ids'][i][value_index]
                            key = str(doc_id)+ '-' + str(hypothesis_id)+ '-' + str(span_id)
                            true_labels_per_span[key] = span_labels[span_index]
                            if key in probs_per_span:
                                probs_per_span[key].append(torch.sigmoid(span_logits[span_index]))
                                # probs_per_span[key].append(span_logits[value_index])
                            else:
                                probs_per_span[key] = [torch.sigmoid(span_logits[span_index])]
                                # probs_per_span[key] = [span_logits[value_index]]
                        else: 
                            break 
                    
                    span_indices_to_consider = span_indices_to_consider[cnt:]

                # eval_span_preds = torch.tensor(eval_span_preds.squeeze(1), dtype=torch.long)

                nli_preds = torch.argmax(torch.softmax(nli_logits, dim=1), dim=1)
                eval_nli_labels.extend(nli_labels.cpu().numpy())
                eval_nli_preds.extend(nli_preds.cpu().numpy())

        eval_span_labels = []
        eval_span_preds = []

        for key in true_labels_per_span:
            eval_span_labels.append(true_labels_per_span[key].item())
            eval_span_preds.append(torch.mean(torch.stack(probs_per_span[key])).item())

        ##### For NLI probablities #####

        # for key in nli_metrics:
        #     nli_metrics[key]['nli_logits'] = torch.stack(nli_metrics[key]['nli_logits'])
        #     nli_metrics[key]['spans_contribution'] = torch.stack(nli_metrics[key]['spans_contribution'])

        #     span_sum = torch.sum(nli_metrics[key]['spans_contribution'])
        #     spans_contribution = nli_metrics[key]['spans_contribution'].transpose(0, -1) @ nli_metrics[key]['nli_logits']

        #     eval_nli_preds.append(torch.argmax(torch.softmax(spans_contribution/span_sum, dim=0)).item())
        #     eval_nli_labels.append(nli_metrics[key]['true_nli_labels'].item())

        ##### END #####

        eval_nli_acc = accuracy_score(eval_nli_labels, eval_nli_preds)

        ic.enable()
        
        ic(list(zip(eval_span_labels, eval_span_preds)))
        # ic(len(eval_span_labels), len(eval_span_preds))
        # ic(sum(eval_span_labels), sum(eval_span_preds))

        # find threshold for 80% recall
        # precision, recall, thresholds = precision_recall_curve(eval_span_labels, eval_span_preds)


        mAP = (average_precision_score(eval_span_labels, eval_span_preds, pos_label=0) + average_precision_score(eval_span_labels, eval_span_preds, pos_label=1))/2

        # mAP = average_precision_score(torch.tensor(true_span_labels), torch.tensor(pred_span_labels))
        precision_at_80_recall = precision_at_recall(torch.tensor(eval_span_labels), torch.tensor(eval_span_preds), 0.8)
        f1_score_for_entailment = calculate_f1_score_for_class(torch.tensor(eval_nli_labels), torch.tensor(eval_nli_preds), NLPUtils.get_labels()['Entailment'])
        f1_score_for_contradiction = calculate_f1_score_for_class(torch.tensor(eval_nli_labels), torch.tensor(eval_nli_preds), NLPUtils.get_labels()['Contradiction'])
        
        return {
            'mAP' : mAP,
            'precision_at_80_recall' : precision_at_80_recall,
            'nli_acc': eval_nli_acc,
            'f1_score_for_entailment': f1_score_for_entailment,
            'f1_score_for_contradiction': f1_score_for_contradiction
        }

In [None]:
trainer = ContractNLIMetricTrainer(
    model=model,                      
    args=training_args,               
    # train_dataset=train_dataset,
    eval_dataset=dev_dataset,      
    data_collator=ContractNLIMetricTrainer.collate_fn,
)

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


In [None]:
ic.disable()
# ic.enable()
results = trainer.evaluate()

In [2]:
results

{'mAP': 0.584293051888368,
 'precision_at_80_recall': 0.3567567567567567,
 'nli_acc': 0.6554621848739496,
 'f1_score_for_entailment': 0.30112590299277603,
 'f1_score_for_contradiction': 0.2663243589845487}