## Spans computation for a given text

    def preprocess_spans(self, tokens, ner, classes_to_id):

        max_len = self.base_config.max_len

        if len(tokens) > max_len:
            length = max_len
            tokens = tokens[:max_len]
        else:
            length = len(tokens)

        spans_idx = []
        for i in range(length):
            spans_idx.extend([(i, i + j) for j in range(self.max_width)])

        dict_lab = self.get_dict(ner, classes_to_id) if ner else defaultdict(int)

        # 0 for null labels
        span_label = torch.LongTensor([dict_lab[i] for i in spans_idx])
        spans_idx = torch.LongTensor(spans_idx)

        # mask for valid spans
        valid_span_mask = spans_idx[:, 1] > length - 1

        # mask invalid positions
        span_label = span_label.masked_fill(valid_span_mask, -1)

        return {
        'tokens': tokens,
        'span_idx': spans_idx,
        'span_label': span_label,
        'seq_length': length,
        'entities': ner,
        }

    
## Span Representation layer 
 
### Functions for span representation

    def extract_elements(sequence, indices):
        B, L, D = sequence.shape
        K = indices.shape[1]

        # Expand indices to [B, K, D]
        expanded_indices = indices.unsqueeze(2).expand(-1, -1, D)

        # Gather the elements
        extracted_elements = torch.gather(sequence, 1, expanded_indices)

        return extracted_elements
    
def create_projection_layer(hidden_size: int, dropout: float, out_dim: int = None) -> nn.Sequential:
            """
            Creates a projection layer with specified configurations.
            """
            if out_dim is None:
                out_dim = hidden_size

            return nn.Sequential(
                nn.Linear(hidden_size, out_dim * 4),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(out_dim * 4, out_dim)
                )

### Span Representation layer (default)
    class SpanMarkerV0(nn.Module):
    """
    Marks and projects span endpoints using an MLP.
    """

        def __init__(self, hidden_size: int, max_width: int, dropout: float = 0.4):
            super().__init__()
            self.max_width = max_width
            self.project_start = create_projection_layer(hidden_size, dropout)
            self.project_end = create_projection_layer(hidden_size, dropout)

            self.out_project = create_projection_layer(hidden_size * 2, dropout, hidden_size)

        def forward(self, h: torch.Tensor, span_idx: torch.Tensor) -> torch.Tensor:
            B, L, D = h.size()

            start_rep = self.project_start(h)
            end_rep = self.project_end(h)

            start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
            end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])

            cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()

            return self.out_project(cat).view(B, L, self.max_width, D)


In [14]:
from collections import defaultdict
from typing import List, Tuple, Dict

import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import random


