In [None]:
!pip install -q transformers datasets

In [None]:
import torch
import random
from transformers import *
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)
from torch import nn
from pylon.constraint import constraint
from pylon.sampling_solver import WeightedSamplingSolver
#from pylon.circuit_solver import SemanticLossCircuitSolver
from pylon.shaped_lazy_solver import ProductTNormSolver
from pylon.shaped_lazy_solver import GodelTNormSolver
from pylon.shaped_lazy_solver import LukasiewiczTNormSolver

#B_A0, B_A1, B_A2, B_A3, B_A4 = None, None, None, None, None

In [None]:
def process_ex(tokenizer, labels, orig_toks, v_labels, role_labels, max_seq_l=200, max_v_num=8, max_num_subtok=8):
    def pad(ls, length, symbol):
        if len(ls) >= length:
            return ls[:length]
        return ls + [symbol] * (length -len(ls))

    bos_tok, eos_tok = tokenizer.cls_token, tokenizer.sep_token
    assert(bos_tok is not None and eos_tok is not None)
    assert(len(v_labels) == len(role_labels))

    idx = sorted([(v, i) for i, v in enumerate(v_labels)])
    idx = [p[1] for p in idx]
    v_labels = [v_labels[i] for i in idx]
    role_labels = [role_labels[i] for i in idx]
    role_labels = [[labels.index(l) for l in r] for r in role_labels]

    # trimming
    orig_toks = orig_toks[:max_seq_l - 2]
    v_labels_new, role_labels_new = [],[]
    for v_idx, roles in zip(v_labels, role_labels):
        if v_idx >= max_seq_l - 2 or len(v_labels_new) >= max_v_num:
            break
        v_labels_new.append(v_idx)
        role_labels_new.append(roles)
    v_labels, role_labels = v_labels_new, role_labels_new

    # subtoks
    sent_subtoks = [tokenizer.tokenize(t) for t in orig_toks]
    tok_l = [len(subtoks) for subtoks in sent_subtoks]
    toks = [p for subtoks in sent_subtoks for p in subtoks] # flatterning

    # pad for CLS and SEP
    toks = [bos_tok] + toks + [eos_tok]
    tok_l = [1] + tok_l + [1]
    orig_toks = [bos_tok] + orig_toks + [eos_tok]
    v_labels = [l+1 for l in v_labels]  # incr v pos for CLS

    # padding
    tok_idx = tokenizer.convert_tokens_to_ids(toks)
    tok_idx = pad(tok_idx, max_seq_l, tokenizer.pad_token_id)
    v_labels = pad(v_labels, max_v_num, -1)
    role_labels = [pad(r, max_seq_l, 0) for r in role_labels]
    role_labels = pad(role_labels, max_v_num, [0]*max_seq_l)

    #
    acc = 0
    sub2tok_idx = []
    for l in tok_l:
        sub2tok_idx.append(pad([p for p in range(acc, acc+l)], max_num_subtok, -1))
        assert(len(sub2tok_idx[-1]) <= max_num_subtok)
        acc += l
    sub2tok_idx = pad(sub2tok_idx, max_seq_l, [-1 for _ in range(max_num_subtok)])

    return tok_idx, sub2tok_idx, v_labels, role_labels

