<a href="https://colab.research.google.com/github/DmitriyValetov/nlp_course_project/blob/master/ria_rnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dataset

In [0]:
from google.colab import drive
import os

drive.mount('/content/drive')
root = '/content/drive/My Drive/'
data_fn = 'stop_lem_norm_sents_ria.json.gz'
vocab_fn = 'stop_lem_norm_sents_ria.json_vocab.json'
enc_fn = 'stop_lem_norm_sents_ria.json_enc.json.gz'
data_path = os.path.join(root, data_fn)
vocab_path = os.path.join(root, vocab_fn)
enc_path = os.path.join(root, enc_fn)
print(f'Check {data_path}: {os.path.exists(data_path)}')
print(f'Check {vocab_path}: {os.path.exists(vocab_path)}')
print(f'Check {enc_path}: {os.path.exists(enc_path)}')

In [0]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import os
import json
import gzip
from pprint import pprint
from tqdm.notebook import tqdm
import numpy as np
import gc
import shutil

class RiaDataset(Dataset):
  def __init__(self, data=None, max_vocab=10000, max_samples=500,
               vocab=None, enc_data=None, 
               special_chars=('<pad>', '<unk>', '<sos>', '<eos>')):
    super(RiaDataset).__init__()
    # Vocabulary dict {word: count)
    if vocab is None:  # Create vocabualry
      print('Creating vocabulary')
      if data is None:
        raise ValueError('Requires raw data path for vocabulary creating')
      vocab = {}
      cnt, bad_ids = 0, []
      root, ext = os.path.splitext(data)
      with gzip.open(data, 'rb') if ext == '.gz' else open(data) as f:
        try:
          for _ in tqdm(range(max_samples), desc='loading samples'):
            n = json.loads(next(f))
            if not len(n['text']) == 0 and not len(n['title']) == 0:
              for s in n['text']:
                for w in s.split():
                  vocab[w] = vocab.setdefault(w, 0) + 1
              for s in n['title']:
                for w in s.split():
                  vocab[w] = vocab.setdefault(w, 0) + 1
            else:
              bad_ids.append(cnt)
            cnt += 1
        except StopIteration as e:
          print(f'max_samples {max_samples} > len dataset {cnt}')
      print(f'bad texts: {len(bad_ids)} {bad_ids}')
      vocab_path = root + '_vocab.json'
      print(f'Saving vocabulary to {vocab_path}')
      with open(vocab_path, 'w') as f:
        json.dump(vocab, f)
    elif isinstance(vocab, str):  # Load vocabualry
      print(f'Loading vocabulary from {vocab}')
      with open(vocab) as f:
        vocab = json.load(f)
    # else use as it is   
    print('Full vocabulary')
    print(f'unique: {len(vocab)}, total: {sum(vocab.values())}')
    self.itos = list(special_chars)
    top_words = [(w, c) for w, c in sorted(vocab.items(), key=lambda x: x[1],
                                            reverse=True)][:max_vocab]
    self.itos += [w for w, _ in top_words]
    self.stoi = {x: i for i, x in enumerate(self.itos)}
    assert len(self.itos) == len(self.stoi)
    print('Reduced vocabulary')
    print(f'unique: {len(self.stoi)}, total: {sum(c for _, c in top_words)} ~ \
    {sum(c for _, c in top_words)/sum(vocab.values())*100:6.3f}% of full')
    # Samples [list of np.arrays (text), list of np.arrays (title)]
    self.samples = [] 
    if enc_data is None:  # Encode samples
      print('Encoding samples')
      if data is None:
        raise ValueError('Requires raw data path for samples encoding')
      cnt, bad_ids = 0, []
      root, ext = os.path.splitext(data)
      with gzip.open(data, 'rb') if ext == '.gz' else open(data) as f:
        try:
          for _ in tqdm(range(max_samples), desc='loading samples'):
            n = json.loads(next(f))
            if not len(n['text']) == 0 and not len(n['title']) == 0:        
              x1 = [self.encode(x) for x in n['text']]
              x2 = [self.encode(x) for x in n['title']]
              self.samples.append([x1, x2])
            else:
              bad_ids.append(cnt)
            cnt += 1
        except StopIteration as e:
            print(f'max_samples {max_samples} > len dataset {cnt}')
      print(f'bad texts: {len(bad_ids)} {bad_ids}')
      enc_path = root + '_enc.json'
      print(f'Saving encoded samples to {enc_path}')
      with open(enc_path, 'w') as f:
        for s in tqdm(self.samples):
          json_str = json.dumps(s, default=lambda x: x.tolist())
          f.write(json_str + '\n')
      comp_enc_path = enc_path + '.gz'
      print(f'Compressing {enc_path} to {comp_enc_path}')
      with gzip.open(comp_enc_path, 'wb') as gz_file:
        with open(enc_path, 'rb') as json_file:
          shutil.copyfileobj(json_file, gz_file)
      print(f'Deleting uncompressed samples: {enc_path}')
      os.remove(enc_path)
    elif isinstance(enc_data, str):  # Load encodings
      print(f'Loading encoded samples from {enc_data}')
      _, ext = os.path.splitext(enc_data)
      with gzip.open(enc_data, 'rb') if ext == '.gz' else open(enc_data) as f:
        try:
          for _ in tqdm(range(max_samples), desc='loading samples'):
            s = json.loads(next(f))
            x1 = [np.array(x) for x in s[0]]
            x2 = [np.array(x) for x in s[1]]
            self.samples.append([x1, x2])
        except StopIteration as e:
          print(f'max_samples {max_samples} > len dataset {len(self.samples)}')
    else:
      self.samples = enc_data
    print(f'samples: {len(self.samples)}')

  def __len__(self):
    return len(self.samples)

  def __getitem__(self, i):
    return self.samples[i]

  def encode(self, s):
    return np.array([self.stoi['<sos>']] \
                  + [self.stoi.get(x, self.stoi['<unk>']) for x in s.split()] \
                  + [self.stoi['<eos>']])
  def decode(self, s):
    return [self.itos[x] for x in s]

class Collate():
  def __init__(self, n_x1=1, padding_value=0):
     self.n_x1 = n_x1
     self.padding_value = padding_value

  def __call__(self, batch):
    bx1, bx2 = [], []
    for b in batch:
      x1s, x2s = b
      # skip first sentence (usually date and place)
      ei = min(len(x1s), 1 + self.n_x1)
      si = 1 if ei > 1 else 0
      # print(self.n_x1, si, ei, len(x1s))
      for x1 in x1s[si:ei]:
        for x2 in x2s:
          bx1.append(torch.as_tensor(x1))
          bx2.append(torch.as_tensor(x2))
    bx1 = torch.nn.utils.rnn.pad_sequence(bx1, batch_first=True, 
                                          padding_value=0)
    bx2 = torch.nn.utils.rnn.pad_sequence(bx2, batch_first=True,
                                          padding_value=0)
    batch = [bx1, bx2]
    return batch

# Create vocabulary and encodings  # max_samples=1003869 max_vocab=50000
# ds = RiaDataset(data=data_path, max_vocab=50000, max_samples=2003869, 
#                 vocab=None, enc_data=None)
# Load vocabulary and encodings
ds = RiaDataset(data=None, max_vocab=50000, max_samples=1003869, 
                vocab=vocab_path, enc_data=enc_path) 
# unique words: 601956, total words: 198199252
# vocabulary: 50004, words in vocabulary: 194385367

torch.manual_seed(0)
torch.cuda.manual_seed(0)
# dl = DataLoader(ds, batch_size=1, num_workers=1, shuffle=False, 
#                 drop_last=False, collate_fn=Collate())
# for s in tqdm(dl, desc='dataset'):
#   bx1, bx2 = s
#   assert bx1.size()[0] == bx2.size()[0]
#   print(bx1.size(), bx2.size())
#   for x1, x2 in zip(bx1, bx2):
#     print(ds.decode(x1), ds.decode(x2))
#   print([ds.decode(x) for x in bx1])
#   print([ds.decode(x) for x in bx2])
train_len = int(0.7*len(ds))
test_len = int(0.2*len(ds))
val_len = len(ds) - train_len - test_len
lens = [train_len, test_len, val_len]
print(lens, sum(lens), len(ds))
train_ds, test_ds, val_ds = random_split(ds, lens)
train_dl = DataLoader(train_ds, batch_size=2, num_workers=1, 
                      shuffle=True, drop_last=False, 
                      collate_fn=Collate(n_x1=1))
test_dl = DataLoader(test_ds, batch_size=3, num_workers=1, 
                     shuffle=False, drop_last=False, 
                     collate_fn=Collate(n_x1=1))
val_dl = DataLoader(val_ds, batch_size=3, num_workers=1, 
                    shuffle=False, drop_last=False, 
                    collate_fn=Collate(n_x1=1))
# for s in tqdm(train_dl, desc='train'):
#   bx1, bx2 = s
#   assert bx1.size()[0] == bx2.size()[0]
#   # print(bx1.size(), bx2.size())
# for s in tqdm(test_dl, desc='test'):
#   bx1, bx2 = s
#   assert bx1.size()[0] == bx2.size()[0]
#   # print(bx1.size(), bx2.size())
# for s in tqdm(val_dl, desc='val'):
#   bx1, bx2 = s
#   assert bx1.size()[0] == bx2.size()[0]
#   # print(bx1.size(), bx2.size())
gc.collect()

