In [None]:
import os

import random
import numpy as np
from collections import defaultdict

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import _LRScheduler

import transformers
from transformers import BertModel, BertTokenizerFast, AutoModel, AutoTokenizer
import datasets

import matplotlib_inline
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
import json

from math import floor, ceil, cos, pi

from scipy.stats import spearmanr, mannwhitneyu, ks_2samp
from sklearn.metrics import roc_auc_score

import wandb
os.environ["CUDA_VISIBLE_DEVICES"]="0"
# torch.cuda.set_device(0)

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

%matplotlib inline
matplotlib_inline.backend_inline.set_matplotlib_formats('pdf', 'svg')

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
key = user_secrets.get_secret("wandb_key")
!wandb login $key

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
import huggingface_hub
huggingface_hub.login(token=user_secrets.get_secret("HF_TOKEN"))

In [None]:
PRETR = 'bert-base-uncased'
#PRETR = 'microsoft/deberta-base'
#PRETR = "SpanBERT/spanbert-base-cased"

In [None]:
de = json.load(open("/kaggle/input/ade-dataset/drug_effect.json"))
de2 = []
for document in tqdm(de['dataset_documents']):
    markup = {'markup_spans': []}
    reprs = set()
    for element in document['document_markups'][0]['markup_elements']:
        for span in element['element_spans']:
            r = repr([span['espan_begin'], span['espan_end']])
            if r not in reprs:
                markup['markup_spans'].append({
                    'espan_begin': span['espan_begin'],
                    'espan_end': span['espan_end'],
                    'espan_tags': ['O']
                })
                reprs.add(r)
    de2.append({
        'document_text': document['document_text'],
        'document_markups': [markup]
    })

  0%|          | 0/4271 [00:00<?, ?it/s]

In [None]:
dd = json.load(open("/kaggle/input/ade-dataset/drug_dosage.json"))

In [None]:
dd2 = []
for document in tqdm(dd['dataset_documents']):
    markup = {'markup_spans': []}
    reprs = set()
    for element in document['document_markups'][0]['markup_elements']:
        for span in element['element_spans']:
            r = repr([span['espan_begin'], span['espan_end']])
            if r not in reprs:
                markup['markup_spans'].append({
                    'espan_begin': span['espan_begin'],
                    'espan_end': span['espan_end'],
                    'espan_tags': ['O']
                })
                reprs.add(r)
    dd2.append({
        'document_text': document['document_text'],
        'document_markups': [markup]
    })

  0%|          | 0/213 [00:00<?, ?it/s]

In [None]:
dataset = json.load(open("/kaggle/input/kaggle-ner-span-level-2/dataset.json"))

In [None]:
tags = set()
for i in range(len(dataset["dataset_documents"])):
    spans = {}
    for j, markup in enumerate(dataset["dataset_documents"][i]["document_markups"]):
        for span in markup["markup_spans"]:
            begin = span['espan_begin']
            end = span['espan_end']
            tag = span['espan_tags'][0]
            tags.add(tag)
            r = repr([begin, end])
            if r not in spans:
                spans[r] = {
                    "espan_begin": begin,
                    "espan_end": end,
                    "espan_tags": ['O'] * j
                }
            spans[r]["espan_tags"].append(tag)
        for r in spans:
            if len(spans[r]['espan_tags']) == j:
                spans[r]['espan_tags'].append('O')
    dataset["dataset_documents"][i] = {
        "document_text": dataset["dataset_documents"][i]["document_text"],
        "document_markups": {
            "markup_spans": list(spans.values())
        }
    }
documents = dataset['dataset_documents']

In [None]:
def ma_dataset(documents, tokenizer):
    input_ids, token_type_ids, attention_mask, offset_mapping = tokenizer([document['document_text'] for document in documents],
                                                                           return_offsets_mapping=True, truncation=True,
                                                                           max_length=512, padding='max_length').values()
    new_documents = datasets.Dataset.from_list([{
        "text": documents[i]["document_text"],
        "input_ids": input_ids[i],
        "attention_mask": attention_mask[i],
        "offset_mapping": offset_mapping[i],
        "markups": documents[i]['document_markups'],
        "special_ids": tokenizer.all_special_ids
    } for i in range(len(documents))])
    return new_documents.map(
        function=ma_align_tokens,
        input_columns=["text", "input_ids", "attention_mask", "offset_mapping", "markups", "special_ids"]
    )

In [None]:
def ma_align_tokens(text, input_ids, attention_mask, offset_mapping, markups, special_ids):
    new_markups = []
    for markup in markups:
        new_spans = []
        for span in markup['markup_spans']:
            begin = span['espan_begin']
            end = span['espan_end']
            tags = span['espan_tags']
            new_begin = None
            new_end = None
            for j, token_id in enumerate(input_ids):
                token_begin, token_end = offset_mapping[j]
                if token_begin >= end:
                    break
                elif token_id in special_ids:
                    continue
                elif token_end > begin:
                  if new_begin is None:
                    new_begin = j
                  new_end = j
            assert new_begin is not None and new_end is not None
            new_spans.append({
                "begin": new_begin,
                "end": new_end + 1,
                "tags": tags
            })
        new_markups.append({
            "markup_spans": new_spans
        })
    return {
        "text": text,
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "markups": new_markups
    }

In [None]:
tokenizer = AutoTokenizer.from_pretrained(PRETR)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [None]:
tokenized_dd = ma_dataset(dd2, tokenizer)

Map:   0%|          | 0/213 [00:00<?, ? examples/s]

In [None]:
tokenized_dd = tokenized_dd.remove_columns(['offset_mapping', 'special_ids'])
tokenized_dd

Dataset({
    features: ['text', 'input_ids', 'attention_mask', 'markups'],
    num_rows: 213
})

In [None]:
tokenized_de = ma_dataset(de2, tokenizer)
tokenized_de = tokenized_de.remove_columns(['offset_mapping', 'special_ids'])
tokenized_de

