In [29]:
%matplotlib inline

In [30]:
import os
import torch

In [31]:
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")
print(device)

cuda:5


In [32]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# load data

In [33]:
from transformer_translation.dataset import ParallelLanguageDataset
from torch.utils.data import DataLoader

In [34]:
data_path = r"/home/alex/data/nlp/agmir/transf_processed_data"
#data_path = 'transformer_translation/data/processed'

In [35]:
num_tokens = 2000
max_seq_length = 96
dataset = ParallelLanguageDataset(
    os.path.join(data_path, 'tags/set.pkl')#'en/train.pkl')#
    ,os.path.join(data_path, 'reports/set.pkl')#'fr/train.pkl')#
    ,num_tokens
    ,max_seq_length)
loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)

In [36]:
tk1 = [j for i in dataset.data_1 for j in i]
tk2 = [j for i in dataset.data_2 for j in i]
(max(tk1)), (max(tk2))

(589, 1952)

# test BoTags gen

In [37]:
import pickle
with open(os.path.join(data_path, 'tags/set_raw.pkl'), 'rb') as f:
        tags_raw = pickle.load(f)

In [38]:
with open(os.path.join(data_path, 'tags/voc.pkl'), 'rb') as f:
        tags_word2index = pickle.load(f)

In [39]:
from sklearn.feature_extraction.text import CountVectorizer
countvec = CountVectorizer(vocabulary=list(tags_word2index.keys()))
countvec.transform([tags_raw[10]]).toarray()

array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
        0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [40]:
len(tags_word2index.keys())

588

In [41]:
len(set([j for i in tags_raw for j in i.split()]))

587

# loading

In [42]:
def load_tags_reports(tags_path, reports_path, max_seq_length):
    # load tags
    with open(tags_path, 'rb') as f:
        tags = pickle.load(f)
        
    # load reports
    with open(reports_path, 'rb') as f:
        reports = pickle.load(f)

    # group reports according to their exact length, and discard some reports
    data_lengths = {}
    for i, str_ in enumerate(reports):
        if 0 < len(str_) <= max_seq_length: # discard reports with no text or text of length > max_seq_length
            if len(str_) in data_lengths:
                data_lengths[len(str_)].append(i)
            else:
                data_lengths[len(str_)] = [i]
    return tags, reports, data_lengths

In [43]:
%%time
tags, reports, data_lengths = load_tags_reports(
    tags_path=os.path.join(data_path, 'tags/set_raw.pkl')
    ,reports_path=os.path.join(data_path, 'reports/set.pkl')
    ,max_seq_length=96)

CPU times: user 40 ms, sys: 4 ms, total: 44 ms
Wall time: 42.3 ms


# batch gen

In [44]:
import random

def gen_batches_tags_reports(num_tokens, data_lengths):
    # Shuffle all the indices
    for k, v in data_lengths.items():
        random.shuffle(v)

    batches = []
    prev_tokens_in_batch = 1e10
    for k in sorted(data_lengths):
        v = data_lengths[k]
        total_tokens = k * len(v)

        while total_tokens > 0:
            tokens_in_batch = min(total_tokens, num_tokens) - min(total_tokens, num_tokens) % (k)
            sentences_in_batch = tokens_in_batch // (k)

            # Combine with previous batch?
            if tokens_in_batch + prev_tokens_in_batch <= num_tokens:
                batches[-1].extend(v[:sentences_in_batch])
                prev_tokens_in_batch += tokens_in_batch
            else:
                batches.append(v[:sentences_in_batch])
                prev_tokens_in_batch = tokens_in_batch
            v = v[sentences_in_batch:]

            total_tokens = k * len(v)
    return batches

In [45]:
batches = gen_batches_tags_reports(2000, data_lengths)

In [46]:
for r in batches[97]:
    print(len(reports[r]))

94
94
94
94
94
94
94
94
94
95
95
95
95
95
95
95


# getitem

In [47]:
from sklearn.feature_extraction.text import CountVectorizer
tag_vocab = list(set([j for i in tags_raw for j in i.split()]))
countvec = CountVectorizer(vocabulary=tag_vocab)

In [76]:
def getitem_tags(idx, data, batches, countvec):
    sentence_indices = batches[idx]
    batch = [data[i] for i in sentence_indices]
    
    return countvec.transform(batch), np.array([[False] for i in range(len(batch))])

In [80]:
batch_bot, batch_bot_masks = getitem_tags(98, tags_raw, batches, countvec)

In [50]:
import numpy as np

def getitem_report(idx, data, batches):
    sentence_indices = batches[idx]
    batch = [[2] + data[i] + [3] for i in sentence_indices]

    seq_length = 0
    for sentence in batch:
        if len(sentence) > seq_length:
            seq_length = len(sentence)

    masks = []
    for i, sentence in enumerate(batch):
        masks.append([False for _ in range(len(sentence))] + [True for _ in range(seq_length - len(sentence))])
        batch[i] = sentence + [0 for _ in range(seq_length - len(sentence))]

    return np.array(batch), np.array(masks)

In [52]:
batch_tk, batch_masks = getitem_report(98, reports, batches)

In [67]:
batch_masks

array([[False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False],
       [False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        Fals

In [58]:
batch_tk.shape

(6, 98)

# bring all together

In [96]:
from torch.utils.data import Dataset
from sklearn.feature_extraction.text import CountVectorizer

In [108]:
def create_countvec(tags):
        tag_vocab = list(set([j for i in tags for j in i.split()]))
        return CountVectorizer(vocabulary=tag_vocab)

In [110]:
class TagReportDataset(Dataset):
    def __init__(self, tags_path, reports_path, num_tokens, max_seq_length):
        self.num_tokens = num_tokens
        
        # load data
        self.tags, self.reports, self.data_lengths = load_tags_reports(tags_path, reports_path, max_seq_length)
        
        # create count vectorizer
        self.countvec = create_countvec(self.tags)
        
        # generate batches
        self.batches = gen_batches_tags_reports(num_tokens, self.data_lengths)

    def __getitem__(self, idx):
        tgt, tgt_mask = getitem_report(idx, self.reports, self.batches)
        src_bow, src_bow_masks = getitem_tags(idx, self.tags, self.batches, self.countvec)

        return src_bow.toarray(), src_bow_masks, tgt, tgt_mask

    def __len__(self):
        return len(self.batches)

    def shuffle_batches(self):
        self.batches = gen_batches_tags_reports(self.num_tokens, self.data_lengths)

In [1]:
from transformer_translation.dataset import TagReportDataset

In [11]:
num_tokens = 2000
max_seq_length = 96
dataset = TagReportDataset(
    os.path.join(data_path, 'tags/set_raw.pkl')
    ,os.path.join(data_path, 'reports/set.pkl')
    ,num_tokens
    ,max_seq_length)
loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)

In [13]:
len(dataset)

99

In [15]:
src_bow, tgt, tgt_mask = dataset.__getitem__(98)

In [16]:
src_bow.shape, tgt.shape, tgt_mask.shape

((6, 587), (6, 98), (6, 98))