In [6]:
from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
import numpy as np
import json
import io
import wandb

In [7]:
INPUT_WINDOW = 900


def get_start_end_pos(start, end, tokens_gaps):
    tokens_match = []
    tokens_gaps = tokens_gaps

    for i, token_gap in enumerate(tokens_gaps[:-1]):
        if token_gap[1] >= start and token_gap[0] <= end:
            tokens_match.append(i)

    first = tokens_match[0]
    last  = tokens_match[-1]

    if len(tokens_gaps) <= INPUT_WINDOW:
        return 0, len(tokens_gaps), first, last
    
    space_available = (INPUT_WINDOW - (last - first)) // 2
    offset_left  = np.maximum(0, first - space_available)
    offset_right = np.minimum(len(tokens_gaps), last + space_available)

    return offset_left, offset_right, first - offset_left, last - offset_left

In [31]:
class CultCode(Dataset):
    def __init__(self, tokenizer, train_on_single_tags=False):
        cult_code = json.load(io.open('./datasets/related/cult_code.json'))
        markups = cult_code['dataset']['markups']

        self.texts = []
        self.uno_span = []
        self.multi_span = []
        for i, mrkp in enumerate(markups):
            input_ids, token_type_ids, attention_mask, offset_mapping = tokenizer(mrkp['text'], return_tensors="pt", return_offsets_mapping=True).values()
            
            for rel in mrkp['relations']:
                if train_on_single_tags and len(rel['tags']) > 1:
                    continue
                
                if len(rel['spans']) == 1:
                    begin = mrkp['spans'][rel['spans'][0]]['begin']
                    end   = mrkp['spans'][rel['spans'][0]]['end']
                    label = rel['tags'][0]
                    offset_left, offset_right, left_begin, right_begin = get_start_end_pos(begin, end, offset_mapping)
                    
                    self.uno_span.append(
                        (label, 
                         input_ids[offset_left, offset_right], 
                         token_type_ids[offset_left, offset_right], 
                         attention_mask[offset_left, offset_right],
                         left_begin, right_begin
                        )
                    )
                
                elif len(rel['spans']) > 1:
                    mult = []
                    for span in rel['spans']:
                        begin = mrkp['spans'][span]['begin']
                        end   = mrkp['spans'][span]['end']
                        mult.append([begin, end])
                    self.multi_span.append((i, mult))
    
    def get_len_eval(self):
        return len(self.multi_span)
    
    def get_item_eval(self, index):
        i, span = self.multi_span[index]
        return {
            'text': self.texts[i],
            'span': span
        }
    
    def __len__(self):
        return len(self.uno_span)
    
    def __getitem__(self, index):
        i, span = self.uno_span[index]
        return {
            'text': self.texts[i],
            'span': span
        }

In [6]:
class NegativeSamples(Dataset):
    def __init__(self, train_on_single_tags=False, ):
        cult_code = json.load(io.open('./datasets/related/cult_code.json'))
        markups = cult_code['dataset']['markups']
        
        self.texts = []
        for i, mrkp in enumerate(markups):
            self.texts.append(mrkp['text'])
    
    def __len__(self):
        return len(self.texts) * 100
    
    def __getitem__(self, index):
        txt_ind = torch.randint(low=0, high=len(self.texts), size=(1,)).item()
        in_text_len = torch.randint(low=0, high=len(self.texts[txt_ind]) - 60, size=(1,)).item()
        delta = torch.randint(low=0, high=len(self.texts), size=(1,)).item()
        
        return {
            'text': self.texts[txt_ind],
            'span': [in_text_len, in_text_len+delta]
        }

