In [1]:
!git clone https://agadetsky@github.com/agadetsky/sochischool.git

Cloning into 'sochischool'...
remote: Enumerating objects: 239, done.[K
remote: Counting objects:  12% (1/8)[Kremote: Counting objects:  25% (2/8)[Kremote: Counting objects:  37% (3/8)[Kremote: Counting objects:  50% (4/8)[Kremote: Counting objects:  62% (5/8)[Kremote: Counting objects:  75% (6/8)[Kremote: Counting objects:  87% (7/8)[Kremote: Counting objects: 100% (8/8)[Kremote: Counting objects: 100% (8/8), done.[K
remote: Compressing objects: 100% (6/6), done.[K
remote: Total 239 (delta 1), reused 2 (delta 0), pack-reused 231[K
Receiving objects: 100% (239/239), 50.75 MiB | 28.81 MiB/s, done.
Resolving deltas: 100% (50/50), done.
Checking out files: 100% (194/194), done.


In [2]:
!pip install -qU git+https://github.com/harvardnlp/pytorch-struct

In [3]:
import sys
sys.path.append('/content/sochischool/')

In [4]:
import torch
import numpy as np
from torch_struct import DependencyCRF
from tqdm import tqdm

In [5]:
import listops.data as _data
import listops.data_processing.python.loading as _loading
import listops.model as _model
from listops.model_modules.sampler import DependencySampler, Sampler
from listops.model_modules.func_customparse import arcmask_from_lengths

In [6]:
datasets = _data.get_datasets(
    "var_5_50_nosm_20000",
    datadirpath="/content/sochischool/listops/data_processing/python/listops/data"
  )

/content/sochischool/listops/data_processing/python/listops/data/d2_ml50_nosm
/content/sochischool/listops/data_processing/python/listops/data/d3_ml50_nosm
/content/sochischool/listops/data_processing/python/listops/data/d4_ml50_nosm
/content/sochischool/listops/data_processing/python/listops/data/d5_ml50_nosm
/content/sochischool/listops/data_processing/python/listops/data/d1_ml50_nosm
maxnums
[20000, 2000, 2000]
[2, 3, 4, 5]
file path: /content/sochischool/listops/data_processing/python/listops/data/d2_ml50_nosm/train.tsv
number of skipped sentences due to length > inf: 0
number of skipped sentences due to length < 2: 0
file path: /content/sochischool/listops/data_processing/python/listops/data/d3_ml50_nosm/train.tsv
number of skipped sentences due to length > inf: 0
number of skipped sentences due to length < 2: 0
file path: /content/sochischool/listops/data_processing/python/listops/data/d4_ml50_nosm/train.tsv
number of skipped sentences due to length > inf: 0
number of skipped sen

In [7]:
train_loader, val_loader, test_loader  = _data.get_dataloaders(datasets, batchsize=50)

ds
<listops.data_processing.python.datasets.MultiListOpsDataset object at 0x7f59e7b7be90>
ds
<listops.data_processing.python.datasets.MultiListOpsDataset object at 0x7f59e233c0d0>
ds
<listops.data_processing.python.datasets.MultiListOpsDataset object at 0x7f59e5883050>


In [8]:
#" ".join([_loading.ix_to_word[elem] for elem in x[-6].cpu().numpy().tolist() if elem != 15])

In [9]:
class ProjectiveSampler(DependencySampler):

    def __init__(self, noise, tau):
        assert noise in set(['gumbel', 'gaussian'])
        super(ProjectiveSampler, self).__init__("soft", noise, tau, True, False)
        
    def inject_noise(self, A):
        if self.noise == "gumbel":
            u = torch.distributions.utils.clamp_probs(torch.rand_like(A))
            noise = u.log().neg().log().neg()
            return (A + noise) / self.tau
        elif self.noise == "gaussian":
            noise = torch.randn_like(A)
            return (A + noise) / self.tau

    def sample(self, A, lengths, mode):
        if mode == "soft":
            return DependencyCRF(self.inject_noise(A), lengths).marginals
        elif mode == "hard":
            return DependencyCRF(self.inject_noise(A), lengths).argmax.detach()

In [10]:
def mask(sample, lengths):
    maxlen = sample.shape[1]
    diag_mask = torch.eye(maxlen, device=sample.device, dtype=bool).unsqueeze(0)
    sample = sample.masked_fill(diag_mask, 0.0)
    arcmask = arcmask_from_lengths(sample, lengths)
    sample = sample.masked_fill(arcmask, 0.0)
    return sample

In [11]:
class IndependentSampler(Sampler):

    def __init__(self, noise, tau):
        assert noise in set(['logistic'])
        super(IndependentSampler, self).__init__("soft", noise, tau)

    def forward_train(self, A, lengths=None):
        sample = mask(self.sample(A, lengths, "soft"), lengths)
        return sample

    def forward_eval(self, A, lengths=None):
        sample = mask(self.sample(A, lengths, 'hard'), lengths)
        return sample

    def inject_noise(self, A):
        if self.noise == "logistic":
            u = torch.distributions.utils.clamp_probs(torch.rand_like(A))
            noise = u.log() - u.neg().log1p()
            return (A + noise) / self.tau

    def sample(self, A, lengths, mode):
        if mode == "soft":
            return self.inject_noise(A).sigmoid()
        elif mode == "hard":
            return (self.inject_noise(A.detach()) > 0.0).float()

In [16]:
def training(m, train_loader, opt, num_epochs):
    m.train()
    for _ in range(num_epochs):
        for batch_idx, (x, y, arcs, lengths, depths) in enumerate(tqdm(train_loader)):
            opt.zero_grad()

            x = x.cuda()
            y = y.cuda()
            arcs = arcs.cuda()
            lengths = lengths.cuda()
            
            with torch.set_grad_enabled(True):
                pred_logits = m(x, arcs, lengths)
                loss = torch.nn.functional.cross_entropy(pred_logits, y)
                loss.backward()
                opt.step()

In [17]:
def compute_metrics(sample, arcs, lengths):
    one = torch.tensor(1.0).cuda() if sample.is_cuda else torch.tensor(1.0)
    zero = torch.tensor(0.0).cuda() if sample.is_cuda else torch.tensor(0.0)
    # Compute true/false positives/negatives for metric calculations.
    maxlen = arcs.shape[-1]
    pad_tn = maxlen - lengths
    tp = torch.where(sample * arcs == 1.0, one, zero).sum((-1, -2))
    tn = torch.where(sample + arcs == 0.0, one, zero).sum((-1, -2)) - pad_tn
    fp = torch.where(sample - arcs == 1.0, one, zero).sum((-1, -2))
    fn = torch.where(sample - arcs == -1.0, one, zero).sum((-1, -2))

    # Calculate precision (attachment).
    precision = torch.mean(tp / (tp + fp)).cpu()
    # Calculate recall.
    recall = torch.mean(tp / (tp + fn)).cpu()

    return precision, recall

In [18]:
def validation(m, val_loader):
    m.eval()
    val_losses = []
    val_accs = []
    val_precs = []
    val_recs = []
    for batch_idx, (x, y, arcs, lengths, depths) in enumerate(tqdm(val_loader)):
        x = x.cuda()
        y = y.cuda()
        arcs = arcs.cuda()
        lengths = lengths.cuda()
        with torch.no_grad():
          pred_logits = m(x, arcs, lengths)
          loss = torch.nn.functional.cross_entropy(pred_logits, y)
          acc = (pred_logits.argmax(-1) == y).float().mean()
          precision, recall = (compute_metrics(m.sample, arcs, lengths))

        val_losses.append(loss.item())
        val_accs.append(acc.item())
        val_precs.append(precision.item())
        val_recs.append(recall.item())
    return val_losses, val_accs, val_precs, val_recs

In [21]:
#sampler = ProjectiveSampler("gaussian", 1.0)
sampler = IndependentSampler("logistic", 1.0)
m = _model.get_school_model(sampler)
m.cuda()
opt = torch.optim.AdamW(m.parameters())
training(m, train_loader, opt, num_epochs=2)

  "num_layers={}".format(dropout, num_layers))
  7%|▋         | 148/2000 [00:13<02:43, 11.36it/s]


KeyboardInterrupt: ignored

In [22]:
val_losses, val_accs, val_precs, val_recs = validation(m, val_loader)

100%|██████████| 10/10 [00:08<00:00,  1.15it/s]


In [23]:
np.mean(val_losses)

1.897117590904236

In [24]:
np.mean(val_accs)

0.28240000903606416

In [25]:
np.mean(val_precs)

0.06232384238392115

In [26]:
np.mean(val_recs)

0.3070721298456192