Map:   0%|          | 0/4271 [00:00<?, ? examples/s]

Dataset({
    features: ['text', 'input_ids', 'attention_mask', 'markups'],
    num_rows: 4271
})

In [None]:
tokenized_docs = ma_dataset(dataset['dataset_documents'], tokenizer)

Map:   0%|          | 0/47575 [00:00<?, ? examples/s]

In [None]:
tokenized_docs = tokenized_docs.remove_columns(['offset_mapping', 'special_ids'])
tokenized_docs

Dataset({
    features: ['text', 'input_ids', 'attention_mask', 'markups'],
    num_rows: 47575
})

In [None]:
def ma(text, input_ids, attention_mask, markups):
    clusters = []
    for markup in markups:
        for span in markup['markup_spans']:
            flag = False
            for j, cl in enumerate(clusters):
                cur_span = cl[0]
                if not cur_span['end'] <= span['begin'] and not span['end'] <= cur_span['begin']:
                    union = max(cur_span['end'], span['end']) - min(cur_span['begin'], span['begin'])
                    inter = min(cur_span['end'], span['end']) - max(cur_span['begin'], span['begin'])
                    if inter / union >= 0.7:
                        clusters[j].append(span)
                        flag = True
                        break
            if not flag:
                clusters.append([span])
    new_spans = []
    for cl in clusters:
        sp_tags = defaultdict(list)
        for span in cl:
            sp_tags[span['tags'][0]].append(span)
        for tag, spans in sp_tags.items():
            begin = floor(sum([span['begin'] for span in spans]) / len(spans))
            end = ceil(sum([span['end'] for span in spans]) / len(spans))
            score = len(spans) / len(markups)
            new_spans.append({'begin': begin, 'end': end, 'tags': [tag], 'score': score})
    return {
        "text": text,
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "spans": new_spans
    }

In [None]:
ma_de = tokenized_de.map(function=ma, input_columns=['text', 'input_ids', 'attention_mask', 'markups'])
ma_de = ma_de.remove_columns('markups')
ma_de

Map:   0%|          | 0/4271 [00:00<?, ? examples/s]

Dataset({
    features: ['text', 'input_ids', 'attention_mask', 'spans'],
    num_rows: 4271
})

In [None]:
ma_dd = tokenized_dd.map(function=ma, input_columns=['text', 'input_ids', 'attention_mask', 'markups'])
ma_dd = ma_dd.remove_columns('markups')
ma_dd

Map:   0%|          | 0/213 [00:00<?, ? examples/s]

Dataset({
    features: ['text', 'input_ids', 'attention_mask', 'spans'],
    num_rows: 213
})

In [None]:
ma_docs = tokenized_docs.map(function=ma, input_columns=['text', 'input_ids', 'attention_mask', 'markups'])

Map:   0%|          | 0/47575 [00:00<?, ? examples/s]

In [None]:
ma_docs = ma_docs.remove_columns('markups')
ma_docs

Dataset({
    features: ['text', 'input_ids', 'attention_mask', 'spans'],
    num_rows: 47575
})

In [None]:
res = ma_docs

In [None]:
res = res.train_test_split(test_size=0.9, shuffle=True, seed=42)
res

DatasetDict({
    train: Dataset({
        features: ['text', 'input_ids', 'attention_mask', 'spans'],
        num_rows: 4757
    })
    test: Dataset({
        features: ['text', 'input_ids', 'attention_mask', 'spans'],
        num_rows: 42818
    })
})

In [None]:
from collections import defaultdict
lens = defaultdict(int)
tags = set('O')
for d in tqdm(res['train']):
    for span in d['spans']:
        length = span['end'] - span['begin']
        lens[length] += 1
        tags.add(span['tags'][0])
tags = sorted(list(tags))
lens, tags

  0%|          | 0/4757 [00:00<?, ?it/s]

(defaultdict(int,
             {4: 1021,
              1: 6069,
              2: 2086,
              6: 200,
              3: 1051,
              5: 377,
              7: 115,
              10: 9,
              8: 58,
              9: 26,
              11: 6,
              12: 5}),
 ['O', 'art', 'eve', 'geo', 'gpe', 'nat', 'org', 'per', 'tim'])