In [None]:
def process_data(tokenizer, path):
    def group_ex(data):
        rs = {}
        for orig_toks, v_idx, role_labels in data:
            key = ' '.join(orig_toks)
            if key not in rs:
                rs[key] = {'v_labels': [], 'role_labels': []}
            rs[key]['v_labels'].append(v_idx)
            rs[key]['role_labels'].append(role_labels)
        return rs

    def batch_process(grouped, labels):
        rs = []
        for key, pack in grouped.items():
            orig_toks = key.split(' ')
            (tok_idx, sub2tok_idx, v_labels, role_labels) = process_ex(tokenizer, labels, orig_toks, pack['v_labels'], pack['role_labels'])
            rs.append((tok_idx, sub2tok_idx, v_labels, role_labels))
        return (torch.tensor([p[0] for p in rs], dtype=torch.long), 
                torch.tensor([p[1] for p in rs], dtype=torch.long), 
                torch.tensor([p[2] for p in rs], dtype=torch.long), 
                torch.tensor([p[3] for p in rs], dtype=torch.long))


    labels = []
    with open(path + '/labels.txt', 'r') as f:
        for l in f:
            if l.strip() == '':
                continue
            labels.append(l.strip())

    files = ['srl.train.txt', 'srl.test.txt']
    all_data = []
    for file in files:
        all_data.append([])
        file = path + '/' + file
        with open(file, 'r') as f:
            print('loading from', file)
            for line in f:
                if line.strip() == '':
                    continue
                parts = line.split('|||')
                assert(len(parts) == 2)
                v_idx = int(parts[0].strip().split()[0])
                orig_toks = parts[0].strip().split()[1:]
                roles = parts[1].strip().split()
                assert(len(orig_toks) == len(roles))
                all_data[-1].append((orig_toks, v_idx, roles))

    # group examples by orig_toks
    grouped = group_ex(all_data[0])
    tok_idx, sub2tok_idx, v_labels, role_labels = batch_process(grouped, labels)
    train = TensorDataset(tok_idx, sub2tok_idx, v_labels, role_labels)
    
    grouped = group_ex(all_data[1])
    tok_idx, sub2tok_idx, v_labels, role_labels = batch_process(grouped, labels)
    test = TensorDataset(tok_idx, sub2tok_idx, v_labels, role_labels)
    return train, test, labels


In [None]:
def evaluate(model, test_data, batch_size=8, device=torch.device('cpu')):
    test_data_loader = DataLoader(test_data, sampler=SequentialSampler(test_data), batch_size=batch_size)
    ex_cnt = 0
    predicate_cnt = 0
    iou_accumulator = 0.0
    global_sat = 0.0
    for _, batch in enumerate(test_data_loader):
        with torch.no_grad():
            tok_idx, sub2tok_idx, v_labels, role_labels = batch[0].to(device), batch[1].to(device), batch[2].to(device), batch[3].to(device)
            _, logits = model(tok_idx, sub2tok_idx, v_labels, role_labels)

            batch_l = tok_idx.shape[0]
            orig_l = (sub2tok_idx[:, :, 0] != -1).sum(-1).long()
            v_l = (v_labels != -1).sum(-1)
            for i in range(batch_l):
                v_i = v_labels[i, :v_l[i]]
                y_pred = logits[i, v_i, :orig_l[i]].max(-1)[1]    # use label index here
                satisfied = unique_role_check(y_pred)
                global_sat += float(satisfied.all(-1).sum())    # we want for each predicate, all tokens satisfy the constraint
                predicate_cnt += int(v_l[i])

            # quick stats
            orig_l = (sub2tok_idx[:, :, 0] != -1).sum(-1).long()
            v_l = (v_labels != -1).sum(-1)
            for i in range(batch_l):
                v_i = v_labels[i, :v_l[i]]
                p = logits[i, v_i, :orig_l[i]].argmax(-1)
                g = role_labels[i, :v_l[i], :orig_l[i]]
                if p.sum() > 0 or g.sum() > 0:
                    # get intersection over union without counting mutual O's
                    intersection = (p.masked_fill(p==0, -1) == g.masked_fill(g==0, -2)).sum()
                    union = ((p != 0) + (g != 0) > 0).sum()
                    iou_accumulator += float(intersection/union) if union > 0 else 0
                else:
                    iou_accumulator += 0
            ex_cnt += batch_l

    print('test set iou', iou_accumulator/ex_cnt)
    print('Global percent of predicates that violate the unique core role constraint', 1-global_sat/predicate_cnt)