In [0]:
# %matplotlib inline
# import matplotlib.pyplot as plt
# import seaborn as sns
# import pandas as pd

# max_vocab = 1000
# with open(vocab_path) as f:
#   vocab = json.load(f)
# top_words = dict([(w, c) for w, c in sorted(vocab.items(), key=lambda x: x[1],
#                                             reverse=True)][:max_vocab])
# # # sns.distplot(list(map(len, vocab.keys())), kde=False)
# # # plt.figure()
# # # sns.distplot(list(map(len, top_words.keys())), kde=True)
# # # plt.figure()
# sns.distplot(list(vocab.values()), kde=False, hist_kws={'log': True})
# sns.distplot(list(top_words.values()), kde=False, hist_kws={'log': True})
# ax = sns.distplot(list(vocab.values()), kde=False, hist=True, 
#              hist_kws={'log': True}, norm_hist=False)
# # ax.set_xscale('log')
# ax = sns.distplot(list(top_words.values()), kde=False, hist=True, 
#              hist_kws={'log': True}, norm_hist=False)
# # ax.set_xscale('log')
# plt.figure()
# plt.title('Sentences lengths')
# sns.kdeplot([len(y) for x in ds.samples for y in x[0]], 
#             label='texts', shade=True, bw=2)
# sns.kdeplot([len(y) for x in ds.samples for y in x[1]], 
#             label='titles', shade=True, bw=2)
# plt.legend()
# plt.figure()
# sns.jointplot(x='titles lengths', y='texts lengths', data=pd.DataFrame({
#     "texts lengths": [sum(map(len, x[0])) for x in ds.samples],
#     "titles lengths": [sum(map(len, x[1])) for x in ds.samples]}
# ), kind='kde', n_levels=10)
# sns.jointplot(x='titles 1 sentense length', y='texts 1 sentense length',
#               data=pd.DataFrame({
#     "texts 1 sentense length": [len(x[0][0]) for x in ds.samples],
#     "titles 1 sentense length": [len(x[1][0]) for x in ds.samples]}
# ), kind='kde', n_levels=10)

# Model

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class EncoderRNN(nn.Module):
  def __init__(self, num_embeddings, embedding_dim, padding_idx,
               rnn_type, hidden_size, num_layers=1, rnn_dropout=0,
               bidirectional=False, dropout=0, pack=False):
    super(EncoderRNN, self).__init__()
    rnn_map = {'RNN': nn.RNN, 'GRU': nn.GRU, 'LSTM': nn.LSTM}
    self.embedding = nn.Embedding(num_embeddings=num_embeddings, 
                                  embedding_dim=embedding_dim,
                                  padding_idx=padding_idx)
    self.rnn = rnn_map[rnn_type](input_size=embedding_dim, 
                                 hidden_size=hidden_size,
                                 num_layers=num_layers,
                                 batch_first=True,
                                 dropout=rnn_dropout,
                                 bidirectional=bidirectional)
    self.dropout = nn.Dropout(dropout)
    self.pack = pack

  def forward(self, x1):  # [B, L]
    x = self.embedding(x1)  # [B, L] -> [B, L, E]
    if self.pack:  # [B, L, E] -> Packed [B, L, E]
      ls = torch.sum(x1 != 0, dim=1)
      x = pack_padded_sequence(x, ls, batch_first=True, enforce_sorted=False)
    # [B, L, E] -> (Packed) [B, L, ND*H], ([NL*ND, B, H], ([NL*ND, B, H]))
    if isinstance(self.rnn, nn.LSTM):
      ht, (hn, cn) = self.rnn(x)
      hn, cn = self.dropout(hn), self.dropout(cn)
      return ht, (hn, cn)
    else:
      ht, hn = self.rnn(x)
      hn = self.dropout(hn)
      return ht, hn


class AttentionDecoderRNN(nn.Module):
  def __init__(self, num_embeddings, embedding_dim, padding_idx,
               rnn_type, hidden_size, num_layers=1, rnn_dropout=0,
               bidirectional=False, dropout=0, out_hidden=0,
               attn_type='soft_dot', pack=False):
    super(AttentionDecoderRNN, self).__init__()
    rnn_map = {'RNN': nn.RNN, 'GRU': nn.GRU, 'LSTM': nn.LSTM}
    attn_types = ['dot', 'cos', 'dist', 'soft_dot', 'soft_cos', 'soft_dist', 'none']
    self.attn_type = attn_type
    self.embedding = nn.Embedding(num_embeddings=num_embeddings, 
                                  embedding_dim=embedding_dim,
                                  padding_idx=padding_idx)
    self.rnn = rnn_map[rnn_type](input_size=embedding_dim, 
                                 hidden_size=hidden_size,
                                 num_layers=num_layers,
                                 batch_first=True,
                                 dropout=rnn_dropout,
                                 bidirectional=bidirectional)
    out_input = hidden_size if self.attn_type == 'none' else 2*hidden_size
    if out_hidden > 0:
      self.out_hidden = nn.Linear(out_input, out_hidden)
      self.out = nn.Linear(out_hidden, num_embeddings)
    else:
      self.out_hidden = None
      self.out = nn.Linear(out_input, num_embeddings)
    self.dropout = nn.Dropout(dropout)
    self.softmax = nn.LogSoftmax(dim=2)
    self.pack = pack

  def forward(self, x2, h1, x1):  # [B, L], [B, L, ND*H], ([NL*ND, B, H], [NL*ND, B, H])
    x = self.embedding(x2)  # [B, L] -> [B, L, E]
    if self.pack:  # [B, L, E] -> Packed [B, L, E]
      lengths = torch.sum(x2 != 0, dim=1)
      x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
    ht1, hns1 = h1  # [B, L, ND*H], ([NL*ND, B, H], [NL*ND, B, H])
    # [B, L, D], ([NL*ND, B, H], [NL*ND, B, H]) -> [B, L, ND*H], ([NL*ND, B, H], [NL*ND, B, H])
    ht2, hns2 = self.rnn(x, hns1)
    if self.pack:  # Packed [B, L, E] -> [B, L, E]
      ht1, _ = pad_packed_sequence(ht1, batch_first=True)
      ht2, _ = pad_packed_sequence(ht2, batch_first=True)
    if self.attn_type == 'none':
      if self.rnn.bidirectional:
        B2, L2, NDH2 = ht2.size()
        # [B, L, ND*H] -> [B, L, ND, H] -> [B, L, H]
        ht2 = ht2.view(B2, L2, 2, int(NDH2/2)).mean(2)
      x = self.dropout(ht2)
      if self.out_hidden is not None:
        x = self.out_hidden(x)
      x = self.out(x)
      x = self.softmax(x)
      return x, None
    else:
      if self.rnn.bidirectional:
        B1, L1, NDH1 = ht1.size()
        B2, L2, NDH2 = ht2.size()  # B1 == B2, NDH1 == NDH2
        # [B, L, ND*H] -> [B, L, ND, H] -> [B, L, H]
        ht1 = ht1.view(B1, L1, 2, int(NDH1/2)).mean(2)
        ht2 = ht2.view(B2, L2, 2, int(NDH2/2)).mean(2)
      # [L2, H], [L1, H] -> [L2, L1]
      # mask where x1 and x2 token is <PAD>
      # pad_mask = torch.einsum('bi,bj->bij', x2, x1) == 0
      # mask where only x1 token is <PAD>
      pad_mask = torch.einsum('bi,bj->bij', torch.ones_like(x2), x1) == 0
      # [B, L2, H], [B, L1, H] -> [B, L2, L1]  # attention
      # [B, L2, L1], [B, L1, H] -> [B, L2, H]  # weighted h1
      if self.attn_type == 'dot':
        attn = torch.einsum('bih,bjh->bij', ht2, ht1)  # dot product
        attn[pad_mask] = 0
      elif self.attn_type == 'soft_dot':
        attn = torch.einsum('bih,bjh->bij', ht2, ht1)  # dot product
        attn[pad_mask] = float('-inf')
        attn = F.softmax(attn, 2)
      elif self.attn_type == 'dist':
        ht1, ht2 = ht1.contiguous(), ht2.contiguous()
        attn = torch.cdist(ht2, ht1)  # euclidian distance
        attn = torch.masked_fill(attn, pad_mask, float('inf'))
        attn = F.threshold(attn, threshold=1e-6, value=1e-6)  # short dist
        attn = 1/attn  # inverse
        attn = F.normalize(attn, p=1, dim=2)  # to [0, 1]
      elif self.attn_type == 'soft_dist':
        ht1, ht2 = ht1.contiguous(), ht2.contiguous()
        attn = torch.cdist(ht2, ht1)  # euclidian distance
        attn = torch.masked_fill(attn, pad_mask, float('inf'))
        attn = F.softmin(attn, 2)
      elif self.attn_type == 'cos':
        ht1n = F.normalize(ht1, p=2, dim=2)  # normalize to length = 1
        ht2n = F.normalize(ht2, p=2, dim=2)  # normalize to length = 1
        attn = torch.einsum('bih,bjh->bij', ht2n, ht1n)  # dot product
        attn[pad_mask] = 0
      elif self.attn_type == 'soft_cos':
        ht1n = F.normalize(ht1, p=2, dim=2)  # normalize to length = 1
        ht2n = F.normalize(ht2, p=2, dim=2)  # normalize to length = 1
        attn = torch.einsum('bih,bjh->bij', ht2n, ht1n)  # dot product
        attn[pad_mask] = float('-inf')
        attn = F.softmax(attn, 2)
      hw1 = torch.einsum('bij,bjh->bih', attn, ht1)  # weighted h1
      ha = torch.cat((ht2, hw1), 2)  # [B, L2, H], [B, L2, H] -> [B, L2, H+H]
      x = self.dropout(ha)  # [B, L2, H+H] -> [B, L2, H+H]
      if self.out_hidden is not None:
        x = self.out_hidden(x)   # [B, L2, H+H] -> [B, L2, OH]
      x = self.out(x)  # [B, L2, H+H] or [B, L2, OH] -> [B, L2, D2]
      y2 = self.softmax(x)  # [B, L2, D2] -> [B, L2, D2]
      return y2, attn