In [None]:
class MAUniversalDataset(Dataset):
    def __init__(self, documents, tags, lens, neg_num, mode, noise=None):
        self.seed = -1
        self.tags = tags
        self.lens = lens
        self.data = []
        self.reprs = []
        self.lengths = []
        self.extras = []
        for doc_id, document in enumerate(tqdm(documents)):
            doc = []
            reprs = set()
            self.lengths.append(sum(document['attention_mask']) - 2)
            for span in document['spans']:
                begin, end = span['begin'], span['end']
                span_tag = tags.index(span['tags'][0])
                r = repr([begin, end])
                if span['score'] >= 0.0:
                    doc.append({
                        "text": document['text'],
                        "input_ids": document['input_ids'],
                        "attention_mask": document['attention_mask'],
                        "begin": begin,
                        "end": end,
                        "y": span_tag,
                        'score': span['score']
                    })
                    reprs.add(r)
                else:
                    self.extras.append({
                        "text": document['text'],
                        "input_ids": document['input_ids'],
                        "attention_mask": document['attention_mask'],
                        "begin": begin,
                        "end": end,
                        "y": span_tag,
                        'score': span['score']
                    })
            if mode == 'train':
                for _ in range(neg_num):
                    self.data.append({
                        "doc_id": doc_id,
                        "text": document['text'],
                        "input_ids": document['input_ids'],
                        "attention_mask": document['attention_mask'],
                        "begin": None,
                        "end": None,
                        "y": None,
                        'score': 1.0
                    })
            else:
                neg_created = 0
                attempts = 0
                while neg_created < neg_num:
                    attempts += 1
                    sample = self.negative_sampler({
                        "doc_id": doc_id,
                        "text": document['text'],
                        "input_ids": document['input_ids'],
                        "attention_mask": document['attention_mask'],
                        "begin": None,
                        "end": None,
                        "y": None,
                        'score': 1
                    }, reprs, self.lengths[doc_id])
                    r = repr([sample['begin'], sample['end']])
                    reprs.add(r)
                    neg_created += 1
                    self.data.append(sample)
            self.data += doc
            self.reprs.append(reprs)
        if noise is not None:
            for j in range(len(self.data)):
                self.data[j]['noise'] = np.random.default_rng(j).normal(loc=0, scale=noise['scale'], size=noise['size']).tolist()

    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        el = self.data[index]
        if el['begin'] is None:
            doc_id = el['doc_id']
            el = self.negative_sampler(el, self.reprs[doc_id], self.lengths[doc_id])
        return el
    def negative_sampler(self, el, reprs, full_length):
        self.seed += 1
        while True:
            rng = np.random.default_rng(self.seed)
            length = rng.choice(list(map(int, self.lens.keys())), 1, p=np.array(list(self.lens.values())) / np.array(list(self.lens.values())).sum())[0]
            begin = rng.choice(np.arange(1, max(full_length - length, 2)), 1)[0]
            end = min(begin + length, full_length + 1)
            try:
                assert begin < end
            except:
                print(begin, end, length, full_length)
                assert False
            r = repr([begin, end])
            flag = True
            for rep in reprs:
                cur_begin, cur_end = rep[1:-1].split(', ')
                cur_begin, cur_end = int(cur_begin), int(cur_end)
                if cur_begin >= end or begin >= cur_end:
                    continue
                union = max(end, cur_end) - min(begin, cur_begin)
                inter = min(end, cur_end) - max(begin, cur_begin)
                if inter / union >= 0.7:
                    flag = False
                    break
            if flag:
                break
            self.seed += 1
        return {
            "text": el['text'],
            "input_ids": el['input_ids'],
            "attention_mask": el['attention_mask'],
            "begin": begin,
            "end": end,
            "y": self.tags.index('O'),
            'score': 1
        }

In [None]:
train_dataset = MAUniversalDataset(res['train'], tags, lens, 1, mode='train')
eval_dataset = MAUniversalDataset(res['test'].select(range(2000)), tags, lens, 1, mode='test')
noise_dataset = MAUniversalDataset(res['test'].select(range(2000)), tags, lens, 0, mode='test', noise={'scale': 0.01, 'size': 768})
dd_dataset = MAUniversalDataset(ma_dd, tags, lens, 0, mode='test')
de_dataset = MAUniversalDataset(ma_de, tags, lens, 0, mode='test')

  0%|          | 0/4757 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/213 [00:00<?, ?it/s]

  0%|          | 0/4271 [00:00<?, ?it/s]

In [None]:
def collate_fn(batch):
    batch_input_ids = []
    batch_attention_mask = []
    batch_begin = []
    batch_end = []
    batch_y = []
    batch_score = []
    flag = 'noise' in batch[0]
    if flag:
        batch_noise = []
    for el in batch:
        batch_input_ids.append(el['input_ids'])
        batch_attention_mask.append(el['attention_mask'])
        batch_begin.append(el['begin'])
        batch_end.append(el['end'])
        batch_score.append(el['score'])
        if flag:
            batch_noise.append(el['noise'])
        batch_y.append(el['y'])
    batch_input_ids = torch.tensor(batch_input_ids, dtype=torch.long)
    batch_attention_mask = torch.tensor(batch_attention_mask, dtype=torch.long)
    batch_begin = torch.tensor(batch_begin, dtype=torch.long)
    batch_end = torch.tensor(batch_end, dtype=torch.long)
    batch_score = torch.tensor(batch_score)
    batch_y = torch.tensor(batch_y)
    res = {
        "input_ids": batch_input_ids,
        "attention_mask": batch_attention_mask,
        "begin": batch_begin,
        "end": batch_end,
        "y": batch_y,
        'score': batch_score
    }
    if flag:
        res['noise'] = torch.tensor(batch_noise)
    return res