In [None]:
class SRLModel(torch.nn.Module):
    def __init__(self, t_type, labels):
        super(SRLModel, self).__init__()
        self.transformer = AutoModel.from_pretrained(t_type)
        self.tokenizer = AutoTokenizer.from_pretrained(t_type)

        self.labels = labels
        self.num_label = len(labels)
        self.hidden_size = self.transformer.config.hidden_size
        self.g_va = nn.Sequential(
            nn.Linear(self.hidden_size*2, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.ReLU())
        self.f_v = nn.Linear(self.hidden_size, self.hidden_size)
        self.f_a = nn.Linear(self.hidden_size, self.hidden_size)
        self.label_layer = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.hidden_size, self.num_label))

    # use the idx (batch_l, seq_l, rs_l) (2nd dim) to select the middle dim of the content (batch_l, seq_l, d)
    #   the result has shape (batch_l, seq_l, rs_l, d)
    def batch_index2_select(self, content, idx, nul_idx):
        idx = idx.long()
        rs_l = idx.shape[-1]
        batch_l, seq_l, d = content.shape
        content = content.contiguous().view(-1, d)
        shift = torch.arange(0, batch_l).to(idx.device).long().view(batch_l, 1, 1)
        shift = shift * seq_l
        shifted = idx + shift
        rs = content[shifted].view(batch_l, seq_l, rs_l, d)
        mask = (idx != nul_idx).unsqueeze(-1)
        return rs * mask.to(rs)


    def forward(self, input_ids, sub2tok_idx, v_labels, role_labels):
        enc = self.transformer(input_ids=input_ids, return_dict=True).last_hidden_state
        enc = self.batch_index2_select(enc, sub2tok_idx, nul_idx=-1).sum(2) # summing subtoks -> (batch_l, seq_l, hidden_size)
        (batch_l, seq_l, hidden_size) = enc.shape

        v_enc = self.f_v(enc.view(-1, self.hidden_size)).view(batch_l, seq_l, 1, self.hidden_size)
        a_enc = self.f_a(enc.view(-1, self.hidden_size)).view(batch_l, 1, seq_l, self.hidden_size)

        va_enc = torch.cat([
            v_enc.expand(batch_l, seq_l, seq_l, self.hidden_size),
            a_enc.expand(batch_l, seq_l, seq_l, self.hidden_size)], dim=-1)
        va_enc = self.g_va(va_enc.view(-1, self.hidden_size*2))
        va_enc = va_enc.view(batch_l, seq_l, seq_l, self.hidden_size)

        logits = self.label_layer(va_enc.view(-1, hidden_size)).view(batch_l, seq_l, seq_l, self.num_label)

        # loss using gold predicates
        loss = torch.zeros(1, device=device)
        orig_l = (sub2tok_idx[:, :, 0] != -1).sum(-1).long()
        v_l = (v_labels != -1).sum(-1)
        for i in range(batch_l):
            v_i = v_labels[i, :v_l[i]]
            p = logits[i, v_i, :orig_l[i]].contiguous()
            g = role_labels[i, :v_l[i], :orig_l[i]].contiguous()
            loss += torch.nn.CrossEntropyLoss(reduction='mean')(p.view(-1, self.num_label), g.view(-1))
        loss = loss / batch_l
        return loss, logits

In [None]:
def unique_role(y):
    from pylon import lazy_torch as torch
    shape = y.size()
    y_ext = y + 1e-6
    y_ext = y_ext.unsqueeze(1).tile(1, shape[1], 1, 1).log()
    y_ext = y_ext + torch.eye(shape[1]).unsqueeze(0).unsqueeze(3).tile(shape[0], 1, 1, shape[2]) * -1e6
    y_ext = y_ext.exp()

    b_a0 = y[:, :, B_A0] <= y_ext[:, :, :, B_A0].logical_not().all(2)
    b_a1 = y[:, :, B_A1] <= y_ext[:, :, :, B_A1].logical_not().all(2)
    b_a2 = y[:, :, B_A2] <= y_ext[:, :, :, B_A2].logical_not().all(2)
    b_a3 = (y[:, :, B_A3]) <= y_ext[:, :, :, B_A3].logical_not().all(2)
    b_a4 = (y[:, :, B_A4]) <= y_ext[:, :, :, B_A4].logical_not().all(2)
    return b_a0.logical_and(b_a1).logical_and(b_a2).logical_and(b_a3).logical_and(b_a4)

