In [1]:
from torchtext import data, datasets

BOS = 2
EOS = 3

class DataLoader():

    def __init__(self, train_fn, valid_fn, batch_size = 64, 
                                            device = -1, 
                                            max_vocab = 99999999, 
                                            max_length = 15, 
                                            fix_length = None, 
                                            use_bos = True, 
                                            use_eos = True, 
                                            shuffle = True
                                            ):
        super(DataLoader, self).__init__()

        self.text = data.Field(sequential = True, 
                                use_vocab = False, 
                                batch_first = True, 
                                include_lengths = True, 
                                fix_length = fix_length, 
                                init_token = '<BOS>' if use_bos else None, 
                                eos_token = '<EOS>' if use_eos else None
                                )

        train = LanguageModelDataset(path = train_fn, 
                                        fields = [('text', self.text)], 
                                        max_length = max_length
                                        )
        valid = LanguageModelDataset(path = valid_fn, 
                                        fields = [('text', self.text)], 
                                        max_length = max_length
                                        )

        self.train_iter = data.BucketIterator(train, 
                                                batch_size = batch_size, 
                                                device = device, 
                                                shuffle = shuffle, 
                                                sort_key=lambda x: -len(x.text), 
                                                sort_within_batch = True
                                                )
        self.valid_iter = data.BucketIterator(valid, 
                                                batch_size = batch_size, 
                                                device = device, 
                                                shuffle = False, 
                                                sort_key=lambda x: -len(x.text), 
                                                sort_within_batch = True
                                                )

        self.text.build_vocab(train, max_size = max_vocab)

class LanguageModelDataset(data.Dataset):
    """Defines a dataset for machine translation."""

    def __init__(self, path, fields, max_length=None, **kwargs):
        if not isinstance(fields[0], (tuple, list)):
            fields = [('text', fields[0])]

        examples = []
        with open(path) as f:
            for line in f:
                line = line.strip()
                if max_length and max_length < len(line.split()):
                    continue
                if line != '':
                    examples.append(data.Example.fromlist(
                        [line], fields))

        super(LanguageModelDataset, self).__init__(examples, fields, **kwargs)

In [None]:
fields = [('text', fields[0])]

examples = []
with open(path) as f:
    for line in f:
        line = line.strip()
        if max_length and max_length < len(line.split()):
            continue
        if line != '':
            examples.append(data.Example.fromlist(
                [line], fields))


In [None]:

if __name__ == '__main__':
    import sys
    loader = DataLoader(sys.argv[1], sys.argv[2])

    for batch_index, batch in enumerate(loader.train_iter):
        print(batch.text)

        if batch_index > 1:
            break

In [2]:
import sys

In [5]:
sys.argv

['ipykernel_launcher',
 '--ip=127.0.0.1',
 '--stdin=9013',
 '--control=9011',
 '--hb=9010',
 '--Session.signature_scheme="hmac-sha256"',
 '--Session.key=b"e0a701a3-242a-44e5-8bc5-267822e7030b"',
 '--shell=9012',
 '--transport="tcp"',
 '--iopub=9014',
 '--f=C:\\Users\\dwkim\\AppData\\Local\\Temp\\tmp-22192ZVup6YcxrNM8.json']

In [None]:
    loader = DataLoader(sys.argv[1], sys.argv[2])

    for batch_index, batch in enumerate(loader.train_iter):
        print(batch.text)

        if batch_index > 1:
            break

In [10]:
import glob
import os

In [14]:
path = 'data/bird/text/'
text_list = glob.glob(os.path.join(path, '*/*.txt'))


In [16]:
len(text_list)

11788

In [17]:
import tqdm

In [18]:
text_list = glob.glob(os.path.join(path, '*/*.txt'))

examples = []

for t in tqdm.tqdm(text_list):
    with open(t) as f:
        for line in f:
            line = line.strip()
            # if max_length and max_length < len(line.split()):
            #     continue
            if line != '':
                examples.append(line)
                # examples.append(data.Example.fromlist(
                #     [line], fields))

100%|██████████| 11788/11788 [00:02<00:00, 4602.62it/s]


In [22]:
examples