In [None]:
class UniversalModel(nn.Module):
    def __init__(self, embedder, tokenizer, hidden_size, num_classes, aggr, device):
        super().__init__()
        self.tokenizer = tokenizer
        self.embedder = embedder.to(device)
        self.num_classes = num_classes
        self.aggr = aggr
        self.hidden_size = hidden_size
        if self.aggr == 'attn':
            self.v = nn.Linear(self.hidden_size, 1)
            self.flat = nn.Flatten()
            self.softmax = nn.Softmax(dim=1)
        elif self.aggr == 'endpoint' or self.aggr == 'diff-sum':
            self.hidden_size *= 2
        elif self.aggr == 'coherent':
            self.a = 15 * self.hidden_size // 32
            self.b = self.hidden_size // 32
            self.hidden_size = 2 * self.a + 1
        self.fc = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.GELU(),
            nn.Linear(self.hidden_size, num_classes)
        )
        self.device = device
    def span_embedder(self, embeddings, begin, end):
        mask = torch.zeros(embeddings.shape[0], embeddings.shape[1]).bool().to(self.device)
        poses = torch.arange(embeddings.shape[1]).to(self.device).reshape(1, embeddings.shape[1])
        mask[(poses >= begin[:,None]) & (poses < end[:,None])] = True
        if self.aggr == 'mean':
            return (embeddings * mask[:,:,None]).sum(dim=1) / (end - begin)[:,None]
        if self.aggr == 'attn':
            alpha = self.flat(self.v(embeddings))
            alpha[~mask] = float('-inf')
            alpha = self.softmax(alpha)
            return (embeddings * alpha[:,:,None]).sum(dim=1)
        if self.aggr == 'max':
            embeddings[~mask] = float('-inf')
            span_embeddings = embeddings.max(dim=1).values
            span_embeddings[begin >= end] = 0
            return span_embeddings
        if self.aggr == 'endpoint':
            return torch.cat([embeddings[(torch.arange(embeddings.shape[0]), begin)],
                              embeddings[(torch.arange(embeddings.shape[0]), end - 1)]], dim=1)
        if self.aggr == 'diff-sum':
            begin_emb = embeddings[(torch.arange(embeddings.shape[0]), begin)]
            end_emb = embeddings[(torch.arange(embeddings.shape[0]), end - 1)]
            return torch.cat([end_emb + begin_emb, end_emb - begin_emb], dim=1)
        if self.aggr == 'coherent':
            begin_emb = embeddings[(torch.arange(embeddings.shape[0]), begin)]
            end_emb = embeddings[(torch.arange(embeddings.shape[0]), end - 1)]
            return torch.cat([begin_emb[:,:self.a], end_emb[:,self.a:2 * self.a],
                              (begin_emb[:,2 * self.a:2 * self.a + self.b] * end_emb[:,2 * self.a + self.b:]).sum(dim=1).reshape(-1, 1)], dim=1)
        raise ValueError('wrong span aggregation name')
    def forward(self, input_ids, attention_mask, begin, end, *args, **kwargs):
        embeddings = self.embedder(input_ids.to(self.device), attention_mask.to(self.device))[0]
        self.span_embeddings = self.span_embedder(embeddings, begin, end)
        if kwargs['noise'] is not None:
            self.span_embeddings = self.span_embeddings + kwargs['noise'].to(self.device)
        if kwargs['retain']:
            self.span_embeddings.retain_grad()
        return self.fc(self.span_embeddings)
    def get_activations_and_gradients(self):
        return self.span_embeddings, self.span_embeddings_grad
    def get_embeddings(self, input_ids, attention_mask, begin, end):
        embeddings = self.embedder(input_ids.to(self.device), attention_mask.to(self.device))[0]
        mask = torch.zeros(embeddings.shape[0], embeddings.shape[1]).to(self.device)
        for i in range(begin.shape[0]):
            mask[i, begin[i]:end[i]] = 1
        return (embeddings * mask[:,:,None]).sum(dim=1) / (end.to(self.device) - begin.to(self.device))[:,None]
    def predict_for_fixed_length(self, input_ids, attention_mask, fixed_len):
        '''
        For only one text!
        '''
        embeddings = self.embedder(input_ids, attention_mask)[0]
        assert embeddings.shape[0] == 1
        span_embeddings = embeddings[0, 1:(attention_mask.sum().item() - 1)].unfold(dimension=0, size=fixed_len, step=1).permute(0, 2, 1).mean(dim=1)
        return self.fc(span_embeddings)

In [None]:
class WarmupScheduler(_LRScheduler):
    def __init__(self, optimizer, warmup_steps, total_steps, mode='linear', last_epoch=-1):
        """
        Args:
            optimizer: оптимизатор
            warmup_steps: количество шагов прогрева
            total_steps: общее количество шагов
            mode: 'linear' или 'cosine'
        """
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.mode = mode
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            # Фаза прогрева
            progress = self.last_epoch / self.warmup_steps
            return [base_lr * progress for base_lr in self.base_lrs]
        else:
            # Основная фаза
            progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
            if self.mode == 'linear':
                factor = max(0.0, 1.0 - progress)
            elif self.mode == 'cosine':
                factor = max(0.0, 0.5 * (1.0 + cos(pi * progress)))
            return [base_lr * factor for base_lr in self.base_lrs]

In [None]:
def set_global_seed(seed: int) -> None:
    """Set global seed for reproducibility.
    :param int seed: Seed to be set
    """


    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    # если нужно гарантировать 1000% воспроизводимость
    # torch.use_deterministic_algorithms(False)

    # Для Dataloader
    g = torch.Generator()
    g.manual_seed(seed)

    return g

# Для каждого worker в Dataloader
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

In [None]:
@torch.no_grad()
def log_gradients(model, global_step, log_norm=True, log_hist=True):
    for tag, value in model.named_parameters():
        g = value.grad
        if g is None:
            continue

        if log_hist:
            wandb.log({f"grad/{tag}": wandb.Histogram(g.cpu())}, global_step)

        if log_norm:
            wandb.log({f"grad_norm/{tag}": torch.norm(g.cpu())}, global_step)

@torch.no_grad()
def log_weights(model, global_step, log_norm=True, log_hist=True):
    for tag, value in model.named_parameters():
        g = value.grad
        if g is None:
            continue

        if log_hist:
            wandb.log({f"weight/{tag}": wandb.Histogram(value.cpu())}, global_step)

        if log_norm:
            wandb.log({f"weight_norm/{tag}": torch.norm(value.cpu())}, global_step)

@torch.no_grad()
def vog(model):
    res = []
    for tag, value in model.named_parameters():
        g = value.grad
        if g is None:
            continue
        res.append(torch.norm(g.cpu()))
    return res