# TODO, use unique_role function instead
def unique_role_check(y_pred):
    batch_l, seq_l = y_pred.shape
    y_ext = y_pred.view(batch_l, 1, seq_l).expand(batch_l, seq_l, seq_l)
    y_ext = y_ext.clone()    # make a copy since we gonna mask out diagonal in place
    y_ext.diagonal(0,1,2).zero_()    # mask out the diagonal to be 'O'

    satisfied = torch.ones(y_pred.shape, device=y_pred.device).bool()
    for label in [B_A0, B_A1, B_A2, B_A3, B_A4]:
        lhs = (y_pred == label)
        rhs = (y_ext == label).logical_not().all(2)
        sat = lhs.logical_not().logical_or(rhs)
        satisfied = satisfied.logical_and(sat)
    return satisfied

In [None]:
def train(model, solver, train_data, 
            lr=5e-5, batch_size=8, seed=1, grad_clip=1.0, lambda_constr=0.1, epoch=1,
            use_constr=False, device=torch.device('cpu')):
    random.seed(seed)
    torch.manual_seed(seed)
    if device.type == 'cuda':
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

    train_loader = DataLoader(train_data, sampler=RandomSampler(train_data), batch_size=batch_size)

    # mixing two datasets
    data_loaders = [train_loader]
    expanded_data_loader = [train_loader] * len(train_loader)
    random.shuffle(expanded_data_loader)

    # create optimizer
    weight_decay = 0
    no_decay = ['bias', 'LayerNorm.weight']
    named_params = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
    optimizer_grouped_parameters = [{'params': [p for n, p in named_params if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
        {'params': [p for n, p in named_params if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

    total_updates = epoch * len(expanded_data_loader)
    optimizer = AdamW(optimizer_grouped_parameters, lr=lr, eps=1e-8)
    lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_updates)

    update_cnt = 0
    loss_accumulator = 0.0
    model = model.to(device)
    model.zero_grad()
    for epoch_id in range(epoch):
        iters = [loader.__iter__() for loader in data_loaders]
        for loader in expanded_data_loader:
            batch = next(iters[data_loaders.index(loader)])

            tok_idx, sub2tok_idx, v_labels, role_labels = batch[0].to(device), batch[1].to(device), batch[2].to(device), batch[3].to(device)

            loss, logits = model(tok_idx, sub2tok_idx, v_labels, role_labels)

            if use_constr:
                constrain_func = constraint(unique_role, solver)
                batch_l = logits.shape[0]
                orig_l = (sub2tok_idx[:, :, 0] != -1).sum(-1).long()
                v_l = (v_labels != -1).sum(-1)
                for i in range(batch_l):
                    v_i = v_labels[i, :v_l[i]]
                    p = logits[i, v_i, :orig_l[i]]
                    c_loss = constrain_func(p)
                    loss = loss + c_loss * lambda_constr / v_l[i]

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            lr_scheduler.step()
            model.zero_grad()

            loss_accumulator += (loss.item())
            update_cnt += 1

            if update_cnt % 100 == 0:
                print('trained {0} steps, avg loss {1:4f}'.format(update_cnt, float(loss_accumulator/update_cnt)))

    return model

In [None]:
#device = torch.device('cpu')
device = torch.device("cuda", 1)
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

# data
print('processing data...')
train_data, test_data, labels = process_data(tokenizer, './examples/srl/')
B_A0 = labels.index('B-ARG0')
B_A1 = labels.index('B-ARG1')
B_A2 = labels.index('B-ARG2')
B_A3 = labels.index('B-ARG3')
B_A4 = labels.index('B-ARG4')

#
print('initializing models and solvers...')
solver = ProductTNormSolver()
model = SRLModel('distilbert-base-uncased', labels)

# train
print('training on gold data...')
model = train(model, solver, train_data, lr=5e-5, epoch=2, use_constr=False, batch_size=6, device=device)

print('evaluating on test set...')
evaluate(model, test_data, device=device)

solver = ProductTNormSolver()
model = SRLModel('distilbert-base-uncased', labels)

# train
print('training on gold data with constraint...')
model = train(model, solver, train_data, lr=5e-5, epoch=2, use_constr=True, batch_size=6, lambda_constr=0.01, device=device)

print('evaluating on test set...')
evaluate(model, test_data, device=device)