In [8]:
class MultiSpanValidator(torch.nn.Module):
    def __init__(self, out_dim=768, encoder='microsoft/deberta-v3-base',
                layers_to_freeze=5, device='cpu'):
        super().__init__()
        
        self.out_dim = out_dim
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(encoder)
        self.model = AutoModel.from_pretrained(encoder)
        
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(out_dim, out_dim // 3),
            torch.nn.ReLU(),
            torch.nn.Linear(out_dim // 3, 2)
        )
        
        for layer in self.model.encoder.layer[:layers_to_freeze]:
            for param in layer.parameters():
                param.requires_grad = False
    
    def forward(self, x):
        text = x['text']
        span = torch.stack(x['span']).T

        batch = self.tokenizer(text, return_tensors='pt', padding=True, )
        tokens_spans = []
    
        for i, (start, end) in enumerate(span):
            end = max(end, len(text[i])-2)
            token_start = batch.char_to_token(batch_or_char_index=i, char_index=start)
            token_end = batch.char_to_token(batch_or_char_index=i, char_index=end)
            if token_start is None or token_end is None:
                continue
            
            tokens_spans.append((i, token_start, token_end))

        if not len(tokens_spans):
            return None

        tens = self.model(**batch).last_hidden_state
        out = []
        for i, start, end in tokens_spans:
            out.append(torch.mean(tens[i, start:end+1], dim=0))
        out = torch.stack(out)

        return self.classifier(out)

In [9]:
model = MultiSpanValidator()



In [20]:
text = '''Until recently, the prevailing view assumed lorem ipsum was born as a nonsense text. “It’s not Latin, though it looks like it, and it actually says nothing,” Before & After magazine answered a curious reader, “Its ‘words’ loosely approximate the frequency with which letters occur in English, which is why at a glance it looks pretty real.”

As Cicero would put it, “Um, not so fast.”

The placeholder text, beginning with the line “Lorem ipsum dolor sit amet, consectetur adipiscing elit”, looks like Latin because in its youth, centuries ago, it was Latin.

Richard McClintock, a Latin scholar from Hampden-Sydney College, is credited with discovering the source behind the ubiquitous filler text. In seeing a sample of lorem ipsum, his interest was piqued by consectetur—a genuine, albeit rare, Latin word. Consulting a Latin dictionary led McClintock to a passage from De Finibus Bonorum et Malorum (“On the Extremes of Good and Evil”), a first-century B.C. text from the Roman philosopher Cicero.

In particular, the garbled words of lorem ipsum bear an unmistakable resemblance to sections 1.10.32–33 of Cicero’s work, with the most notable passage excerpted below:

“Neque porro quisquam est, qui dolorem ipsum quia dolor sit amet, consectetur, adipisci velit, sed quia non numquam eius modi tempora incidunt ut labore et dolore magnam aliquam quaerat voluptatem.”

A 1914 English translation by Harris Rackham reads:

“Nor is there anyone who loves or pursues or desires to obtain pain of itself, because it is pain, but occasionally circumstances occur in which toil and pain can procure him some great pleasure.”

McClintock’s eye for detail certainly helped narrow the whereabouts of lorem ipsum’s origin, however, the “how and when” still remain something of a mystery, with competing theories and timelines.'''

In [24]:
len(model.tokenizer.tokenize(text))

425

In [30]:
model.tokenizer(text, return_offsets_mapping=True)

{'input_ids': [1, 6550, 1104, 261, 262, 17138, 866, 6510, 24442, 358, 106652, 284, 1338, 283, 266, 13003, 1529, 260, 317, 1325, 276, 268, 298, 4996, 261, 651, 278, 1127, 334, 278, 261, 263, 278, 675, 652, 942, 261, 318, 2306, 429, 643, 3421, 4951, 266, 5348, 3684, 261, 317, 1325, 268, 534, 31826, 276, 18841, 14942, 262, 4436, 275, 319, 3527, 3080, 267, 1342, 261, 319, 269, 579, 288, 266, 9192, 278, 1127, 890, 609, 260, 318, 463, 50362, 338, 552, 278, 261, 317, 48179, 261, 298, 324, 1274, 260, 318, 279, 11736, 1529, 261, 1547, 275, 262, 683, 317, 17183, 368, 358, 106652, 72880, 2146, 266, 13595, 261, 4636, 64706, 473, 18324, 266, 27744, 59714, 510, 5173, 1632, 318, 261, 1127, 334, 4996, 401, 267, 359, 3020, 261, 6156, 824, 261, 278, 284, 4996, 260, 3155, 99660, 261, 266, 4996, 13146, 292, 67795, 271, 74640, 1676, 261, 269, 11653, 275, 10006, 262, 1271, 931, 262, 18994, 18594, 1529, 260, 344, 1769, 266, 2783, 265, 24442, 358, 106652, 261, 315, 981, 284, 57436, 293, 4636, 64706, 473, 1832

In [8]:
def evaluate(
    model,
    dts,
    step
):
    model.eval()
    
    success_rate = 0
    
    for i in range(dts.get_len_eval()):
        text, spans = dts.get_item_eval(1).values()
        batch = model.tokenizer(text, return_tensors='pt', padding=True)
        tens = model.model(**batch).last_hidden_state
        
        out = []

        for token_start, token_end in spans:
            token_start = batch.char_to_token(batch_or_char_index=0, char_index=token_start)
            token_end = batch.char_to_token(batch_or_char_index=0, char_index=token_end)
            out.append(torch.mean(tens[0, token_start:token_end+1], dim=0))
            
        tns = torch.stack(out, dim=0).mean(dim=0)

        success_rate += (model.classifier(tns.unsqueeze(0))[0, 1] > 0).item()
        
    print(f'test/success: {success_rate / dts.get_len_eval()}, step={step}')
    
def train(
    model,
    loss, 
    optim, 
    scheduler,
    clt_dtl,
    neg_dtl
):
    model.train()
    
    neg_iterator = iter(neg_dtl)
    
    for pos_batch in clt_dtl:
        optim.zero_grad()
        neg_batch = next(neg_iterator)
        
        try:
            out_pos = model(pos_batch)
            out_neg = model(neg_batch)
            out = torch.concat([out_pos, out_neg], dim=0)
        except:
            continue
        
        target_ones = torch.ones(len(out_pos))
        target_zeros = torch.zeros(len(out_neg))
        targets = torch.concat([target_ones, target_zeros]).type(torch.long)

        ls = loss(out, targets)
        ls.backward()
        
        optim.step()
        scheduler.step()


def train_loop(
    model, 
    loss,
    optim, 
    scheduler,
    clt_dtl,
    neg_dtl,
    clt_dts,
    num_training_steps
):
    for i in range(num_training_steps):
        if i % 250 == 0:
            # evaluate(model, clt_dts, i)
            pass
    
        train(model, loss, optim, scheduler, clt_dtl, neg_dtl)
        
    torch.save(model.state_dict(), 'model.kpl')
    

In [9]:
config = {
    'batch_size': 2,
    'shuffle': True,
    'pin_memory_device': True,
    'device': 'cpu',
    'num_warmup_steps': 250,
    'num_training_steps': 1000
}

In [10]:
cltcode = CultCode()
ngtsmpl = NegativeSamples()

dtl_clt = DataLoader(cltcode, batch_size=config['batch_size'], shuffle=config['shuffle'], pin_memory_device=config['device'])
dtl_neg = DataLoader(ngtsmpl, batch_size=config['batch_size'], shuffle=config['shuffle'], pin_memory_device=config['device'])

model = MultiSpanValidator(device=config['device'])
model.to(config['device'])
loss = torch.nn.CrossEntropyLoss(reduction='mean')
optim = torch.optim.AdamW(model.parameters())
scheduler = get_cosine_schedule_with_warmup(optim, num_warmup_steps=config['num_warmup_steps'], num_training_steps=config['num_training_steps'])



In [11]:
train_loop(model, loss, optim, scheduler, dtl_clt, dtl_neg, cltcode, config['num_training_steps'])



In [12]:
from matplotlib import pyplot as plt
import json
import io

In [13]:
cult_code = json.load(io.open('./datasets/related/cult_code.json'))
markups = cult_code['dataset']['markups']

In [14]:
from collections import defaultdict

In [15]:
labels = defaultdict(lambda: 0)
for mrkp in markups:
    for rel in mrkp['relations']:
        if len(rel['spans']) == 1 and len(rel['tags']) == 2:
            for tag in rel['tags']:
                labels[tag]+= 1

In [16]:
lbl = {}
for i in range(113):
    lbl[cult_code['dataset']['relation_tags'][i]] = labels[i]

In [27]:
keys = sorted(lbl, key=lambda k: lbl[k])
keys[-3]

'Материальные ценности'

In [22]:
lbl = {}
for i in range(113):
    lbl[cult_code['dataset']['relation_tags'][i]] = labels[i], i

In [29]:
lbl['Правосознание (гражданская активность, гражданственность)']

(64, 22)

In [90]:
markups[0]

{'assessor': 4,
 'text': '23 февраля в войсковой части 5526 прошло чествование воспитанников военно-патриотического клуба «Крепость». Заместитель командира части по идеологической работе майор Олег Ляшук отметил, что воспитание подрастающего поколения в патриотическом ключе, в стремлении к здоровому образу жизни, уважению к традициям, культурным ценностям и исторической памяти государства является главным профилактическим фактором безнравственности и аморальности. «Вместе мы вносим огромный вклад в будущее страны и нравственное здоровье нашего общества. Служба Родине во все времена была почетна. А служить можно по-разному, и не обязательно с оружием в руках. Служить можно в любом возрасте, служить можно и парню, и девушке. Служить можно и словом, и делом!» – подытожил Олег Ляшук.',
 'spans': [{'begin': 191, 'end': 382, 'id': 0, 'tags': []},
  {'begin': 444, 'end': 576, 'id': 1, 'tags': []}],
 'relations': [{'spans': [0], 'tags': [1, 5, 10, 2, 0]},
  {'spans': [1], 'tags': [0, 9, 1]}]}

In [251]:
mat_cents_1 = []
for ind, mrkp in enumerate(markups):
    for ind_2, rel in enumerate(mrkp['relations']):
        if len(rel['spans']) > 1 and len(rel['tags']) == 2 and 22 in rel['tags']:
            mat_cents_1.append((ind, ind_2, mrkp['text'], [(mrkp['spans'][sp]['begin'], mrkp['spans'][sp]['end']) for sp in rel['spans']]))

In [252]:
len(mat_cents_1)

7

In [30]:
mat_cents = []
for ind, mrkp in enumerate(markups):
    for ind_2, rel in enumerate(mrkp['relations']):
        if len(rel['spans']) > 1 and len(rel['tags']) == 2 and 22 in rel['tags']:
            mat_cents.append((ind, ind_2, mrkp['text'], [(mrkp['spans'][sp]['begin'], mrkp['spans'][sp]['end']) for sp in rel['spans']]))

In [31]:
len(mat_cents)

7

In [253]:
def iter_dts(dts):
    for mrkp_index, rel_index, text, spans in dts:
        print(f'({mrkp_index}, {rel_index})')
        print(text)
        print('=====')
        for i, sp in enumerate(spans):
            print(f'{i}: {text[sp[0]:sp[1]]}')
        yield ''

iter = iter_dts(mat_cents_1)

In [261]:
next(iter)

StopIteration: 

In [203]:
len(mat_cents)

35

In [225]:
awkward = [
    (45, 5),
    (120, 0),
    (404, 0)
]

semigrate = [
    (94, 5),
    (151, 3),
    (156, 0),
    (170, 0),
    (172, 0),
    (256, 1),
    (281, 0),
    (281, 4),
    (362, 2),
    (389, 2),
    (450, 1),
    (541, 0),
    (562, 0),
    (624, 9),
    (665, 0)
]

In [245]:
printing = []
for ind, mrkp in enumerate(markups):
    for ind_2, rel in enumerate(mrkp['relations']):
        if (ind, ind_2) in awkward:
            printing.append((ind, ind_2, mrkp['text'], [(mrkp['spans'][sp]['begin'], mrkp['spans'][sp]['end']) for sp in rel['spans']]))

In [223]:
len(printing)

1

In [246]:
def iter_dts(dts):
    for mrkp_index, rel_index, text, spans in dts:
        print(f'({mrkp_index}, {rel_index})')
        print(text)
        print('=====')
        for i, sp in enumerate(spans):
            print(f'{i}: {text[sp[0]:sp[1]]}')
        yield ''

iter = iter_dts(printing)

In [250]:
next(iter)

StopIteration: 