def get_activations(net, batch, loss_fn, optimizer):
    optimizer.zero_grad()
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    begin = batch['begin'].to(device)
    end = batch['end'].to(device)
    y = batch['y'].to(device)
    score = batch['score'].to(device)
    out = net(input_ids, token_type_ids, attention_mask, begin, end, retain=True)
    loss = torch.dot(loss_fn(out, y), score.to(out.dtype)) / score.sum()
    loss.backward()
    return net.span_embeddings.grad.detach().cpu().tolist()

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
def evaluate(net, val_dataloader, loss_fn, optimizer, device):
    net.eval()
    val_grads = []
    loss = 0
    count = 0
    trues = []
    preds = []
    scores = []
    for batch in tqdm(val_dataloader, leave=False):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        begin = batch['begin'].to(device)
        end = batch['end'].to(device)
        y = batch['y'].to(device)
        score = batch['score'].to(device)
        if 'noise' in batch:
            noise = batch['noise']
        else:
            noise = None

        out = net(input_ids, attention_mask, begin, end, retain=True, noise=noise)
        loss_cur = torch.dot(loss_fn(out, y), score.to(out.dtype)) / score.sum()
        loss += loss_cur.item() * score.sum().item()
        trues += y.cpu().tolist()
        preds += torch.argmax(out, dim=-1).cpu().tolist()
        scores += score.cpu().tolist()
        count += score.sum().item()
        loss_cur.backward()
        val_grads += net.span_embeddings.grad.detach().cpu().tolist()
    accuracy = accuracy_score(trues, preds, sample_weight=scores)
    precision_macro = precision_score(trues, preds, labels=range(net.num_classes), sample_weight=scores, average='macro', zero_division=0)
    precision_micro = precision_score(trues, preds, labels=range(net.num_classes), sample_weight=scores, average='micro', zero_division=0)
    recall_macro = recall_score(trues, preds, labels=range(net.num_classes), sample_weight=scores, average='macro', zero_division=0)
    recall_micro = recall_score(trues, preds, labels=range(net.num_classes), sample_weight=scores, average='micro', zero_division=0)
    f1_macro = f1_score(trues, preds, labels=range(net.num_classes), sample_weight=scores, average='macro', zero_division=0)
    f1_micro = f1_score(trues, preds, labels=range(net.num_classes), sample_weight=scores, average='micro', zero_division=0)

    return loss / count, accuracy, precision_macro, precision_micro, recall_macro, recall_micro, f1_macro, f1_micro, val_grads

In [None]:
import gc
def train(config, net, optimizer, scheduler, loss_fn, train_dataloader, val_dataloader, other_loaders, device, log_iterations, max_grad_norm=1.0):


    wandb.init(
        project="universal",
        name=config['name'],
        config=config
    )

    global_step = 0
    net = net.to(device)

    # ДОБАВИЛИ
    best_f1 = 0
    os.makedirs(config['checkpoint_dir'], exist_ok=True)

    epoch_num = config['epoch_num']
    # train_grads = {}
    val_grads = []
    other_grads = {}
    for name in other_loaders:
        other_grads[name] = []
    for epoch in tqdm(range(epoch_num)):
        net.train()
        #train_grads[epoch] = []

        for batch in tqdm(train_dataloader, leave=False):
            optimizer.zero_grad()

            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            begin = batch['begin'].to(device)
            end = batch['end'].to(device)
            y = batch['y'].to(device)
            scores = batch['score'].to(device)

            out = net(input_ids, attention_mask, begin, end, retain=False, noise=None)

            loss = loss_fn(out, y)
            loss = torch.dot(loss, scores.to(out.dtype)) / scores.sum()
            loss.backward()
            #train_grads[epoch] += net.span_embeddings.grad.detach().cpu().tolist()

            torch.nn.utils.clip_grad_norm_(net.parameters(), max_grad_norm)
            optimizer.step()
            scheduler.step()

            #y_pred = torch.argmax(out, 1)
            #accuracy = torch.sum(y_pred == y_true) / y_pred.shape[0]

            if global_step % log_iterations == 0:
                torch.cuda.empty_cache()
                gc.collect()
                log_gradients(net, global_step)
                log_weights(net, global_step)
                wandb.log({"train/loss": loss.item()}, step=global_step)

                loss, accuracy, precision_macro, precision_micro, recall_macro, recall_micro, f1_macro, f1_micro, grads = evaluate(
                    net              = net,
                    val_dataloader   = val_dataloader,
                    loss_fn          = loss_fn,
                    optimizer        = optimizer,
                    device           = device
                )
                val_grads.append(grads)
                for name in other_loaders:
                    res = evaluate(
                        net              = net,
                        val_dataloader   = other_loaders[name],
                        loss_fn          = loss_fn,
                        optimizer        = optimizer,
                        device           = device
                    )
                    other_grads[name].append(res[-1])
                wandb.log({
                    "eval/loss": loss, 'eval/accuracy': accuracy,
                    'eval/precision_macro': precision_macro, 'eval/precision_micro': precision_micro,
                    'eval/recall_macro': recall_macro, 'eval/recall_micro': recall_micro,
                    'eval/f1_macro': f1_macro, 'eval/f1_micro': f1_micro,
                }, step=global_step)
                f1 = (f1_macro + f1_micro) / 2

                # ДОБАВИЛИ
                if f1 > best_f1:
                    best_f1 = f1
                    torch.save({
                        'epoch': epoch,
                        'iter': global_step,
                        'model_state_dict': net.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'f1': f1,
                        }, os.path.join(config['checkpoint_dir'], config['checkpoint_name']))

            global_step += 1
        '''
        val_grads[epoch] = []
        for batch in tqdm(val_dataloader, leave=False):
            val_grads[epoch] += get_activations(net, batch, loss_fn, optimizer)
        '''
    wandb.finish()
    return val_grads, other_grads

In [None]:
aggr = 'mean'
config = {
    'seed'           : 42,
    'lr'             : 1e-4,
    'epoch_num'      : 2,
    'batch_size'     : 16,
    'val_batch_size' : 16,
    'name'           : f'exp {PRETR} {aggr}',
    'checkpoint_dir' : './checkpoints',
    'checkpoint_name': 'MLP.pth',
    'hidden_size'    : 768,
    'num_classes'    : len(tags)
}

g = set_global_seed(config['seed'])
g.manual_seed(0)

device = 'cpu'

if torch.cuda.is_available():
    device = 'cuda:0'

train_dataloader = DataLoader(
    train_dataset,
    batch_size     = config['batch_size'],
    shuffle        = True,
    drop_last      = True,
    num_workers    = 3,
    worker_init_fn = seed_worker,
    generator      = g,
    collate_fn     = collate_fn
)