class InstructBase(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.max_width = config['max_width']
        self.base_config = config

    def get_dict(self, spans, classes_to_id):
        dict_tag = defaultdict(int)
        for span in spans:
            if span[2] in classes_to_id:
                dict_tag[(span[0], span[1])] = classes_to_id[span[2]]
        return dict_tag

    def preprocess_spans(self, tokens, ner, classes_to_id):

        max_len = self.base_config['max_len']

        if len(tokens) > max_len:
            length = max_len
            tokens = tokens[:max_len]
        else:
            length = len(tokens)

        spans_idx = []
        for i in range(length):
            spans_idx.extend([(i, i + j) for j in range(self.max_width)])

        dict_lab = self.get_dict(ner, classes_to_id) if ner else defaultdict(int)

        # 0 for null labels
        span_label = torch.LongTensor([dict_lab[i] for i in spans_idx])
        spans_idx = torch.LongTensor(spans_idx)

        # mask for valid spans
        valid_span_mask = spans_idx[:, 1] > length - 1

        # mask invalid positions
        span_label = span_label.masked_fill(valid_span_mask, -1)

        return {
            'tokens': tokens,
            'span_idx': spans_idx,
            'span_label': span_label,
            'seq_length': length,
            'entities': ner,
        }

    def collate_fn(self, batch_list, entity_types=None):
        # batch_list: list of dict containing tokens, ner
        if entity_types is None:
            negs = self.get_negatives(batch_list, 100)
            class_to_ids = []
            id_to_classes = []
            for b in batch_list:
                # negs = b["negative"]
                random.shuffle(negs)

                # negs = negs[:sampled_neg]
                max_neg_type_ratio = int(self.base_config.max_neg_type_ratio)

                if max_neg_type_ratio == 0:
                    # no negatives
                    neg_type_ratio = 0
                else:
                    neg_type_ratio = random.randint(0, max_neg_type_ratio)

                if neg_type_ratio == 0:
                    # no negatives
                    negs_i = []
                else:
                    negs_i = negs[:len(b['ner']) * neg_type_ratio]

                # this is the list of all possible entity types (positive and negative)
                types = list(set([el[-1] for el in b['ner']] + negs_i))

                # shuffle (every epoch)
                random.shuffle(types)

                if len(types) != 0:
                    # prob of higher number shoul
                    # random drop
                    if self.base_config.random_drop:
                        num_ents = random.randint(1, len(types))
                        types = types[:num_ents]

                # maximum number of entities types
                types = types[:int(self.base_config.max_types)]

                # supervised training
                if "label" in b:
                    types = sorted(b["label"])

                class_to_id = {k: v for v, k in enumerate(types, start=1)}
                id_to_class = {k: v for v, k in class_to_id.items()}
                class_to_ids.append(class_to_id)
                id_to_classes.append(id_to_class)

            batch = [
                self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids[i]) for i, b in enumerate(batch_list)
            ]

        else:
            class_to_ids = {k: v for v, k in enumerate(entity_types, start=1)}
            id_to_classes = {k: v for v, k in class_to_ids.items()}
            batch = [
                self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids) for b in batch_list
            ]

        span_idx = pad_sequence(
            [b['span_idx'] for b in batch], batch_first=True, padding_value=0
        )

        span_label = pad_sequence(
            [el['span_label'] for el in batch], batch_first=True, padding_value=-1
        )

        return {
            'seq_length': torch.LongTensor([el['seq_length'] for el in batch]),
            'span_idx': span_idx,
            'tokens': [el['tokens'] for el in batch],
            'span_mask': span_label != -1,
            'span_label': span_label,
            'entities': [el['entities'] for el in batch],
            'classes_to_id': class_to_ids,
            'id_to_classes': id_to_classes,
        }

    @staticmethod
    def get_negatives(batch_list, sampled_neg=5):
        ent_types = []
        for b in batch_list:
            types = set([el[-1] for el in b['ner']])
            ent_types.extend(list(types))
        ent_types = list(set(ent_types))
        # sample negatives
        random.shuffle(ent_types)
        return ent_types[:sampled_neg]

    def create_dataloader(self, data, entity_types=None, **kwargs):
        return DataLoader(data, collate_fn=lambda x: self.collate_fn(x, entity_types), **kwargs)

    def set_sampling_params(self, max_types, shuffle_types, random_drop, max_neg_type_ratio, max_len):
        """
        Sets sampling parameters on the given model.

        Parameters:
        - model: The model object to update.
        - max_types: Maximum types parameter.
        - shuffle_types: Boolean indicating whether to shuffle types.
        - random_drop: Boolean indicating whether to randomly drop elements.
        - max_neg_type_ratio: Maximum negative type ratio.
        - max_len: Maximum length parameter.
        """
        self.base_config.max_types = max_types
        self.base_config.shuffle_types = shuffle_types
        self.base_config.random_drop = random_drop
        self.base_config.max_neg_type_ratio = max_neg_type_ratio
        self.base_config.max_len = max_len

In [15]:
config = {
  "lr_encoder": "1e-5",
  "lr_others": "5e-5",
  "num_steps": 30000,
  "warmup_ratio": 0.1,
  "train_batch_size": 8,
  "eval_every": 5000,
  "max_width": 12,
  "model_name": "microsoft/deberta-v3-large",
  "fine_tune": True,
  "subtoken_pooling": "first",
  "hidden_size": 768,
  "span_mode": "markerV0",
  "dropout": 0.4,
  "max_neg_type_ratio": 3,
  "name": "abl",
  "size_sup": -1,
  "max_types": 25,
  "shuffle_types": True,
  "random_drop": True,
  "max_len": 384,
}

In [16]:
ib = InstructBase(config=config)

In [17]:
ib.

InstructBase()