class EncoderDecoder(nn.Module):
  def __init__(self, 
               enc_num_embeddings, enc_embedding_dim, enc_padding_idx,
               dec_num_embeddings, dec_embedding_dim, dec_padding_idx,
               rnn_type, hidden_size, num_layers=1, out_hidden=0,
               enc_rnn_dropout=0, dec_rnn_dropout=0,
               bidirectional=False, enc_dropout=0, dec_dropout=0, 
               attn_type='dot', pack=False):
    super(EncoderDecoder, self).__init__()
    self.encoder = EncoderRNN(num_embeddings=enc_num_embeddings, 
                              embedding_dim=enc_embedding_dim, 
                              padding_idx=enc_padding_idx,
                              hidden_size=hidden_size, 
                              rnn_type=rnn_type, 
                              bidirectional=bidirectional,
                              num_layers=num_layers,
                              dropout=enc_dropout,
                              rnn_dropout=enc_rnn_dropout,
                              pack=pack)
    self.decoder = AttentionDecoderRNN(num_embeddings=dec_num_embeddings, 
                                       embedding_dim=dec_embedding_dim, 
                                       padding_idx=dec_padding_idx,
                                       hidden_size=hidden_size, 
                                       rnn_type=rnn_type, 
                                       bidirectional=bidirectional,
                                       num_layers=num_layers,
                                       dropout=dec_dropout,
                                       out_hidden=out_hidden,
                                       rnn_dropout=dec_rnn_dropout,
                                       attn_type=attn_type,
                                       pack=pack)

  def forward(self, x1, x2):
    #  [B, L1] -> [B, L1, ND*H], ([NL*ND, B, H], [NL*ND, B, H])
    h1 = self.encoder(x1)
    # [B, L2], ([B, L1, ND*H], ([NL*ND, B, H], [NL*ND, B, H])) -> [B, L2, E2]
    y2 = self.decoder(x2, h1, x1)
    return y2


def external_attn(x1, x2, attn_dict):
    attn_dict = {1: [1], 2: [2], 4: [4, 3]}
    attn = []
    B1, L1 = x1.size()
    B2, L2 = x2.size()
    for i in range(B1):
      b = []
      for j in range(L2):
        l2 = []
        # x2t = x2[i, j].item()
        x2t = x2[i, j + 1].item() if j + 2 < L2 else 0 # decoder shift
        for k in range(L1):
          x1t = x1[i, k].item()
          # x1t = x1[i, k + 1].item() if k + 1 < L1 else 0 # encoder shift
          x1ts = self.attn_dict.get(x2t, [])
          if x1t in x1ts:
            l2.append(1.)
          else:
            l2.append(0.)
        b.append(l2)
      attn.append(b)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    attn = torch.tensor(attn).to(device)
    attn = F.normalize(attn, p=1, dim=2)  # to [0, 1]
    return attn

from IPython.display import Image
# Image(make_dot(loss).render('loss', format='png'))
from tqdm.notebook import tqdm
from torch import optim
import numpy as np

def plot_attention(a, x1, x2, shift=True, mask=True, 
                   suptitle='Attention', figsize=None, 
                   decoder_x1=None, decoder_x2=None, tight=False,
                   labels=True):
  # %matplotlib inline
  import matplotlib.pyplot as plt
  from mpl_toolkits.axes_grid1 import make_axes_locatable

  b = a.shape[0]
  fig, axs = plt.subplots(1, b, figsize=figsize)
  if not isinstance(axs, np.ndarray):  # if batch size == 1
    axs = [axs]
  fig.suptitle(suptitle)
  for i in range(b):
    ax = axs[i]
    if shift:
      ba, bx1, bx2 = a[i,:-1,1:], x1[i,1:], x2[i,1:]
    else:
      ba, bx1, bx2 = a[i], x1[i], x2[i]
    if mask:
      mask_x1 = np.flatnonzero(bx1)
      mask_x2 = np.flatnonzero(bx2)
      ba, bx1, bx2 =  ba[mask_x2,:][:,mask_x1], bx1[mask_x1], bx2[mask_x2]
    # ax.set_title(f'{i+1}', y=-0.2)
    im = ax.imshow(ba, cmap='gray')
    ax.set_xticks(np.arange(len(bx1)))
    ax.set_yticks(np.arange(len(bx2)))
    if labels:
      ax.set_xlabel('x1')
      ax.set_ylabel('x2')
    ax.xaxis.tick_top()
    ax.xaxis.set_label_position('top')
    if decoder_x1 is not None:
      bx1 = decoder_x1(bx1)
    if decoder_x2 is not None:
      bx2 = decoder_x2(bx2)
    ax.set_xticklabels(bx1, rotation=90)  # rotation=90
    ax.set_yticklabels(bx2)
    fig.colorbar(im, ax=ax, fraction=0.05, pad=0.05)
  if tight:
    plt.tight_layout()


# torch.backends.cudnn.enabled=False
# torch.backends.cudnn.deterministic=True
# torch.autograd.set_detect_anomaly(True)

# enc_dec_config = {
#   'enc_num_embeddings': 6,
#   'enc_embedding_dim': 2,
#   'enc_padding_idx': 0,
#   'dec_num_embeddings': 6,
#   'dec_embedding_dim': 2,
#   'dec_padding_idx': 0,
#   'rnn_type': 'RNN',
#   'hidden_size': 2,
#   'num_layers': 1,
#   'bidirectional': False,
#   'enc_rnn_dropout': 0,
#   'dec_rnn_dropout': 0,
#   'enc_dropout': 0,
#   'dec_dropout': 0,
#   'attn_type': 'soft_dist',
#   'out_hidden': 16,
#   'pack': True
# }
# # <pad>, <unk>, <go>, <eos>
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# seed = 0
# torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
# model = EncoderDecoder(**enc_dec_config).to(device)
# x1 = torch.tensor([[2, 1, 4, 3, 0], [2, 1, 4, 1, 3]]).to(device)
# x2 = torch.tensor([[2, 4, 1, 5, 3, 0], [2, 4, 5, 1, 1, 3]]).to(device)
# opt = optim.Rprop(model.parameters(), lr=0.01)
# loss_fn = torch.nn.NLLLoss(ignore_index=0)  # 0 is <PAD>
# pbar = tqdm(range(200))
# for i in pbar:
#   opt.zero_grad()
#   t = x2[:,1:].flatten(start_dim=0)
#   y2, attn = model(x1, x2)
#   p = y2[:,:-1,:].flatten(start_dim=0, end_dim=1)
#   loss = loss_fn(p, t)
#   loss.backward()
#   opt.step()
#   with torch.no_grad():
#     idx = torch.nonzero(t).view(-1)
#     acc = torch.sum(torch.argmax(p, 1)[idx] == t[idx]).float()/idx.size()[0]
#     ps = torch.argmax(y2[:,:-1,:], 2).detach().cpu().numpy()
#     ts = x2[:,1:].detach().cpu().numpy()
#     for sp, st in zip(ps, ts):
#       t_set = set(st) - {0, 1, 2, 3}
#       eos = sp[sp == 3][0] if 3 in sp else len(sp)
#       p_eos = sp[:eos]
#       p_set = set(p_eos) - {0, 1, 2, 3}
#       i_set = t_set.intersection(p_set)
#       pre = len(i_set)/len(p_set) if len(p_set) != 0 else 0
#       rec = len(i_set)/len(t_set) if len(t_set) != 0 else 0
#       f1 = 2*pre*rec/(pre + rec) if pre + rec != 0 else 0
#       # print(st, sp, p_eos)
#       # print(t_set, p_set)
#       # print(f'precision: {pre}, recall: {rec}, f1: {f1}')
#     if i % 25 == 0:
#       # print(attn)
#       # print(p)
#       # print(torch.argmax(p, 1))
#       # print(t)
#       if attn is not None:
#          plot_attention(attn.detach().cpu().numpy(), 
#                         x1.detach().cpu().numpy(), 
#                         x2.detach().cpu().numpy(),
#                         labels=True, tight=False,
#                         shift=False, mask=False, figsize=(10, 5))
#   pbar.set_description(f'loss: {loss:.3f}, acc: {acc:.3f}')
#   if acc == 1:
#     break
# # Image(make_dot(attn).render('attn', format='png'))

# Train

## GPU/CPU stats

In [0]:
!pip install gputil  # for monitoring
!pip install psutil  # for monitoring
!pip install humanize  # for monitoring

!pip install optuna  # for train

!pip install rouge_score  # for metrics

In [0]:
!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi
import subprocess
print(subprocess.getstatusoutput('nvidia-smi')[1])