val_dataloader = DataLoader(
    eval_dataset,
    batch_size     = config['val_batch_size'],
    shuffle        = False,
    drop_last      = False,
    num_workers    = 0,
    worker_init_fn = seed_worker,
    generator      = g,
    collate_fn     = collate_fn
)
other_loaders = {}
other_loaders['noise'] = DataLoader(
    noise_dataset,
    batch_size     = config['val_batch_size'],
    shuffle        = False,
    drop_last      = False,
    num_workers    = 0,
    worker_init_fn = seed_worker,
    generator      = g,
    collate_fn     = collate_fn
)

other_loaders['drug_dosage'] = DataLoader(
    dd_dataset,
    batch_size     = config['val_batch_size'],
    shuffle        = False,
    drop_last      = False,
    num_workers    = 0,
    worker_init_fn = seed_worker,
    generator      = g,
    collate_fn     = collate_fn
)

other_loaders['drug_effect'] = DataLoader(
    de_dataset,
    batch_size     = config['val_batch_size'],
    shuffle        = False,
    drop_last      = False,
    num_workers    = 0,
    worker_init_fn = seed_worker,
    generator      = g,
    collate_fn     = collate_fn
)

# tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
embedder = AutoModel.from_pretrained(PRETR)
net = UniversalModel(embedder, tokenizer, config['hidden_size'], config['num_classes'], aggr, device)
net.to('cuda')
optimizer = optim.AdamW(net.parameters(), lr=config['lr'])
total_steps = config['epoch_num'] * len(train_dataloader)
scheduler = WarmupScheduler(
    optimizer,
    warmup_steps=int(0.1 * total_steps),
    total_steps=total_steps,
    mode='cosine'
)
loss_fn = nn.CrossEntropyLoss(reduction='none')

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [None]:
val_grads, other_grads = train(config, net, optimizer, scheduler, loss_fn, train_dataloader, val_dataloader,
                               other_loaders, device, log_iterations=300)

In [None]:
import pickle
other_grads['val'] = val_grads
pickle.dump(other_grads, open('/kaggle/working/grads.pkl', 'wb'))

In [None]:
import kagglehub
kagglehub.model_upload(f'taband58/kaggle_universal/pyTorch/2b', '/kaggle/working/checkpoints/MLP.pth', 'Apache 2.0',
                       version_notes=f"{PRETR} {aggr}")

Uploading Model https://www.kaggle.com/models/taband58/kaggle_universal/pyTorch/2b ...
Starting upload for file /kaggle/working/checkpoints/MLP.pth


Uploading: 100%|██████████| 1.32G/1.32G [01:01<00:00, 21.4MB/s]  

Upload successful: /kaggle/working/checkpoints/MLP.pth (1GB)





Your model instance version has been created.
Files are being processed...
See at: https://www.kaggle.com/models/taband58/kaggle_universal/pyTorch/2b


In [None]:
state_dict = torch.load("/kaggle/input/kaggle_universal/pytorch/2b/7/MLP.pth", map_location='cuda')
net.load_state_dict(state_dict['model_state_dict'])
optimizer.load_state_dict(state_dict['optimizer_state_dict'])

In [None]:
res = evaluate(net, val_dataloader, loss_fn, optimizer, 'cuda')

  0%|          | 0/413 [00:00<?, ?it/s]

In [None]:
res[:-1]

(0.2889057270376002,
 0.9107746656566315,
 0.7137675571560931,
 0.9107746656566315,
 0.6178072214371854,
 0.9107746656566315,
 0.6312116074921346,
 0.9107746656566315)

In [None]:
def process_set(net0, loader):
    embeddings = []
    answers = []
    y = []
    score = []
    begin = []
    end = []
    net0.eval()
    with torch.no_grad():
        for batch in tqdm(loader):
            if 'noise' in batch:
                noise = batch['noise'].to(net0.device)
            else:
                noise = None
            answers.append(net0(batch['input_ids'].to(net0.device), batch['attention_mask'].to(net0.device),
                                batch['begin'].to(net0.device), batch['end'].to(net0.device),
                                retain=False, noise=noise).cpu().numpy())
            embeddings.append(net0.span_embeddings.cpu().numpy())
            y.append(batch['y'].cpu().numpy())
            score.append(batch['score'].cpu().numpy())
            begin.append(batch['begin'].cpu().numpy())
            end.append(batch['end'].cpu().numpy())
    embeddings = np.concatenate(embeddings, axis=0)
    answers = np.concatenate(answers, axis=0)
    y = np.concatenate(y, axis=0)
    score = np.concatenate(score, axis=0)
    begin = np.concatenate(begin, axis=0)
    end = np.concatenate(end, axis=0)
    others = answers.copy()
    others[np.arange(len(answers)), [tag for tag in y]] = -1e20
    margin = answers[np.arange(len(answers)), [tag for tag in y]] - answers[np.arange(len(answers)), others.argsort(axis=1)[:,-1]]
    ids = np.array([i for i in range(len(y)) if tags[y[i]] != 'O'])
    return embeddings, answers, y, ids, margin, score, begin, end

def calc_rocauc(diffs):
    ans = {}
    for key in diffs:
        if key != 'val':
            ans[key] = roc_auc_score(np.concatenate([np.zeros_like(diffs['val']), np.ones_like(diffs[key])]),
                    np.concatenate([diffs['val'], diffs[key]]))
    return ans

def calc_mah(cov_matrix, means, embeddings):
    scores = np.zeros((len(embeddings), len(means)))
    for i in range(len(tags)):
        scores[:, i] = ((embeddings - means[i][None,:]) @ np.linalg.inv(cov_matrix) @ (embeddings - means[i][None,:]).T)[range(len(embeddings)),
        range(len(embeddings))]
    return scores.min(axis=1)

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size     = config['val_batch_size'],
    shuffle        = False,
    drop_last      = False,
    num_workers    = 0,
    worker_init_fn = seed_worker,
    generator      = g,
    collate_fn     = collate_fn
)

