In [2]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("allenai/led-large-16384-arxiv")

In [3]:
import pickle
from tqdm import tqdm
import torch

In [3]:
dataset = "arxiv"
data_type = "test"

with open(file=f"../_FRIDGE/{dataset}/{data_type}_filtered.pickle", mode='rb') as f:
    arxiv_test = pickle.load(f)

In [4]:
from copy import deepcopy
sample = deepcopy(arxiv_test[:30])
del arxiv_test

In [5]:
sample[:2]

[{'article': tensor([   0, 1990,   59,  ...,  438,  479,    2]),
  'abstract': tensor([    0,   627,   765,   111,  1385, 27185,  2192,     9,     5,  1230,
           3778, 27840,   443, 22798,    31, 26713,  4193, 37613,     7, 16874,
          24761, 26873,    32,  3373,   479,    13,   209,   414,  1437, 50118,
              5, 46764,  3693,  1966,  8711,  2430, 22792,    13,     5,   675,
          24414,     9,    59,   787,  1178, 40051,   288,   360,  2156,    53,
              5,   476,  8576,  1966,  8711,    10, 27697,  1233,  4996,    11,
             42,    86, 22455,   479,  1437, 50118,    10,    92,  5448,     9,
              5,  9726,     9,    41, 23930,   111,  1683,    11,  8576,    16,
           1850,     8,    24,    16,  2305,    14,     5, 18918,    12,  1208,
            675, 24414,    16,    10, 46336,     9,     5, 27185,  2192,    31,
              5, 22455,     9,   787,  1178, 40051,   134,    68, 27779,   360,
            479,  1437,  1437,  1437,     5

In [109]:
seq = sample[1]["abstract"]
tokenizer.decode(seq)

'<s>we study the detectability of circular polarization in a stochastic gravitational wave background from various sources such as supermassive black hole binaries, cosmic strings, and inflation in the early universe with pulsar timing arrays. \n we calculate generalized overlap reduction functions for the circularly polarized stochastic gravitational wave background. \n we find that the circular polarization can not be detected for an isotropic background. however, there is a chance to observe the circular polarization for an anisotropic gravitational wave background. \n we also show how to separate polarized gravitational waves from unpolarized gravitational waves.</s>'

In [41]:
tokenizer.all_special_ids

[0, 2, 3, 1, 50264]

In [4]:
def pad_pure_rand(summ):
    len_pad = torch.randint(2, 20, (1,))[0]
    num_pad = torch.randint(3, 120//len_pad, (1,))[0]

    MAX_ID = tokenizer.vocab_size - 1
    pad = torch.randint(4, MAX_ID, (num_pad,))

    return torch.cat((summ[:-1], pad.repeat(num_pad), torch.tensor([tokenizer.eos_token_id])))

#tokenizer.decode(pad_pure_rand(seq))

In [5]:
def pad_dup_back(summ):
    len_pad = torch.randint(2, 20, (1,))[0]
    num_pad = torch.randint(3, 120//len_pad, (1,))[0]
    
    pad = summ[-len_pad:-1]

    return torch.cat((summ[:-1], pad.repeat(num_pad), torch.tensor([tokenizer.eos_token_id])))

#tokenizer.decode(pad_dup_back(seq))

In [6]:
def noise(summ):
    ret = summ.clone().detach()
    len_ret = ret.shape[0]

    num_noise = torch.randint(min(len_ret, 3), len_ret, (1,))[0]

    MAX_ID = tokenizer.vocab_size - 1
    ret[torch.randint(1, len_ret-1, (num_noise,))] = torch.randint(4, MAX_ID-1, (num_noise,))
    
    return ret

#tokenizer.decode(noise(seq))

In [7]:
def randmod(summ):
    rand = torch.randint(0, 10, (1,))[0]

    ret = []

    if rand < 2:
        ret = pad_pure_rand(summ)
    elif rand < 4:
        ret = pad_dup_back(summ)
    else:
        ret = noise(summ)
    
    return ret

#tokenizer.decode(randmod(seq))

In [8]:
def aug(dataset, data_type):
    with open(file=f"../_FRIDGE/{dataset}/{data_type}_filtered.pickle", mode='rb') as f:
        data = pickle.load(f)
    
    for i in tqdm(data):
        i["noised"] = randmod(i["abstract"])
    
    with open(file=f"../_FRIDGE/_aug/{dataset}_{data_type}_aug.pickle", mode='wb') as f:
        pickle.dump(data, f)
    del data

In [23]:
dataset, data_type = "pubmed", "test"
aug(dataset, data_type)

100%|██████████| 6588/6588 [00:00<00:00, 24777.06it/s]


In [24]:
with open(file=f"../_FRIDGE/_aug/{dataset}_{data_type}_aug.pickle", mode='rb') as f:
    result = pickle.load(f)

In [26]:
result[0]

{'article': tensor([    0,   260, 43537,  ..., 11048,   479,     2]),
 'abstract': tensor([    0,  6762,    15,     5,  8819,     9,  6882,    11,  2221,  9554,
           128,    29,  2199,    36,   181,   417,  4839,    34,    57, 20428,
          1135,    63, 21087,    11,   823,   654,   207,     9,  1484,     8,
            63,  2430,   913,    15,  1318,     9,   301,   479,  1437, 50118,
           986,   690,    33,  1581,    14, 14913, 33844, 21303,  5298, 29210,
         14526,   819,    11,   181,   417,  1484, 25606,   959,  2156,     7,
          1248,  2156,   117,   892,    34,  2024,  1118,   181,   417,  1484,
            19,     8,   396,  6882,     7, 10154,     5,   913,     9,  6882,
            15, 14526, 29210,  2963,    11,   181,   417,   479,  1437, 50118,
            42,   892,  1118, 14526,   819,   420,   654,   181,   417,  3597,
            19,     8,   396,  6882,    36,   601,   181,  6106,  2744, 25606,
          2357,   181,  6106,  4839,  2156,    54

In [28]:
aug("pubmed", "train")
aug("pubmed", "validation")

100%|██████████| 118681/118681 [00:06<00:00, 18876.52it/s]
100%|██████████| 6573/6573 [00:00<00:00, 25194.53it/s]


In [29]:
aug("arxiv", "test")
aug("arxiv", "validation")

100%|██████████| 5699/5699 [00:00<00:00, 25137.17it/s]
100%|██████████| 5677/5677 [00:00<00:00, 25058.79it/s]


In [9]:
aug("arxiv", "train")

100%|██████████| 174216/174216 [00:11<00:00, 15402.70it/s]
