In [1]:
import sys
sys.path.append('../')

In [2]:
import torch

In [3]:
import listops.data as _data
import listops.data_processing.python.loading as _loading
import listops.model as _model

In [4]:
datasets = _data.get_datasets("var_5_50_nosm_20000")

./data_processing/python/listops/data/d2_ml50_nosm
./data_processing/python/listops/data/d3_ml50_nosm
./data_processing/python/listops/data/d4_ml50_nosm
./data_processing/python/listops/data/d5_ml50_nosm
./data_processing/python/listops/data/d1_ml50_nosm
maxnums
[20000, 2000, 2000]
[2, 3, 4, 5]
file path: ./data_processing/python/listops/data/d2_ml50_nosm/train.tsv
number of skipped sentences due to length > inf: 0
number of skipped sentences due to length < 2: 0
file path: ./data_processing/python/listops/data/d3_ml50_nosm/train.tsv
number of skipped sentences due to length > inf: 0
number of skipped sentences due to length < 2: 0
file path: ./data_processing/python/listops/data/d4_ml50_nosm/train.tsv
number of skipped sentences due to length > inf: 0
number of skipped sentences due to length < 2: 0
file path: ./data_processing/python/listops/data/d5_ml50_nosm/train.tsv
number of skipped sentences due to length > inf: 0
number of skipped sentences due to length < 2: 0
file path: ./dat

In [5]:
train_loader, val_loader, test_loader  = _data.get_dataloaders(datasets, batchsize=50)

ds
<listops.data_processing.python.datasets.MultiListOpsDataset object at 0x7f8e98168650>
ds
<listops.data_processing.python.datasets.MultiListOpsDataset object at 0x7f8e488680d0>
ds
<listops.data_processing.python.datasets.MultiListOpsDataset object at 0x7f8e03e71090>


In [6]:
"""
x - tokens
y - value
arcs - true graph
lengths
depths
"""

x, y, arcs, lengths, depths = next(iter(train_loader)) 

In [7]:
from listops.data_processing.python.loading import ix_to_word

In [9]:
" ".join([ix_to_word[elem] for elem in x[-6].numpy().tolist() if elem != 15])

'[MIN 0 0 9 7 ]'

In [10]:
from torch_struct import DependencyCRF
from listops.model_modules.sampler import DependencySampler

In [11]:
class ProjectiveSampler(DependencySampler):

    def __init__(self, noise, tau):
        assert noise in set(['gumbel', 'gaussian'])
        super(ProjectiveSampler, self).__init__("soft", noise, tau, True, False)
        
    def inject_noise(self, A):
        if self.noise == "gumbel":
            u = torch.distributions.utils.clamp_probs(torch.rand_like(A))
            g = u.log().neg().log().neg()
            return (A + g) / self.tau
        elif self.noise == "gaussian":
            pass 

    def sample(self, A, lengths, mode):
        if mode == "soft":
            return DependencyCRF(self.inject_noise(A), lengths).marginals
        elif mode == "hard":
            return DependencyCRF(self.inject_noise(A), lengths).argmax.detach()

In [12]:
sampler = ProjectiveSampler("gumbel", 1.0)

In [13]:
m = _model.get_school_model(sampler)

  "num_layers={}".format(dropout, num_layers))


In [16]:
out = m(x, arcs, lengths)

In [23]:
opt = torch.optim.AdamW(m.parameters())

In [31]:
for batch_idx, (x, y, arcs, lengths, depths) in enumerate(train_loader):
    opt.zero_grad()
    bs = x.shape[0]
    
    with torch.set_grad_enabled(True):
        pred_logits = m(x, arcs, lengths)
        loss = torch.nn.functional.cross_entropy(pred_logits, y)
        loss.backward()
        opt.step()
    print("loss: {:.3f}".format(loss.item()))

    acc = (pred_logits.argmax(-1) == y).float().mean()
    print("acc: {:.3f}".format(acc.item()))
    print()

loss: 1.457
acc: 0.420

loss: 1.687
acc: 0.460

loss: 1.527
acc: 0.320



KeyboardInterrupt: 

In [32]:
batch_idx

3

In [33]:
2000 * 50

100000

In [35]:
len(val_loader)

10

In [36]:
x, y, arcs, lengths, depths = next(iter(val_loader))

In [37]:
m.eval()