res = process_set(net, train_loader)
embeddings = res[0]
y = res[2]

  0%|          | 0/987 [00:00<?, ?it/s]

In [None]:
cov_matrix = np.cov(embeddings, rowvar=False)
means = []
for i in range(len(tags)):
    means.append(embeddings[y == i].mean(axis=0))

In [None]:
processed = {}
other_loaders['val'] = val_dataloader
for key in ['val', 'noise', 'drug_dosage', 'drug_effect']:
    processed[key] = process_set(net, other_loaders[key])

  0%|          | 0/413 [00:00<?, ?it/s]

  0%|          | 0/288 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/686 [00:00<?, ?it/s]

In [None]:
processed = {}
other_loaders['val'] = val_dataloader
for key in ['val']:
    processed[key] = process_set(net, other_loaders[key])

  0%|          | 0/413 [00:00<?, ?it/s]

In [None]:
texts = [eval_dataset[i]['text'] for i in range(len(eval_dataset))]
spans = []
for i in range(len(eval_dataset)):
    el = eval_dataset[i]
    begin = el['begin']
    end = el['end']
    spans.append(tokenizer.decode(el['input_ids'][begin:end]))

In [None]:
import pickle
grads = pickle.load(open('/kaggle/input/other-grads-s-endpoint/grads.pkl', 'rb'))
diffs_vog = {}
for key in ['val', 'noise', 'drug_dosage', 'drug_effect']:
    diffs_vog[key] = np.array(grads[key])[1:].var(axis=0).mean(axis=1)
# diffs_vog['val'] = diffs_vog['val'][processed['val'][3]]

In [None]:
import pickle
grads = pickle.load(open('/kaggle/input/other-grads-b-mean/grads.pkl', 'rb'))
diffs_vog = {}
for key in ['val']:
    diffs_vog[key] = np.array(grads[key])[1:].var(axis=0).mean(axis=1)

In [None]:
margin_diffs = {}
abs_margin_diffs = {}
mh_diffs = {}
for key in processed:
    margin_diffs[key] = -processed[key][4]
    abs_margin_diffs[key] = -np.abs(processed[key][4])
    mh_diffs[key] = calc_mah(cov_matrix, means, processed[key][0])

In [None]:
ids = processed['val'][3]
diffs = diffs_vog['val']
for i in diffs[ids].argsort()[-10:]:
    print(diffs[ids[i]], spans[ids[i]], tags[processed['val'][2][ids[i]]], tags[processed['val'][1][ids[i]].argmax()])
    print(texts[ids[i]])

4.169192010762451e-07 home office org O
The Home Office organized seven working groups to prepare the report on the root causes of the July 7 suicide attacks by British Muslims that killed 52 people in London .
4.2320633356716694e-07 cabinet org org
He says any measures must have Cabinet approval .
4.304048317013869e-07 helmund province geo geo
They say the policemen disappeared in Helmund Province .
4.3479030695632983e-07 helmand geo geo
An exchange of gunfire between British troops and Afghan police in southern Helmand province Thursday left an Afghan policeman dead .
4.3940648382627613e-07 tomb geo O
Mr. Bush will make remarks and lay a wreath at the Tomb of the Unknowns which contains the remains of unidentified U.S. service members who died in World Wars I and II - and in the Korean War .
4.401264387790166e-07 helmand geo geo
The British Defense Ministry says a British soldier serving with the NATO force in Afghanistan has been shot dead in southern Helmand province .
4.4484732388

In [None]:
diffs = margin_diffs['val']
for i in diffs[ids].argsort()[-10:]:
    print(diffs[ids[i]], spans[ids[i]], tags[processed['val'][2][ids[i]]], tags[processed['val'][1][ids[i]].argmax()])
    print(texts[ids[i]])

7.6897073 ebola nat geo
In the past , the Democratic Republic of Congo has endured outbreaks of both Marburg and Ebola , two types of hemorrhagic fever caused by viruses that can attack the central nervous system and cause bleeding from the eyes , ears , and other parts of the body .
7.8210087 arab geo gpe
Pakistan 's army has conducted a series of counter-terrorism operations in North and South Waziristan in the past three years , aimed at trapping Arab , Afghan and Central Asian militants with links to the Taliban and al-Qaida .
7.902258 azerbaijani geo gpe
The current global economic slowdown presents some challenges for the Azerbaijani economy as oil prices remain below their mid-2008 highs , highlighting Azerbaijan 's reliance on energy exports and lackluster attempts to diversify its economy .
7.9583364 british org gpe
Originally settled by Polynesian emigrants from surrounding island groups , the Tokelau Islands were made a British protectorate in 1889 .
7.9825544 mw org O
Nepal

In [None]:
diffs = abs_margin_diffs['val']
for i in diffs[ids].argsort()[-10:]:
    print(diffs[ids[i]], spans[ids[i]], tags[processed['val'][2][ids[i]]], tags[processed['val'][1][ids[i]].argmax()])
    print(texts[ids[i]])

-0.023964643 bild am sonntag org org
Ms. Merkel told the Bild am Sonntag newspaper the government is doing all it can to rescue Susanne Osthoff and her Iraqi driver , who disappeared nine days ago , on November 25 .
-0.017394066 general bozize org per
Though the government has the tacit support of civil society groups and the main parties , a wide field of candidates contested the municipal , legislative , and presidential elections held in March and May of 2005 in which General BOZIZE was affirmed as president .
-0.016691208 uganda ' s health ministry org org
Uganda 's Health Ministry says the country has confirmed its first case of H1N1 swine flu .
-0.016061783 " caliph of cologne org O
Metin Kaplan , also known as the " Caliph of Cologne , " faces charges of trying to overthrow Turkey 's constitutional order .
-0.015525579 u. n. org org
He said in a statement that he is deeply concerned by continued threats against U.N. personnel and by reports that more violent protests and attacks