e white and the wings are dark grey.',
 'a large bird with white crown and black eyebrow having large bill and black secondaries and primaries',
 'this bird is black and white in color, with a orange beak and a black eye ring.',
 'this is a white bird with an orange bill and black on the wingbars.',
 'this bird has white body, black wing, and long hooked bill.',
 'this bird has a white head, neck, belly, vent and tarsus, with black and grey feathers covering the rest of its body.',
 'this bird has wings that are grey and has a white belly',
 'this bird has a large white body with black wings and long beak',
 'this bird has wings that are black and has a white belly',
 'this bird has wings that are grey and has a white belly',
 'the medium sized bird has black wings and tails, white head and belly, and a yellow downward curved beak.',
 'this is a medium sized white bird with dark gray wings and tail and a yellow beak that curves down at the tip.',
 'the bird has a curved yellow bill, sm

In [1]:
import os
import pickle
import random
import numpy as np
import pandas as pd
import PIL
from PIL import Image
from collections import defaultdict
import torch
import torch.utils.data as data
from nltk.tokenize import RegexpTokenizer
import torchvision.transforms as transforms

In [10]:
class Text_Dataset(data.Dataset):
    def __init__(self, data_dir, split, words_num, print_shape=False):
        self.words_num = words_num
        self.data_dir = data_dir
        self.split = split
        # self.device = device

        self.filenames, self.captions, self.idx2word, self.word2idx, self.n_word \
            = self.load_text_data(data_dir, split)

    def load_text_data(self, data_dir, split):
        filepath = os.path.join(data_dir, 'captions.pickle')
        train_names = self.load_filenames(data_dir, 'train')
        if not data_dir.find('coco'):
            test_names = self.load_filenames(data_dir, 'test')
        else:
            test_names = self.load_filenames(data_dir, 'val')

        if not os.path.isfile(filepath):
            train_captions = self.load_captions(data_dir, train_names)
            test_captions = self.load_captions(data_dir, test_names)

            train_captions, test_captions, idx2word, word2idx, n_words = self.build_dictionary(train_captions, test_captions)
            with open(filepath, 'wb') as f:
                pickle.dump([train_captions, test_captions,
                             idx2word, word2idx], f, protocol=2)
                print('Save to: ', filepath)


        else:
            with open(filepath, 'rb') as f:
                x = pickle.load(f)
                train_captions, test_captions = x[0], x[1]
                idx2word, word2idx = x[2], x[3]
                del x
                n_words = len(idx2word)
                print('Load from: ', filepath)

        if split == 'train':
            filenames = train_names
            captions = train_captions
        
        else:
            filenames = test_names
            captions = test_captions
        
        return filenames, captions, idx2word, word2idx, n_words

    def build_dictionary(self, train_captions, test_captions):
        word_counts = defaultdict(float)
        captions = train_captions + test_captions
        for sent in captions:
            for word in sent:
                word_counts[word] += 1

        vocab = [w for w in word_counts if word_counts[w] >= 0]

        idx2word = {}
        idx2word[0] = '<end>'

        word2idx = {}
        word2idx['<end>'] = 0
        ix = 1
        for w in vocab:
            word2idx[w] = ix
            idx2word[ix] = w
            ix += 1

        # Add Begin of Sentence token
        # Remove this token when adapt this weights to the main model
        # word2idx['<bos>'] = ix
        # idx2word[ix] = '<bos>'

        train_captions_new = []
        for t in train_captions:
            rev = []
            for w in t:
                if w in word2idx:
                    rev.append(word2idx[w])
            rev.append(0)
            train_captions_new.append(rev)

        test_captions_new = []
        for t in test_captions:
            rev = []
            for w in t:
                if w in word2idx:
                    rev.append(word2idx[w])
            rev.append(0)
            test_captions_new.append(rev)

        return train_captions_new, test_captions_new, idx2word, word2idx, len(idx2word)


    def load_captions(self, data_dir, filenames):
        all_captions = []
        for i in range(len(filenames)):
            cap_path = '%s/text/%s.txt' % (data_dir, filenames[i])
            print("cap_path", cap_path)

            with open(cap_path, "r", encoding='utf-8') as f:
                print(f.read())
                print(type(f.read()))
                captions = f.read().split('\n')
                cnt = 0
                for cap in captions:
                    if len(cap) == 0:
                        continue
                    cap = cap.replace("\ufffd\ufffd", " ")
                    # picks out sequences of alphanumeric characters as tokens
                    # and drops everything else
                    tokenizer = RegexpTokenizer(r'\w+')
                    tokens = tokenizer.tokenize(cap.lower())
                    # print('tokens', tokens)
                    if len(tokens) == 0:
                        print('cap', cap)
                        continue

                    tokens_new = []
                    for t in tokens:
                        t = t.encode('ascii', 'ignore').decode('ascii')
                        if len(t) > 0:
                            tokens_new.append(t)
                    all_captions.append(tokens_new)
                    cnt += 1
            print('cap cnt', cnt)

        return all_captions

    def load_filenames(self, data_dir, split):
        filepath = '%s/filenames/%s/filenames.pickle' % (data_dir, split)
        print("filepath : ", filepath)

        if os.path.isfile(filepath):
            with open(filepath, 'rb') as f:
                filenames = pickle.load(f)
            print('Load filenames from: %s (%d)' % (filepath, len(filenames)))
        else:
            filenames = []
        
        return filenames

    def get_caption(self, sent_idx):
        sent_caption = np.asarray(self.captions[sent_idx]).astype('int64')
        
        # Ignore this warning
        if (sent_caption == 0).sum() > 0:
            print('ERROR: do not need END (0) token', sent_caption)

        num_words = len(sent_caption)
        # pad with 0s (i.e., '<end>')
        x = np.zeros((self.words_num), dtype='int64')
        # x = np.zeros((self.words_num, 1), dtype='int64')

        x_len = num_words
        if num_words <= self.words_num:
            x[:num_words] = sent_caption

        else: # For LM pretraining 
            x[:] = sent_caption[:self.words_num]
            x_len = self.words_num

        return x, x_len
 
    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        data_dir = self.data_dir
        caps, cap_len = self.get_caption(idx)
        return caps, cap_len