In [36]:
import torch
import torch.nn.functional as F
from torch import nn

def create_projection_layer(hidden_size: int, dropout: float, out_dim: int = None) -> nn.Sequential:
    """
    Creates a projection layer with specified configurations.
    """
    if out_dim is None:
        out_dim = hidden_size

    return nn.Sequential(
        nn.Linear(hidden_size, out_dim * 4),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(out_dim * 4, out_dim)
    )

def extract_elements(sequence, indices):
    B, L, D = sequence.shape
    print("B,L,D = ", B, L, D)
    K = indices.shape[1]

    # Expand indices to [B, K, D]
    expanded_indices = indices.unsqueeze(2).expand(-1, -1, D)

    # Gather the elements
    extracted_elements = torch.gather(sequence, 1, expanded_indices)

    return extracted_elements

In [41]:
class SpanMarkerV0(nn.Module):
    """
    Marks and projects span endpoints using an MLP.
    """

    def __init__(self, hidden_size: int, max_width: int, dropout: float = 0.4):
        super().__init__()
        self.max_width = max_width
        self.project_start = create_projection_layer(hidden_size, dropout)
        self.project_end = create_projection_layer(hidden_size, dropout)

        self.out_project = create_projection_layer(hidden_size * 2, dropout, hidden_size)

    def forward(self, h: torch.Tensor, span_idx: torch.Tensor) -> torch.Tensor:
        B, L, D = h.size()

        start_rep = self.project_start(h)
        end_rep = self.project_end(h)

        start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
        end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])

        print(start_rep.shape)
        print(end_rep.shape)

        print()

        print(span_idx[:, :, 0])
        print(span_idx[:, :, 1])

        print()

        print(start_span_rep.shape)
        print(end_span_rep.shape)



        cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()

        print(cat.shape)

        return self.out_project(cat)

        return self.out_project(cat).view(B, L, self.max_width, D)



In [44]:
# Given inputs
word_rep = torch.randn(1, 6, 10)  # Assuming 6 tokens with each token represented by a 10-dimensional vector
span_idx = torch.LongTensor([[[0, 0],
                              [0, 1],
                              [0, 2],
                              [1, 1],
                              [1, 2],
                              [1, 3],
                              [2, 2],
                              [2, 3],
                              [2, 4],
                              [3, 3],
                              [3, 4],
                              [3, 5],
                              [4, 4],
                              [4, 5],
                              [5, 5]]])

# Create an instance of SpanMarkerV0
span_marker = SpanMarkerV0(hidden_size=10, max_width=3, dropout=0.4)

# Compute span_rep
span_rep = span_marker.forward(word_rep, span_idx)

print("###############")
print(span_rep.shape)  # Output shape
print(span_rep)         # Output tensor


B,L,D =  1 6 10
B,L,D =  1 6 10
torch.Size([1, 6, 10])
torch.Size([1, 6, 10])

tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 5]])
tensor([[0, 1, 2, 1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 5]])

torch.Size([1, 15, 10])
torch.Size([1, 15, 10])
torch.Size([1, 15, 20])
###############
torch.Size([1, 15, 10])
tensor([[[-3.0157e-01, -4.7236e-02, -2.1270e-01,  1.3546e-01,  8.2573e-03,
          -4.8060e-02,  1.6734e-01,  1.5996e-01, -2.6834e-02,  3.7026e-01],
         [-2.6844e-01, -4.5104e-02, -3.4430e-01, -5.4507e-02,  4.8877e-02,
          -8.3942e-02,  2.0335e-01,  1.9588e-01,  1.1190e-03,  3.3777e-01],
         [-2.3895e-01,  4.7283e-03, -1.2314e-01,  1.1788e-01,  1.6451e-02,
           3.1235e-02,  2.7326e-01,  1.2244e-01, -4.0133e-02,  2.9034e-01],
         [-1.2960e-01, -1.5413e-01, -2.6885e-01, -9.8466e-02,  2.3240e-02,
          -2.3104e-02,  1.5348e-01,  1.4566e-01,  5.1083e-02,  2.8492e-01],
         [-1.7766e-01, -1.7208e-02, -1.9379e-01, -2.3842e-02,  2.7209e-02,
          -1.0550