In [None]:
diffs = mh_diffs['val']
for i in diffs[ids].argsort()[-10:]:
    print(diffs[ids[i]], spans[ids[i]], tags[processed['val'][2][ids[i]]], tags[processed['val'][1][ids[i]].argmax()])
    print(texts[ids[i]])

9366.179746382695 from tim tim
A national shopping research group , ShopperTrak RCT corp. , reported Friday 's total sales at $ 8 billion , down about 0.9 percent from last year .
9659.56802364299 people gpe org
The modern country of Mongolia , however , represents only part of the Mongols ' historical homeland ; more ethnic Mongolians live in the Inner Mongolia Autonomous Region in the People 's Republic of China than in Mongolia .
10003.488265525782 tomb geo O
Mr. Bush will make remarks and lay a wreath at the Tomb of the Unknowns which contains the remains of unidentified U.S. service members who died in World Wars I and II - and in the Korean War .
10027.933075800538 corruption org org
The report also calls on African governments to commit to transparency and to ratify the U.N. Covenant on Corruption .
10596.28947828256 greek art gpe
The busy season meant forecasters exhausted their list of names , forcing them to use the Greek alphabet to name storms for the first time .
11485.539

In [None]:
print(calc_rocauc(diffs_vog))
print(calc_rocauc(margin_diffs))
print(calc_rocauc(abs_margin_diffs))
print(calc_rocauc(mh_diffs))

{}
{}
{}
{}


In [None]:
print(spearmanr(diffs_vog['val'], margin_diffs['val']))
print(spearmanr(diffs_vog['val'], abs_margin_diffs['val']))
print(spearmanr(diffs_vog['val'], mh_diffs['val']))
print(spearmanr(margin_diffs['val'], abs_margin_diffs['val']))
print(spearmanr(margin_diffs['val'], mh_diffs['val']))
print(spearmanr(abs_margin_diffs['val'], mh_diffs['val']))

SignificanceResult(statistic=0.6758830642461333, pvalue=0.0)
SignificanceResult(statistic=0.6441229857305452, pvalue=0.0)
SignificanceResult(statistic=0.39805498436710224, pvalue=8.599102796705937e-250)
SignificanceResult(statistic=0.9654502156000898, pvalue=0.0)
SignificanceResult(statistic=0.2031913925699588, pvalue=1.5864339015325852e-62)
SignificanceResult(statistic=0.19619873025088622, pvalue=2.4443886817512198e-58)


In [None]:
print(spearmanr(diffs_vog['val'], 1-processed['val'][5]))
print(spearmanr(margin_diffs['val'], 1-processed['val'][5]))
print(spearmanr(abs_margin_diffs['val'], 1-processed['val'][5]))
print(spearmanr(mh_diffs['val'], 1-processed['val'][5]))

SignificanceResult(statistic=0.011747357534651006, pvalue=0.33968372929725743)
SignificanceResult(statistic=0.01560611012005201, pvalue=0.20463666021217125)
SignificanceResult(statistic=0.014843153353402128, pvalue=0.2276508580262295)
SignificanceResult(statistic=0.006459881217910198, pvalue=0.5995641909510832)


In [None]:
mask = processed['val'][5] < 0.7
print(ks_2samp(diffs_vog['val'][~mask], diffs_vog['val'][mask]))
print(ks_2samp(margin_diffs['val'][~mask], margin_diffs['val'][mask]))
print(ks_2samp(abs_margin_diffs['val'][~mask], abs_margin_diffs['val'][mask]))
print(ks_2samp(mh_diffs['val'][~mask], mh_diffs['val'][mask]))

KstestResult(statistic=0.4585984045238817, pvalue=0.11294687977040595, statistic_location=1.8456918788426468e-08, statistic_sign=1)
KstestResult(statistic=0.4502171059274967, pvalue=0.1253603096277013, statistic_location=-3.1282942, statistic_sign=1)
KstestResult(statistic=0.47702716348581237, pvalue=0.08925335592223553, statistic_location=-3.1282942, statistic_sign=1)
KstestResult(statistic=0.2952135716449561, pvalue=0.5761146385398366, statistic_location=739.8628451785771, statistic_sign=1)


In [None]:
print(diffs_vog['val'][~mask].mean(), diffs_vog['val'][mask].mean())
print(margin_diffs['val'][~mask].mean(), margin_diffs['val'][mask].mean())
print(abs_margin_diffs['val'][~mask].mean(), abs_margin_diffs['val'][mask].mean())
print(mh_diffs['val'][~mask].mean(), mh_diffs['val'][mask].mean())

1.2005555081930867e-07 7.093571001935776e-08
-5.194937 -3.1974075
-5.6114206 -4.1497297
1944.2404437000073 1763.2661517050146


In [None]:
print(roc_auc_score(np.concatenate([np.zeros(len(processed['val'][5]) - mask.sum()), np.ones(mask.sum())]),
                    np.concatenate([diffs_vog['val'][~mask], diffs_vog['val'][mask]])))
print(roc_auc_score(np.concatenate([np.zeros(len(processed['val'][5]) - mask.sum()), np.ones(mask.sum())]),
                    np.concatenate([margin_diffs['val'][~mask], margin_diffs['val'][mask]])))
print(roc_auc_score(np.concatenate([np.zeros(len(processed['val'][5]) - mask.sum()), np.ones(mask.sum())]),
                    np.concatenate([abs_margin_diffs['val'][~mask], abs_margin_diffs['val'][mask]])))
print(roc_auc_score(np.concatenate([np.zeros(len(processed['val'][5]) - mask.sum()), np.ones(mask.sum())]),
                    np.concatenate([mh_diffs['val'][~mask], mh_diffs['val'][mask]])))

0.6126426335453903
0.6495506412198323
0.642254872260931
0.5619256790871453