In [0]:
# Import packages
import os, sys, humanize, psutil, GPUtil

# Define function
def mem_report(verbose=True):
  h = humanize.naturalsize
  cmf = psutil.virtual_memory().available
  cmt = psutil.virtual_memory().total
  if verbose: 
    print(f"CPU mem Free: {h(cmf)} / {h(cmt)}")
  gmfs, gmts, gmus = [], [], []
  GPUs = GPUtil.getGPUs()
  for i, gpu in enumerate(GPUs):
    gmfs.append(gpu.memoryFree)
    gmts.append(gpu.memoryTotal)
    gmus.append(gpu.memoryUtil)
    if verbose:
      print(f'GPU {i} mem free: {h(gmfs[-1]*1000000)} / {h(gmts[-1]*1000000)} util: {int(gmus[-1]*100)}%')
  return cmf, gmfs, gmts, gmus
mem_report()

## Beam

In [0]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns

reduction_map = {
  'sum': torch.sum,
  'mean': torch.mean,
  'none': None
}

class Forced():
  def __init__(self, **kwargs):
    pass

  def __call__(self, model, bx1, bx2):
    by2p, attn = model(bx1, bx2)
    y2 = torch.argmax(by2p[:,:-1,:], axis=2)
    return y2, attn  

class Greedy():
  def __init__(self, sos=2, eos=3, max_len=30, **kwargs):
    self.sos, self.eos, self.max_len = sos, eos, max_len
    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


  def __call__(self, model, bx1, bx2=None):
    batch_size = bx1.size(0)  # input batch_size
    bx2 = torch.full((batch_size, 1), self.sos, dtype=bx1.dtype).to(self.device)  # batch with <sos>
    # stop when predictions len > max_len or all have <eos> token
    while bx2.size(1) - 1 < self.max_len and not torch.all(torch.any(bx2 == self.eos, axis=1)):
      by2p, attn = model(bx1, bx2)
      next_bx2 = torch.argmax(by2p[:,-1,:], axis=1, keepdim=True)
      bx2 = torch.cat((bx2, next_bx2), 1)
    return bx2[:,1:], attn  

class Beam():
  def __init__(self, sos=2, eos=3, max_len=30, beam_width=2, beam_depth=2,
               depth_reduction='sum', beam_reduction='none', 
               batch_reduction='none', max_input_size=30000, 
               first_beam_width=10, verbose=0,
               plot_dist=False, plot_size=100, **kwargs):
    self.sos, self.eos, self.max_len = sos, eos, max_len
    self.beam_width, self.beam_depth = beam_width, beam_depth
    self.depth_reduction = reduction_map[depth_reduction]
    self.beam_reduction = reduction_map[beam_reduction]
    self.batch_reduction = reduction_map[batch_reduction]
    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    self.max_input_size = max_input_size
    self.first_beam_width = first_beam_width
    self.plot_size = plot_size
    self.plot_dist = plot_dist
    self.verbose = verbose

  def __call__(self, model, bx1, bx2=None):
    batch_size = bx1.size(0)  # input batch_size
    device = torch
    bx2 = torch.full((batch_size, 1), self.sos, dtype=bx1.dtype).to(self.device)  # batch with <sos>
    # stop when len decoder output > max_len or all decoder outputs have <eos> token
    while bx2.size(1) - 1 < self.max_len and not torch.all(torch.any(bx2 == self.eos, axis=1)):
      beam_scores = torch.empty((batch_size, 0)).to(self.device)  # scores for each beam
      bx1t = bx1 if batch_size == 1 else bx1.clone()
      for i in range(self.beam_depth):
        eos_mask = torch.any(bx2 == self.eos, axis=1)  # beams with <eos>
        cur_batch_size = torch.sum(~eos_mask)  # beams without <eos>
        cur_input_size = bx1t.numel() + bx2.numel()  # B*L1 + B*L2
        cur_beam_width = self.beam_width if i > 0 else self.first_beam_width
        if self.verbose:
          print(f'{i} {bx1t.size()}, {bx2.size()}, {cur_batch_size}, {cur_input_size}')
        if bx2.size(1) - 1 >= self.max_len or cur_batch_size == 0:
          break
        if cur_input_size > self.max_input_size:  # split into parts
          n_parts = cur_input_size // self.max_input_size + 1
          bx1t_parts = torch.chunk(bx1t[~eos_mask], n_parts)
          bx2_parts = torch.chunk(bx2[~eos_mask], n_parts)
          masked_by2p = torch.empty((0, cur_beam_width)).to(self.device)
          masked_bx2 = torch.empty((0, cur_beam_width), dtype=bx2.dtype).to(self.device)
          if self.plot_dist:
            fig, axs = plt.subplots(ncols=2)
          for j, (bx1t_p, bx2_p) in enumerate(zip(bx1t_parts, bx2_parts)):
            if self.verbose:
              print(f'{j+1}/{n_parts} {bx1t_p.size()}, {bx2_p.size()}, {bx1t_p.numel() + bx2_p.numel()}')
            by2p_p, attn = model(bx1t_p, bx2_p)
            masked_by2p_p, masked_bx2_p = torch.topk(by2p_p[:,-1,:], 
                                                     cur_beam_width)
            masked_by2p = torch.cat((masked_by2p, masked_by2p_p))
            masked_bx2 = torch.cat((masked_bx2, masked_bx2_p))
            if self.plot_dist:
              sns.distplot(torch.topk(by2p_p[:,-1,:], self.plot_size)[0], 
                            hist=True, kde=False, label=f'{j+1}/{n_parts}',
                            ax=axs[1])
              for k in range(by2p_p.size(0)):
                sns.distplot(torch.topk(by2p_p[k,-1,:], self.plot_size)[0], 
                             hist=True, kde=False, ax=axs[0])
          if self.plot_dist:
            fig.suptitle(f'step {bx2.size(1)} top {self.plot_size} scores')
            axs[0].set_title(f'by beam')
            axs[1].set_title(f'by part')
            plt.legend()
        else:
          by2p, attn = model(bx1t[~eos_mask], bx2[~eos_mask])  # predict
          if i > 0 and self.beam_reduction is not None:  # beam reduction
            if i == 1:
              prev_prev_batch_size = batch_size
            else:
              prev_prev_batch_size = batch_size*self.first_beam_width*self.beam_width**(i-1)
            by2p = by2p.view(prev_prev_batch_size, cur_beam_width, by2p.size(1), -1)  # split predictions into batches
            by2p = self.beam_reduction(by2p, axis=1) # cumulative beams predictions
            by2p = by2p.repeat(cur_beam_width, 1, 1)  # return to prev batch size
          if self.batch_reduction is not None:  # batch reduction
            if i > 0:
              prev_batch_size = batch_size*self.first_beam_width*self.beam_width**(i-1)
            else:
              prev_batch_size = batch_size
            by2p = self.batch_reduction(by2p, axis=0) # cumulative batch predictions
            masked_by2p, masked_bx2 = torch.topk(by2p[-1,:], cur_beam_width)  # beams to top k last predictions
            masked_bx2 = masked_bx2.repeat(prev_batch_size, 1)  # return to prev batch size
            masked_by2p = masked_by2p.repeat(prev_batch_size, 1)  # return to prev batch size
          else:
            if self.plot_dist:
              fig, axs = plt.subplots(ncols=2)
              # sns.distplot(by2p[:,-1,:], hist=True, kde=False, label=f'{bx2.size(1)}')
              sns.distplot(torch.topk(by2p[:,-1,:], self.plot_size)[0], 
                           hist=True, kde=False, ax=axs[1])
              for k in range(by2p.size(0)):
                sns.distplot(torch.topk(by2p[k,-1,:], self.plot_size)[0], 
                             hist=True, kde=False, ax=axs[0])
              fig.suptitle(f'step {bx2.size(1)} top {self.plot_size} scores')
              axs[0].set_title(f'by beam')
              axs[1].set_title(f'cumulative')
            masked_by2p, masked_bx2 = torch.topk(by2p[:,-1,:], cur_beam_width)  # beams to top k last predictions
        next_by2p = torch.full((bx2.size(0), cur_beam_width), 
                               float('-inf'), dtype=by2p.dtype).to(self.device)
        next_bx2 = torch.full((bx2.size(0), cur_beam_width), 
                              0, dtype=bx2.dtype).to(self.device)
        next_by2p[~eos_mask] = masked_by2p  # from beams without <eos>
        next_bx2[~eos_mask] = masked_bx2  # from beams without <eos>
        # update bx2 and scores
        new_batch_size = batch_size*self.first_beam_width*self.beam_width**i
        next_bx2 = next_bx2.view(new_batch_size, 1)  # new beams
        next_by2p = next_by2p.view(new_batch_size, 1)  # new scores
        beam_scores = torch.repeat_interleave(beam_scores, cur_beam_width, 0)  # increase batch for new scores
        bx2 = torch.repeat_interleave(bx2, cur_beam_width, 0)  # increase batch for new beams
        bx2 = torch.cat((bx2, next_bx2), 1)  # add beams
        beam_scores = torch.cat((beam_scores, next_by2p), 1)  # add beams scores
        if batch_size == 1:
          bx1t = bx1.expand(new_batch_size, -1)  # increase batch for new beams
        else:
          bx1t = torch.repeat_interleave(bx1t, cur_beam_width, 0)  # increase batch for new beams
      if self.verbose:
        print(f'{bx1t.size()}, {bx2.size()}, {bx1t.numel() + bx2.numel()}')
      # print(bx2)
      # print(beam_scores)
      if self.depth_reduction == torch.sum:
        pad_mask = beam_scores == float('-inf')
        # print(pad_mask)
        beam_scores[pad_mask] = 0
        # print(beam_scores)
        beam_scores = self.depth_reduction(beam_scores, axis=1) # cumulative beams scores
      elif self.depth_reduction == torch.mean:
        pad_mask = beam_scores == float('-inf')
        # print(pad_mask)
        beam_scores[pad_mask] = 0
        # print(beam_scores)
        beam_scores = self.depth_reduction(beam_scores, axis=1) # cumulative beams scores
        # print(beam_scores)
        beam_lens = torch.sum(~pad_mask, axis=1)
        # print(beam_lens)
        beam_scores = beam_scores / beam_lens
      else:
        self.depth_reduction(beam_scores, axis=1)
      if self.plot_dist:
        plt.figure()
        sns.distplot(torch.topk(beam_scores, top_beams)[0], hist=True, kde=False)
        plt.title(f'step {bx2.size(1)} top {top_beams} beams scores')
      # print(beam_scores)
      if self.batch_reduction is None:
        beam_scores = beam_scores.view(batch_size, -1)  # split scores into batches
        best_beams = torch.argmax(beam_scores, axis=1, keepdim=True)  # best beams
        bx2 = bx2.view(batch_size, int(bx2.size(0)/batch_size), -1)  # split beams into batches
        # XXX its fucking magic... (return to input batch_size)
        best_beams = best_beams.unsqueeze(2).expand(best_beams.size(0), 
                                                    best_beams.size(1),
                                                    bx2.size(2))
        bx2 = torch.gather(bx2, 1, best_beams)
        bx2 = bx2.view(batch_size, -1)
      else:
        best_beams = torch.argmax(beam_scores, axis=0, keepdim=True)  # best beams
        bx2 = bx2[best_beams]  # best beam of all batches
        bx2 = bx2.repeat(batch_size, 1)  # return to input batch size
    return bx2[:,1:], attn