In [11]:
hr_dataset = Text_Dataset(data_dir='data/bird/',
                              split='train',
                              words_num=15,
                              print_shape=False)
# (self, data_dir, split, words_num, print_shape=False)

filepath :  data/bird//filenames/train/filenames.pickle
Load filenames from: data/bird//filenames/train/filenames.pickle (8855)
filepath :  data/bird//filenames/val/filenames.pickle
Load filenames from: data/bird//filenames/val/filenames.pickle (2933)
cap_path data/bird//text/002.Laysan_Albatross/Laysan_Albatross_0002_1027.txt
a bird with a very long wing span and a long pointed beak.
the long-beaked bird has a white body with long brown wings.
this is a white bird with brown wings and a large pointy beak.
this large bird has long bill, a white breast, belly & head and a black back & wings.
bird has an extremely long wingspan with a darker top and white belly and head.
this bird has wings that are brown and has a white belly
this bird has extended wings and a white head and body.
this bird is white and brown in color, with a long curved beak.
this white and grey bird has an enormous wing span.
this bird has wings that are brown and has a white body

<class 'str'>


AttributeError: 'str' object has no attribute 'decode'

In [46]:
hr_dataset.captions

[]

In [52]:
hr_dataloader = torch.utils.data.DataLoader(dataset=hr_dataset,
                                            batch_size=32,
                                            drop_last=True,
                                            shuffle=True,
                                            num_workers=0)

ValueError: num_samples should be a positive integer value, but got num_samples=0

In [31]:
iter(hr_dataloader).next()

IndexError: list index out of range

In [None]:
# class Text_Dataset(data.Dataset):
#     def __init__(self, data_dir, split, device,
#                  words_num, transform=None, print_shape=False):

#         self.words_num = words_num
#         self.data_dir = data_dir
#         self.split = split
#         self.embeddings_num = captions_per_image
#         self.device = device

#         self.captions, self.idx2word, self.word2idx, self.n_word = \
#             self.load_text_data(data_dir, split)

#     def load_text_data(self, data_dir, split):
#         filepath = os.path.join(data_dir, 'captions.pickle')

#         if not os.path.isfile(filepath):
#             captions = self.load_captions(data_dir)