Model(
  (base): ModelBase(
    (sampler): ProjectiveSampler()
    (computer): KipfMLPGNN(
      (embd): Embedding(16, 60, padding_idx=15)
      (msg_fc1): ModuleList(
        (0): Linear(in_features=120, out_features=60, bias=True)
      )
      (msg_fc2): ModuleList(
        (0): Linear(in_features=60, out_features=60, bias=True)
      )
    )
    (decoder): Decoder(
      (net): Sequential(
        (0): Linear(in_features=60, out_features=60, bias=True)
        (1): ReLU()
        (2): Linear(in_features=60, out_features=60, bias=True)
        (3): ReLU()
        (4): Linear(in_features=60, out_features=10, bias=True)
      )
    )
  )
  (archer): Archer(
    (embd): Embedding(16, 60, padding_idx=15)
    (head_lstm): LSTM(60, 60, batch_first=True, dropout=0.1)
    (head_dropout): Dropout(p=0.1, inplace=False)
    (modif_lstm): LSTM(60, 60, batch_first=True, dropout=0.1)
    (modif_dropout): Dropout(p=0.1, inplace=False)
  )
)

In [38]:
pred = m(x, arcs, lengths)

In [40]:
m.sample.size()

torch.Size([1000, 28, 28])

In [62]:
def compute_metrics(x, sample, arcs, lengths):
    one = torch.tensor(1.0).cuda() if sample.is_cuda else torch.tensor(1.0)
    zero = torch.tensor(0.0).cuda() if sample.is_cuda else torch.tensor(0.0)
    # Compute true/false positives/negatives for metric calculations.
    maxlen = arcs.shape[-1]
    pad_tn = maxlen - lengths
    tp = torch.where(sample * arcs == 1.0, one, zero).sum((-1, -2))
    tn = torch.where(sample + arcs == 0.0, one, zero).sum((-1, -2)) - pad_tn
    fp = torch.where(sample - arcs == 1.0, one, zero).sum((-1, -2))
    fn = torch.where(sample - arcs == -1.0, one, zero).sum((-1, -2))

    # Calculate IoUs.
    iou = torch.mean((tp) / (tp + fp + fn)).cpu()
    # Calculate precision (attachment).
    precision = torch.mean(tp / (tp + fp)).cpu()
    # Calculate recall.
    recall = torch.mean(tp / (tp + fn)).cpu()
    # Calculate parse accuracy
    parse_acc = (sample == arcs).all(-1).all(-1).float().mean()

    # Clean computations
    # Compute clean attch_score which ignores "]" symbol (requires acces to x)
    close_ix = _loading.word_to_ix[']']
    clean_mask = (x != close_ix).unsqueeze(1).expand_as(arcs) # expands along 2nd dimension
    clean_mask = clean_mask & clean_mask.transpose(1, 2)

    cltp = torch.where((sample * arcs == 1.0) * clean_mask, one, zero).sum((-1, -2))
    cltn = torch.where((sample + arcs == 0.0) * clean_mask, one, zero).sum((-1, -2)) - pad_tn
    clfp = torch.where((sample - arcs == 1.0) * clean_mask, one, zero).sum((-1, -2))
    clfn = torch.where((sample - arcs == -1.0) * clean_mask, one, zero).sum((-1, -2))

    # Calculate IoUs.
    idx = (cltp + clfp + clfn) > 0
    clious = torch.ones_like(cltp)
    clious[idx] = cltp[idx] / (cltp + clfp + clfn)[idx]
    cliou = clious.mean().cpu()
    # Calculate precision (attachment).
    idx = (cltp + clfp) > 0
    clprecisions = torch.zeros_like(cltp)
    clprecisions[idx] = cltp[idx] / (cltp + clfp)[idx]
    clprecision = clprecisions.mean().cpu()
    # Calculate recall.
    idx = (cltp + clfn) > 0
    clrecalls = torch.ones_like(cltp)
    clrecalls[idx] = cltp[idx] / (cltp + clfn)[idx]
    clrecall = clrecalls.mean().cpu()
    # Calculate parse accuracy
    clparse_acc = (sample * clean_mask == arcs * clean_mask).all(-1).all(-1).float().mean()

    if torch.isnan(cliou) or torch.isnan(clprecision) or torch.isnan(clrecall) or torch.isnan(clparse_acc):
        print('Found NaN in cl metrics')

    return iou, cliou, precision, clprecision, recall, clrecall, parse_acc, clparse_acc

In [64]:
iou, cliou, precision, clprecision, recall, clrecall, parse_acc, clparse_acc = (compute_metrics(x, m.sample, arcs, lengths))

In [83]:
(m.sample == arcs).all(-1).all(-1).float().mean()

tensor(0.)

In [73]:
(m.sample == arcs).float().mean()

tensor(0.9748)

In [81]:
sum(p.numel() for p in m.parameters() if p.requires_grad)

79330