class BeamBkw():
  def __init__(self, sos=2, eos=3, max_len=30, beam_width=2, beam_depth=2,
               depth_reduction='sum', beam_reduction='none', 
               batch_reduction='none', bkw_beam_width=2, **kwargs):
    self.sos, self.eos, self.max_len = sos, eos, max_len
    self.beam_width, self.beam_depth = beam_width, beam_depth
    self.bkw_beam_width = bkw_beam_width
    self.depth_reduction = reduction_map[depth_reduction]
    self.beam_reduction = reduction_map[beam_reduction]
    self.batch_reduction = reduction_map[batch_reduction]
    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if self.bkw_beam_width > self.beam_width**self.beam_depth:
      print(f'Warning! Backward beam width {self.bkw_beam_width} greater\
      than beam_width**beam_depth {self.beam_width**self.beam_depth}') 
      print(f'setting backword beam width to {self.beam_width**self.beam_depth}')
      self.bkw_beam_width = self.beam_width**self.beam_depth

  def __call__(self, model, bx1, bx2=None):
    batch_size = bx1.size(0)  # input batch_size
    bx2 = torch.full((batch_size, 1), self.sos, dtype=bx1.dtype).to(self.device)  # batch with <sos>
    # stop when len decoder output > max_len or all decoder outputs have <eos> token
    while bx2.size(1) - 1 < self.max_len and not torch.all(torch.any(bx2 == self.eos, axis=1)):
      fwd_scores = torch.empty((batch_size, 0)).to(self.device)  # scores for each beam
      fwd_bx2 = bx2.clone()
      bx1t = bx1.clone()
      for i in range(self.beam_depth):
        new_batch_size = batch_size*self.beam_width**(i+1)
        by2p, attn = model(bx1t, fwd_bx2)  # predict
        if i > 0 and self.beam_reduction is not None:  # beam reduction
          prev_prev_batch_size = batch_size*self.beam_width**(i-1)
          by2p = by2p.view(prev_prev_batch_size, self.beam_width, by2p.size(1), -1)  # split predictions into batches
          by2p = self.beam_reduction(by2p, axis=1) # cumulative beams predictions
          by2p = by2p.repeat(self.beam_width, 1, 1)  # return to prev batch size
        if self.batch_reduction is not None:
          prev_batch_size = batch_size*self.beam_width**(i)
          by2p = self.batch_reduction(by2p, axis=0) # cumulative batch predictions
          next_by2p, next_bx2 = torch.topk(by2p[-1,:], self.beam_width)  # beams to top k last predictions
          next_bx2 = next_bx2.repeat(prev_batch_size, 1)  # return to prev batch size
          next_by2p = next_by2p.repeat(prev_batch_size, 1)  # return to prev batch size        
        else:
          next_by2p, next_bx2 = torch.topk(by2p[:,-1,:], self.beam_width)  # beams to top k last predictions        
        next_bx2 = next_bx2.view(new_batch_size, 1)  # new beams
        next_by2p = next_by2p.view(new_batch_size, 1)  # new scores
        fwd_scores = torch.repeat_interleave(fwd_scores, self.beam_width, 0)  # increase batch for new scores
        fwd_bx2 = torch.repeat_interleave(fwd_bx2, self.beam_width, 0)  # increase batch for new beams
        fwd_bx2 = torch.cat((fwd_bx2, next_bx2), 1)  # add beams
        fwd_scores = torch.cat((fwd_scores, next_by2p), 1)  # add beams scores
        bx1t = torch.repeat_interleave(bx1t, self.beam_width, 0)  # increase batch for new beams
      if self.batch_reduction is None:
        fwd_scores = fwd_scores.view(batch_size, self.beam_width**self.beam_depth, -1)  # split scores into batches
        fwd_bx2 = fwd_bx2.view(batch_size, self.beam_width**self.beam_depth, -1)  # split beams into batches
        bkw_bx2 = torch.empty((batch_size, self.beam_width**self.beam_depth, 0), dtype=bx1.dtype)  # scores for each beam
        bkw_scores = torch.empty((batch_size, self.beam_width**self.beam_depth, 0))  # scores for each beam
        mask = torch.full((fwd_bx2.size(0), fwd_bx2.size(1), 1), False, dtype=torch.bool)
        n_beams_per_batch = self.beam_width**self.beam_depth
        for i in range(fwd_scores.size(2)):
          old_n_beams_per_batch = n_beams_per_batch*self.bkw_beam_width**(i)
          new_n_beams_per_batch = n_beams_per_batch*self.bkw_beam_width**(i+1)
          cur_values = fwd_bx2[:,:,-i-1:-i if i > 0 else None]
          cur_values = torch.repeat_interleave(cur_values, self.bkw_beam_width**i, 1)
          cur_scores = fwd_scores[:,:,-i-1:-i if i > 0 else None]
          cur_scores = torch.repeat_interleave(cur_scores, self.bkw_beam_width**i, 1)
          cur_scores[mask] = float('-inf')  # mask from prev step
          best_scores, best_beams = torch.topk(cur_scores, self.bkw_beam_width, 1)  # best beams
          best_values = torch.gather(cur_values, 1, best_beams).view(batch_size, -1)
          best_values = best_values.view(best_values.size(0), self.bkw_beam_width, -1)
          best_values = torch.repeat_interleave(best_values, old_n_beams_per_batch, 1)
          cur_values = cur_values.repeat(1, self.bkw_beam_width, 1)
          mask = cur_values != best_values
          bkw_bx2 = bkw_bx2.repeat(1, self.bkw_beam_width, 1)  # increase batch for new beams
          bkw_bx2 = torch.cat((best_values, bkw_bx2), 2)
          best_scores = torch.repeat_interleave(best_scores, old_n_beams_per_batch, 1)
          bkw_scores = bkw_scores.repeat(1, self.bkw_beam_width, 1)  # increase batch for new beams
          bkw_scores = torch.cat((best_scores, bkw_scores), 2)
        beam_scores = self.depth_reduction(bkw_scores, axis=2) # cumulative beams scores
        best_beams = torch.argmax(beam_scores, axis=1, keepdim=True)  # best beams
        best_beams = best_beams.unsqueeze(2).expand(best_beams.size(0), best_beams.size(1), bkw_bx2.size(2))
        bkw_bx2 = torch.gather(bkw_bx2, 1, best_beams) # best beam of all batches
        bkw_bx2 = bkw_bx2.view(batch_size, -1)
        bx2 = torch.cat((bx2, bkw_bx2), 1)  # <sos> + bkw_bx2
      else:  # batch reduction
        bkw_bx2 = torch.empty((new_batch_size, 0), dtype=bx1.dtype)  # scores for each beam
        bkw_scores = torch.empty((new_batch_size, 0))  # scores for each beam
        mask = torch.full((new_batch_size, 1), False, dtype=torch.bool)
        fwd_batch_size = new_batch_size
        for i in range(fwd_scores.size(1)):
          prev_batch_size = fwd_batch_size*self.bkw_beam_width**(i)
          new_batch_size = fwd_batch_size*self.bkw_beam_width**(i+1)
          cur_values = fwd_bx2[:,-i-1:-i if i > 0 else None]
          cur_values = torch.repeat_interleave(cur_values, self.bkw_beam_width**i, 0)
          cur_scores = fwd_scores[:,-i-1:-i if i > 0 else None]
          cur_scores = torch.repeat_interleave(cur_scores, self.bkw_beam_width**i, 0)
          cur_scores[mask] = float('-inf')  # mask from prev step
          best_scores, best_beams = torch.topk(cur_scores, self.bkw_beam_width, 0)  # best beams
          best_values = cur_values[best_beams]
          mask = cur_values != best_values
          mask = mask.view(new_batch_size, -1)
          best_values = best_values.view(self.bkw_beam_width, -1)
          best_values = torch.repeat_interleave(best_values, prev_batch_size, 0)
          bkw_bx2 = bkw_bx2.repeat(self.bkw_beam_width, 1)  # increase batch for new beams
          bkw_bx2 = torch.cat((best_values, bkw_bx2), 1)
          best_scores = torch.repeat_interleave(best_scores, prev_batch_size, 0)
          bkw_scores = bkw_scores.repeat(self.bkw_beam_width, 1)  # increase batch for new beams
          bkw_scores = torch.cat((best_scores, bkw_scores), 1)
        # bkw batch reduction
        bkw_scores = self.depth_reduction(bkw_scores, axis=1) # cumulative beams scores
        best_beams = torch.argmax(bkw_scores, axis=0, keepdim=True)  # best beams
        bkw_bx2 = bkw_bx2[best_beams]  # best beam of all batches
        bkw_bx2 = bkw_bx2.repeat(batch_size, 1)  # return to input batch size
        bx2 = torch.cat((bx2, bkw_bx2), 1)  # <sos> + bkw_bx2
    return bx2[:,1:], attn