#             captions, idx2word, word2idx, n_words = \
#                 self.build_dictionary(captions)
#             with open(filepath, 'wb') as f:
#                 pickle.dump([captions, idx2word, word2idx], f, protocol=2)
#                 if self.print_shape==True:
#                     print('Save to: ', filepath)

#         else:
#             with open(filepath, 'rb') as f:
#                 x = pickle.load(f)
#                 captions, idx2word, word2idx = x[0], x[1], x[2]
#                 del x
#                 n_words = len(idx2word)
#                 if self.print_shape == True:
#                     print('Load from: ', filepath)

#         return captions, idx2word, word2idx, n_words

#     def build_dictionary(self, captions):
#         word_counts = defaultdict(float)
#         for sent in captions:
#             for word in sent:
#                 word_counts[word] += 1

#         vocab = [w for w in word_counts if word_counts[w] >= 0]

#         idx2word = {}
#         idx2word[0] = '<end>'
#         word2idx = {}
#         word2idx['<end>'] = 0
#         ix = 1
#         for w in vocab:
#             word2idx[w] = ix
#             idx2word[ix] = w
#             ix += 1

#         captions_new = []
#         for t in captions:
#             rev = []
#             for w in t:
#                 if w in word2idx:
#                     rev.append(word2idx[w])
#             rev.append(0)
#             captions_new.append(rev)

#         return captions_new, idx2word, word2idx, len(idx2word)


#     def load_captions(self, data_dir):
#         text_list = glob.glob(os.path.join(data_dir, '*/*.txt'))
#         all_captions = []
#         for cap_path in tqdm.tqdm(text_list):
#             # cap_path = '%s/text/%s.txt' % (data_dir, filenames[i])
#             with open(cap_path, "r") as f:
#                 captions = f.read().decode('utf8').split('\n')
#                 cnt = 0
#                 for cap in captions:
#                     if len(cap) == 0:
#                         continue
#                     cap = cap.replace("\ufffd\ufffd", " ")
#                     # picks out sequences of alphanumeric characters as tokens
#                     # and drops everything else
#                     tokenizer = RegexpTokenizer(r'\w+')
#                     tokens = tokenizer.tokenize(cap.lower())
#                     # print('tokens', tokens)
#                     if len(tokens) == 0:
#                         if self.print_shape == True:
#                             print('cap', cap)
#                         continue

#                     tokens_new = []
#                     for t in tokens:
#                         t = t.encode('ascii', 'ignore').decode('ascii')
#                         if len(t) > 0:
#                             tokens_new.append(t)
#                     all_captions.append(tokens_new)
#                     cnt += 1
#                     if cnt == self.embeddings_num:
#                         break
#                 if cnt < self.embeddings_num:
#                     print('ERROR: the captions for %s less than %d'
#                           % (filenames[i], cnt))
#         return all_captions

#     def get_caption(self, sent_idx):
#         sent_caption = np.asarray(self.captions[sent_idx]).astype('int64')

#         if (sent_caption == 0).sum() > 0:
#             print('ERROR: do not need END (0) token', sent_caption)

#         num_words = len(sent_caption)
#         # pad with 0s (i.e., '<end>')
#         x = np.zeros((self.words_num), dtype='int64')
#         # x = np.zeros((self.words_num, 1), dtype='int64')

#         x_len = num_words
#         if num_words <= self.words_num:
#             x[:num_words] = sent_caption
#         else:
#             ix = list(np.arange(num_words))  # 1, 2, 3,..., maxNum
#             np.random.shuffle(ix)
#             ix = ix[:self.words_num]
#             ix = np.sort(ix)
#             x[:] = sent_caption[ix]
#             x_len = self.words_num
#         return x, x_len

#     def __len__(self):
#         return len(self.filenames)

#     def __getitem__(self, idx):

#         while True:
#             wrong_idx = random.randint(0, self.__len__())
#             if idx != wrong_idx:
#                 break

#         data_dir = self.data_dir

#         sent_ix = random.randint(0, self.embeddings_num)
#         new_sent_ix = idx * self.embeddings_num + sent_ix
#         caps, cap_len = self.get_caption(new_sent_ix)

#         return caps, cap_len