## Metric

In [0]:
from rouge_score import rouge_scorer
from pprint import pprint

def metrics(bx2, by2, eos=3):
  scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'])
  scores = {}
  ps = by2.detach().cpu().numpy()
  ts = bx2.detach().cpu().numpy()
  for sp, st in zip(ps, ts):
    # target to set
    st_eos_i = np.where(st==eos)[0][0] if eos in st else len(st)
    st_eos = st[:st_eos_i]
    t_set = set(st_eos) - {0, 1, 2, 3}
    # prediction to set
    sp_eos_i = np.where(sp==eos)[0][0] if eos in sp else len(sp)
    sp_eos = sp[:sp_eos_i]
    p_set = set(sp_eos) - {0, 1, 2, 3}
    # metrics
    i_set = t_set.intersection(p_set)
    pre = len(i_set)/len(p_set) if len(p_set) != 0 else 0
    rec = len(i_set)/len(t_set) if len(t_set) != 0 else 0
    f1 = 2*pre*rec/(pre + rec) if pre + rec != 0 else 0
    rl = len(sp_eos)/len(st_eos) if len(st_eos) != 0 else 0
    # to string
    t_str = ' '.join(map(str, st_eos))
    p_str = ' '.join(map(str, sp_eos))
    rs = scorer.score(t_str, p_str)
    # print(f'st: {st}, sp: {sp}')
    # print(f'eos in st?: {eos in st}, sp_eos_i: {st_eos_i}')
    # print(f'eos in sp?: {eos in sp}, sp_eos_i: {sp_eos_i}')
    # print(f'st_eos: {st_eos}, sp_eos: {sp_eos}')
    # print(f't_set: {t_set}, p_set: {p_set}')
    # print(f't_str: "{t_str}", p_str: "{p_str}"')
    # pprint(rs)
    # print(f'precision: {pre}, recall: {rec}, f1: {f1}. rl: {rl}\n')
    scores.setdefault('set_precision', []).append(pre)
    scores.setdefault('set_recall', []).append(rec)
    scores.setdefault('set_f1', []).append(f1)
    scores.setdefault('relen', []).append(rl)
    for k, v in rs.items():
      scores.setdefault(f'{k}_precision', []).append(v.precision)
      scores.setdefault(f'{k}_recall', []).append(v.recall)
      scores.setdefault(f'{k}_f1', []).append(v.fmeasure)
  return scores

x2 = torch.tensor([[1, 5, 3, 4, 4, 3, 3, 3, 3], 
                   [5, 3, 0, 4, 3, 0, 0, 0, 0]])
y2 = torch.tensor([[3, 1, 4, 3, 4, 6, 7, 4], 
                   [0, 5, 4, 3, 3, 0, 0, 3]])
metrics(x2, y2)

## Run

In [0]:
import humanize
import psutil
import GPUtil

def run(train, device, model, opt, loss_fn, dl, ds, print_triples=1, 
        plot_attn=False, inference=None, metric=False, pin_memory=False,
        max_samples=None):
  scores = {}
  pbar = tqdm(dl, total=max_samples)
  batch_cnt = 0
  if train:
    model.train()
    for x1, x2 in pbar:
      opt.zero_grad()
      x1 = x1.to(device, non_blocking=pin_memory)
      x2 = x2.to(device, non_blocking=pin_memory)
      t = x2[:,1:].flatten(0)
      y2p, attn = model(x1, x2)
      p = y2p[:,:-1,:].flatten(0, 1)
      loss = loss_fn(p, t)
      loss.backward()
      opt.step()
      scores.setdefault('loss', []).extend(loss.item())
      y2 = None
      if metric:
        y2 = torch.argmax(y2p[:,:-1,:], 2)
        # idx = torch.nonzero(t).view(-1)
        # acc = torch.sum(y2.view(-1)[idx] == t[idx]).float()/idx.size()[0]
        # accs.append(acc.item())
        bs = metrics(x2[:,1:], y2)  # batch scores
        r1f = sum(bs['rouge1_f1'])/len(bs['rouge1_f1'])
        r2f = sum(bs['rouge2_f1'])/len(bs['rouge2_f1'])
        rlf = sum(bs['rougeL_f1'])/len(bs['rougeL_f1'])
        sf = sum(bs['set_f1'])/len(bs['set_f1'])
        rl = sum(bs['relen'])/len(bs['relen'])
        for score, values in bs.items():
          scores.setdefault(score, []).extend(values)
        cmt = psutil.virtual_memory().total
        cma = psutil.virtual_memory().available
        cmu = (cmt-cma)/cmt
        gmu = GPUtil.getGPUs()[0].memoryUtil if len(GPUtil.getGPUs()) > 0 else 0
        pbar.set_description(f'train {int(cmu*100)}%/{int(gmu*100)}% \
          l: {loss.item():.3}, \
          sf: {int(sf*100)}, r1f: {int(r1f*100)}, r2f: {int(r2f*100)}, \
          rlf: {int(rlf*100)}, rl: {int(rl*100)}')
      else:
        pbar.set_description(f'loss: {loss.item():.3}')
      batch_cnt += 1
      if max_samples is not None:
        if batch_cnt == max_samples:
          break
  else:
    model.eval()
    with torch.no_grad():
      for x1, x2 in pbar:
        x1 = x1.to(device, non_blocking=pin_memory)
        x2 = x2.to(device, non_blocking=pin_memory)
        if inference is not None:
          y2, attn = inference(model, x1, x2)
        else:  # forced with loss (e.g. for training)
          t = x2[:,1:].flatten(0)
          y2p, attn = model(x1, x2)
          p = y2p[:,:-1,:].flatten(0, 1)
          loss = loss_fn(p, t)
          y2 = torch.argmax(y2p[:,:-1,:], 2)
          scores.setdefault('loss', []).extend(loss.item())
          pbar.set_description(f'loss: {loss.item():.3}')
        if metric:
          bs = metrics(x2[:,1:], y2)  # batch scores
          for score, values in bs.items():
            scores.setdefault(score, []).extend(values)
          # r1f = sum(bs['rouge1_f1'])/len(bs['rouge1_f1'])
          # r2f = sum(bs['rouge2_f1'])/len(bs['rouge2_f1'])
          # rlf = sum(bs['rougeL_f1'])/len(bs['rougeL_f1'])
          # sf = sum(bs['set_f1'])/len(bs['set_f1'])
          # rl = sum(bs['relen'])/len(bs['relen'])
          # cmt = psutil.virtual_memory().total
          # cma = psutil.virtual_memory().available
          # cmu = (cmt-cma)/cmt
          # gmu = GPUtil.getGPUs()[0].memoryUtil if len(GPUtil.getGPUs()) > 0 else 0
          # pbar.set_description(f'eval {int(cmu*100)}%/{int(gmu*100)}% \
          #   sf: {int(sf*100)}, r1f: {int(r1f*100)}, r2f: {int(r2f*100)}, \
          #   rlf: {int(rlf*100)}, rl: {int(rl*100)}')
          # pbar.set_description(f'gmu: {int(gmu*100)}%')
        batch_cnt += 1
        if max_samples is not None:
          if batch_cnt == max_samples:
            break
  if print_triples:
    if y2 is None:
      y2 = torch.argmax(y2p[:,:-1,:], 2)
    dy2 = map(ds.decode, y2.detach().cpu().numpy())
    dx2 = map(ds.decode, x2[:,1:].detach().cpu().numpy())
    dx1 = map(ds.decode, x1[:,1:].detach().cpu().numpy())
    triples = list(zip(dx1, dx2, dy2))[:print_triples]
    for triple in triples:
      print('')
      print(triple[0])
      print(triple[1])
      print(triple[2])
  if plot_attn:
    plot_attention(attn.detach().cpu().numpy()[:1], 
                   x1.detach().cpu().numpy()[:1], 
                   x2.detach().cpu().numpy()[:1],
                   decoder_x1 = ds.decode,
                   decoder_x2 = ds.decode,
                   figsize=(10, 10),
                   shift=True,
                   mask=True,
                   suptitle=None)
  return scores

In [0]:
def tensors_mem():
  for obj in gc.get_objects():
    try:
      if isinstance(obj, torch.Tensor):
      # if hasattr(obj, 'data'):
      #   if torch.is_tensor(obj.data):
        print(type(obj), obj.size(), obj.dtype, obj.device, obj.layout, obj.memory_format)
    except:
      pass
tensors_mem()

## Main

In [0]:
import torch
import optuna
from tqdm.notebook import tqdm
from pprint import pprint
import random
import numpy as np
from shutil import copyfile
import GPUtil
from torch import optim
from torchsummary import summary

def objective(trial):
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  trial.set_user_attr('device', str(device))
  if device.type == "cuda":
    print(f'gmu: {"% ".join(str(int(x.memoryUtil*100)) for x in GPUtil.getGPUs())}%')
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    print(torch.cuda.memory_stats(device=device))
    print(torch.cuda.memory_summary(device=device))
    print(f'gmu: {"% ".join(str(int(x.memoryUtil*100)) for x in GPUtil.getGPUs())}%')
  opt_fn_map = {
    'SGD': optim.SGD,
    'Adam': optim.Adam,
    'Adagrad': optim.Adagrad,
    'ASGD': optim.ASGD,
    'Adamax': optim.Adamax,
    'SparseAdam': optim.SparseAdam,
    'AdamW': optim.AdamW,
    'Adadelta': optim.Adadelta,
    'LBFGS': optim.LBFGS,
    'RMSprop': optim.RMSprop,
    'Rprop': optim.Rprop
  }
  eval_fn_map = {
      'Beam': Beam,
      'BeamBkw': BeamBkw,
      'Greedy': Greedy,
      'Forced': Forced
  }
  # fn = 'model.pth'
  fn = f'model_{trial.study.study_name}_{trial.number}.pth'
  print(fn)
  # checkpoint = None
  checkpoint = torch.load('/content/drive/My Drive/model_ria_base.pth',
                          map_location=device)
  train = False  # ?
  # train_config = checkpoint['train_config']
  train_config = {
    'seed': trial.suggest_int('seed', 0, 0),
    'n_epoches': trial.suggest_int('n_epoches', 1, 1),
    'lr': trial.suggest_loguniform('lr', 1e-3, 1e-3),
      # 'opt_fn': trial.suggest_categorical('opt_fn', ['SGD', 'Adam', 'Adagrad', 
      #                                           'ASGD', 'Adamax',
      #                                           'AdamW', 'Adadelta',
      #                                           'RMSprop', 'Rprop']),
    'opt_fn': trial.suggest_categorical('opt_fn', ['Adam']),
    'batch_size': trial.suggest_int('batch_size', 1, 1),
    'weight_decay': None,
    'pin_memory': True,
    # 'weight_decay': trial.suggest_loguniform('weight_decay', 1e-6, 1e-6),
    'n_x1': trial.suggest_int('n_x1', 1, 1)
  }
  eval_config = {
      # 'eval_fn': trial.suggest_categorical('eval_fn', ['Beam', 'BeamBkw']),
      # 'eval_fn': trial.suggest_categorical('eval_fn', ['Beam']),
      'eval_fn': 'Beam',
      'max_samples': 100}
  if eval_config['eval_fn'] in ['Beam', 'BeamBkw']:
    eval_config['beam_width'] = trial.suggest_int('beam_width', 2, 3)
    eval_config['beam_depth'] = trial.suggest_int('beam_depth', 8, 8)
    # eval_config['beam_reduction'] = trial.suggest_categorical('beam_reduction', ['sum', 'mean', 'none']),
    eval_config['beam_reduction'] = 'none'
    eval_config['max_input_size'] = 20000  # BATCH*(LEN1+LEN2) ~ BATCH*(10+30)
    eval_config['batch_reduction'] = 'none'
    eval_config['max_len'] = 14
    eval_config['plot_size'] = 10
    eval_config['plot_dist'] = False
    eval_config['verbose'] = False
    # eval_config['first_beam_width'] = eval_config['beam_width']
    eval_config['first_beam_width'] = trial.suggest_int('first_beam_width', 10, 10)
    # eval_config['depth_reduction'] = trial.suggest_categorical('depth_reduction', ['sum', 'mean'])
    eval_config['depth_reduction'] = 'mean'
    if eval_config['eval_fn'] == 'BeamBkw':
      eval_config['bkw_beam_width'] = trial.suggest_int('bkw_beam_width', 2, 2)
  if checkpoint is None:
    model_config = {
      'enc_num_embeddings': len(ds.stoi),
      'enc_embedding_dim': trial.suggest_int('enc_embedding_dim', 300, 300),
      'enc_padding_idx': ds.stoi['<pad>'],
      'dec_num_embeddings': len(ds.stoi),
      'dec_embedding_dim': trial.suggest_int('dec_embedding_dim', 300, 300),
      'dec_padding_idx': ds.stoi['<pad>'],
      'rnn_type': trial.suggest_categorical('rnn_type', ['RNN']),
      # 'rnn_type': trial.suggest_categorical('rnn_type', ['RNN']),
      'hidden_size': trial.suggest_int('hidden_size', 300, 300),
      'num_layers': trial.suggest_int('num_layers', 1, 1),
      'bidirectional': trial.suggest_categorical('bidirectional', [False]),
      # 'bidirectional': trial.suggest_categorical('bidirectional', [True]),
      'enc_dropout': trial.suggest_uniform('enc_dropout', 0.0, 0.0),
      'dec_dropout': trial.suggest_uniform('dec_dropout', 0.0, 0.0),
      'enc_rnn_dropout': trial.suggest_uniform('enc_rnn_dropout', 0.0, 0.0),
      'dec_rnn_dropout': trial.suggest_uniform('dec_rnn_dropout', 0.0, 0.0),
      'attn_type': trial.suggest_categorical('attn_type', ['soft_dist']),
      # 'attn_type': trial.suggest_categorical('attn_type', ['dot', 'dist', 'soft_dot', 'soft_dist', 'cos', 'soft_cos', 'none']),
      'out_hidden': trial.suggest_int('out_hidden', 600, 600),
      'pack': trial.suggest_categorical('pack', [True]),
    }
  else:
    model_config = checkpoint['model_config'] 
  # pprint(model_config)   
  # pprint(train_config)
  # pprint(eval_config)
  for k, v in model_config.items():
    trial.set_user_attr('model_' + k, v)
  for k, v in train_config.items():
    trial.set_user_attr('train_' + k, v)
  for k, v in eval_config.items():
    trial.set_user_attr('eval_' + k, v)
  seed = train_config['seed']
  torch.backends.cudnn.benchmark = False
  torch.backends.cudnn.deterministic = True
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  random.seed(seed)
  np.random.seed(seed)
  opt_fn = train_config['opt_fn']
  lr = train_config['lr']
  n_epoches = train_config['n_epoches']
  batch_size = train_config['batch_size']
  weight_decay = train_config['weight_decay']
  pin_memory = train_config['pin_memory']
  model = EncoderDecoder(**model_config).to(device)
  if checkpoint is not None:
    model.load_state_dict(checkpoint['model_state_dict'])
  pprint(model)
  n_params = sum(x.numel() for x in model.parameters() if x.requires_grad)
  trial.set_user_attr('n_params', n_params)
  if opt_fn not in ['Rprop'] and weight_decay is not None:
    opt = opt_fn_map[opt_fn](model.parameters(), lr=lr, 
                             weight_decay=weight_decay)
  else:
    opt = opt_fn_map[opt_fn](model.parameters(), lr=lr)
  loss_fn = torch.nn.NLLLoss(ignore_index=0)
  train_len = int(0.7*len(ds))
  test_len = int(0.2*len(ds))
  val_len = len(ds) - train_len - test_len
  lens = [train_len, test_len, val_len]
  trial.set_user_attr('n_samples', len(ds))
  trial.set_user_attr('n_samples_train', train_len)
  trial.set_user_attr('n_samples_test', test_len)
  trial.set_user_attr('n_samples_val', val_len)
  train_ds, test_ds, val_ds = random_split(ds, lens)
  train_dl = DataLoader(train_ds, batch_size=batch_size, num_workers=1, 
                        shuffle=True, drop_last=False,
                        collate_fn=Collate(train_config['n_x1']),
                        pin_memory=pin_memory)
  test_dl = DataLoader(test_ds, batch_size=batch_size, num_workers=1, 
                       shuffle=False, drop_last=False,
                       collate_fn=Collate(train_config['n_x1']),
                       pin_memory=pin_memory)
  val_dl = DataLoader(val_ds, batch_size=batch_size, num_workers=1, 
                      shuffle=False, drop_last=False,
                      collate_fn=Collate(train_config['n_x1']),
                      pin_memory=pin_memory)
  pprint(trial.user_attrs)
  pprint(trial.params)
  if train:
    pbar_epoch = tqdm(range(n_epoches))
    for i in tqdm(pbar_epoch):
      train_scores = run(True, device, model, opt, loss_fn, train_dl, ds, 
                          print_triples=3, 
                          metric=False,
                          pin_memory=pin_memory,
                          plot_attn=True if model_config['attn_type'] != 'none' else False)
      train_loss = sum(train_scores['loss'])/len(train_scores['loss'])
      trial.report(train_loss, step=i+1)
      # validation
      # eval_fn = eval_fn_map[eval_config['eval_fn']](**eval_config)
      val_scores = run(False, device, model, opt, loss_fn, val_dl, ds,  
                       print_triples=3, 
                       inference=None,
                       metric=False,
                       pin_memory=pin_memory,
                       plot_attn=False)
      val_loss = sum(val_scores['loss'])/len(val_scores['loss'])
      # print(f'{type(eval_fn).__name__} p: {val_pre:.3f}, r: {val_rec:.3f}, f1: {val_f1:.3f}, rl: {val_rl:.3f}')
      pbar_epoch.set_description(f'epoch: {i+1}, tloss: {train_loss:.3}, vloss: {val_loss:.3}')
      # pbar_epoch.set_description(f'l: {loss:.2f}/{val_loss:.2f} \
      # a: {int(acc*100)}/{int(val_acc*100)} p: {int(pre*100)}/{int(val_pre*100)} \
      # r: {int(rec*100)}/{int(val_rec*100)} f1: {int(f1*100)}/{int(val_f1*100)}\
      # rl: {int(rl*100)}/{int(val_rl*100)}')
      # torch.save({'epoch': i,
      #             'model_state_dict': model.state_dict(),
      #             'optimizer_state_dict': opt.state_dict(),
      #             'loss': loss,
      #             'model_config': model_config,
      #             'train_config': train_config,
      #             'eval_config': eval_config},
      #              fn)
      # try:
      #   copyfile(fn, '/content/drive/My Drive/' + fn)
      # except Exception as e:
      #   print(e)
  # test
  eval_fn = eval_fn_map[eval_config['eval_fn']](**eval_config)
  test_scores = run(False, device, model, opt, loss_fn, test_dl, ds, 
                    print_triples=1, 
                    inference=eval_fn,
                    metric=True,
                    pin_memory=pin_memory,
                    max_samples=eval_config['max_samples'],
                    plot_attn=False)
  scores = test_scores
  test_scores = {}
  for score, values in scores.items():
    test_scores[score] = sum(values)/len(values)
    trial.set_user_attr(score, test_scores[score])
  pprint(test_scores)
  return loss if train else test_scores['rouge1_f1']

import os
fn = 'optuna_20.db'
if not os.path.exists(fn):
  copyfile('/content/drive/My Drive/' + fn, fn)
study = optuna.create_study(study_name='2_vs_3_100', 
                            direction='maximize', 
                            storage=f'sqlite:///{fn}', 
                            sampler=optuna.samplers.RandomSampler(0),  # optuna.samplers.GridSampler()
                            load_if_exists=True)

def dump_optuna(study, frozen_trial):
  fn = 'optuna_20.db'
  print(f'saving optuna {fn}')
  try:
    copyfile(fn, '/content/drive/My Drive/' + fn)
  except Exception as e:
    print(e)
  print('done')

study.optimize(objective, n_trials=2, callbacks=[dump_optuna])

In [0]:
import gc
gc.collect()
# try:
#   copyfile(fn, '/content/drive/My Drive/' + fn)
# except Exception as e:
#   print(e)

In [0]:
# Greedy 200k
# {'relen': 0.9892607721833844,
#  'rouge1_f1': 0.38423491185005065,
#  'rouge1_precision': 0.4026889198348889,
#  'rouge1_recall': 0.37827062666600875,
#  'rouge2_f1': 0.16623936163781577,
#  'rouge2_precision': 0.17458141088353102,
#  'rouge2_recall': 0.16407297971294255,
#  'rougeL_f1': 0.3567234511838257,
#  'rougeL_precision': 0.3739773768993677,
#  'rougeL_recall': 0.35121314215560523,
#  'set_f1': 0.39064669071317015,
#  'set_precision': 0.4199063773953252,
#  'set_recall': 0.3771313731307635}

# try:
# except KeyboardInterrupt:
#   print(e)
# except Exception as e:
#   print(e)
# finally:
#   torch.save({'epoch': i,
#               'model_state_dict': model.state_dict(),
#               'optimizer_state_dict': opt.state_dict(),
#               'loss': loss,
#               'model_config': model_config,
#               'train_config': train_config},
#               'model_exception.pth')
#   try:
#     copyfile('model.pth', '/content/drive/My Drive/' + 'model.pth')
#   except Exception as e:
#     print(e)
# raise KeyboardInterrupt

In [0]:
from pprint import pprint
pprint(study.best_params)
pprint(study.best_value)
pprint(study.best_trial)
pprint(study.direction)
%reload_ext google.colab.data_table
df = study.trials_dataframe()
df.params_eval_fn.fillna(df.user_attrs_eval_eval_fn, inplace=True)
df

In [0]:
# optuna.visualization.plot_contour(study, params=['beam_width', 'beam_depth'])
# optuna.visualization.plot_optimization_history(study)
optuna.visualization.plot_slice(study)
# optuna.visualization.plot_parallel_coordinate(study)  # BUG
# optuna.visualization.plot_intermediate_values(study)

In [0]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# max_vocab = 1000
# with open(vocab_path) as f:
#   vocab = json.load(f)
# top_words = dict([(w, c) for w, c in sorted(vocab.items(), key=lambda x: x[1],
#                                             reverse=True)][:max_vocab])
# # sns.distplot(list(map(len, vocab.keys())), kde=False)
# # plt.figure()
# # sns.distplot(list(map(len, top_words.keys())), kde=True)
# # plt.figure()
# sns.distplot(list(vocab.values()), kde=False, hist_kws={'log': True})
# sns.distplot(list(top_words.values()), kde=False, hist_kws={'log': True})
# ax = sns.distplot(list(vocab.values()), kde=False, hist=True, 
#              hist_kws={'log': True}, norm_hist=False)
# # ax.set_xscale('log')
# ax = sns.distplot(list(top_words.values()), kde=False, hist=True, 
#              hist_kws={'log': True}, norm_hist=False)
# # ax.set_xscale('log')
# plt.figure()
# plt.title('Sentences lengths')
# sns.kdeplot([len(y) for x in ds.samples for y in x[0]], 
#             label='texts', shade=True, bw=2)
# sns.kdeplot([len(y) for x in ds.samples for y in x[1]], 
#             label='titles', shade=True, bw=2)
# plt.legend()
# plt.figure()
# sns.jointplot(x='titles lengths', y='texts lengths', data=pd.DataFrame({
#     "texts lengths": [sum(map(len, x[0])) for x in ds.samples],
#     "titles lengths": [sum(map(len, x[1])) for x in ds.samples]}
# ), kind='kde', n_levels=10)
# sns.jointplot(x='user_attrs_rouge1_recall', y='user_attrs_rouge1_precision',
#               data=df, kind='kde', n_levels=10) 
# hue="sex"
# sns.catplot(x="params_beam_width", y="params_beam_depth", 
#             kind="violin", hue="params_eval_fn", split=True, bw=1., data=df)
# sns.violinplot(x=df.params_beam_depth, y=df.user_attrs_rouge1_f1);

# sns.catplot(x="params_beam_width", y="user_attrs_rougeL_f1", 
#             # hue="smoker",
#             col="params_eval_fn", aspect=.6,
#             kind="swarm", data=df);
# sns.pairplot(df, hue="value", vars=["params_beam_width", "params_beam_depth"], kind="kde")

# g = sns.PairGrid(df, vars=["params_beam_width", "params_beam_depth",
#                            "user_attrs_rougeL_f1", "user_attrs_rouge2_f1",
#                            "user_attrs_rouge1_f1"])
# g.map_upper(plt.scatter)
# g.map_lower(sns.kdeplot)
# g.map_diag(sns.kdeplot, lw=3, legend=False);

# sns.violinplot(x=df.params_beam_depth, y=df.user_attrs_rouge1_f1);

# sns.catplot(x="params_beam_width", 
#             y="params_beam_depth", 
#             hue="value",
#             col="params_eval_fn", 
#             aspect=.618,
#             kind="swarm", palette='Blues', data=df);
# sns.catplot(x="params_beam_width", 
#             y="user_attrs_rouge2_f1", 
#             hue="params_eval_fn",
#             col="params_beam_depth", 
#             aspect=.618,
#             kind="swarm", data=df);

# sns.lmplot(x="params_beam_width", y="user_attrs_rouge2_f1", 
#           #  hue="smoker", 
#            col="params_beam_depth", row="params_eval_fn", data=df);

# df.params_eval_fn.fillna(df.user_attrs_eval_params_eval_fn, inplace=True)

sns.pairplot(df, x_vars=[
                        #  "params_beam_width", 
                         "params_beam_depth",
                        #  "params_first_beam_width",
                         "user_attrs_eval_input_limit"
                         ], 
             y_vars=["user_attrs_rouge1_f1", 
                     "user_attrs_rouge2_f1",
                     "user_attrs_rougeL_f1",
                     "user_attrs_set_f1"],
             hue="user_attrs_eval_max_samples", 
             height=5, aspect=1., kind="reg");

# sns.jointplot(x='params_beam_width', y='params_beam_depth',
#               data=df, kind='kde', n_levels=10) 