In [1]:
import torch
import os
import io
import sys
import csv
import six
# from sklearn.feature_extraction.text import CountVectorizer
from collections import Counter
import itertools
import collections
import argparse

from torchtext.vocab import Vectors, Vocab
import gensim
from collections import defaultdict, OrderedDict

from model import IQN
from trainer import Trainer

In [2]:
def unicode_csv_reader(unicode_csv_data, **kwargs):
    # Fix field larger than field limit error
    maxInt = sys.maxsize
    while True:
        # decrease the maxInt value by factor 10
        # as long as the OverflowError occurs.
        try:
            csv.field_size_limit(maxInt)
            break
        except OverflowError:
            maxInt = int(maxInt / 10)
    csv.field_size_limit(maxInt)

    if six.PY2:
        # csv.py doesn't do Unicode; encode temporarily as UTF-8:
        csv_reader = csv.reader(utf_8_encoder(unicode_csv_data), **kwargs)
        for row in csv_reader:
            # decode UTF-8 back to Unicode, cell by cell:
            yield [cell.decode('utf-8') for cell in row]
    else:
        for line in csv.reader(unicode_csv_data, **kwargs):

            yield line

In [3]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, path, max_len=64, vocab=None, specials=[]):
        self.keys = ['Date', 'Code', 'State', 'Next_State', 'Reward']
        self.string_keys = ['State', 'Next_State']
        self.tensor_keys = ['Reward'] + self.string_keys
        self.data_list = {k: [] for k in self.keys}
        
        with io.open(os.path.expanduser(path), encoding="utf8") as f:
            reader = unicode_csv_reader(f, delimiter='\t')
            for line in reader:
                for k, x in zip(self.keys, line):
                    if k in self.string_keys:
                        self.data_list[k].append(x.split(':'))
                    elif k in self.tensor_keys:
                        self.data_list[k].append(float(x))
                    else:
                        self.data_list[k].append(x)
        
        self.unk_token = '<unk>'
        self.pad_token = '<pad>'
        self.init_token = '<cls>'
        self.eos_token = '<eos>'
        self.max_len = max_len
        self.fix_len = self.max_len + (self.init_token, self.eos_token).count(None) - 2
        self.specials = specials

        self.words = list(itertools.chain.from_iterable(self.data_list['State']))
        self.counter = Counter(self.words)

        self.vocab = None
        if vocab is not None: 
            self.vocab = vocab
            self.padded_list = self.pad(self.data_list)
            self.tensor_list = self.numericalize(self.padded_list)

    def build_vocab(self, vectors, min_freq):
        specials = list(OrderedDict.fromkeys(
            tok for tok in [self.unk_token, self.pad_token, self.init_token,
                            self.eos_token] + self.specials
            if tok is not None))
        self.vocab = Vocab(self.counter, 
                                           specials=specials, 
                                           vectors=vectors, 
                                           min_freq=min_freq) 
        self.padded_list = self.pad(self.data_list)
        self.tensor_list = self.numericalize(self.padded_list)

    def pad(self, data):
        padded = {k: [] for k in self.keys}
        for key, val in data.items():
            if key in self.string_keys:
                arr = []
                for x in val:
                    arr.append(
                        ([self.init_token])
                        + list(x[:self.fix_len])
                        + ([self.eos_token])
                        + [self.pad_token] * max(0, self.fix_len - len(x)))
                padded[key] = arr
            else:                
                padded[key] = val

        return padded

    def numericalize(self, padded):
        tensor = {k: [] for k in self.keys}
        for key, val in padded.items():
            if key in self.string_keys:
                arr = []
                for ex in val:
                    arr.append([self.vocab.stoi[x] for x in ex])
                tensor[key] = torch.LongTensor(arr).to('cpu')
            elif key in self.tensor_keys:
                tensor[key] = torch.FloatTensor(val).to('cpu')
            else:                
                tensor[key] = val

        return tensor

    def __len__(self):
        return len(self.tensor_list['State'])

    def __getitem__(self, i):
        arr = {k: [] for k in self.keys}
        for key in self.keys:
            arr[key] = self.tensor_list[key][i]
        return arr

In [9]:
train_ds = MyDataset(
    path=os.path.join('..', 'data', 'news', 'text_train.tsv'),
    specials=['<company>', '<organization>', '<person>', '<location>']
)

japanese_fasttext_vectors = Vectors(name='../data/news/cc.ja.300.vec')

train_ds.build_vocab(
    vectors=japanese_fasttext_vectors,
    min_freq=10)

test_ds = MyDataset(
    path=os.path.join('..', 'data', 'news', 'text_test.tsv'),
    vocab=train_ds.vocab
)

In [11]:
train_dl = torch.utils.data.DataLoader(
    train_ds, 
    batch_size = 32, 
    shuffle = True, 
    num_workers = 2
)

test_dl = torch.utils.data.DataLoader(
    test_ds, 
    batch_size = 32, 
    shuffle = False, 
    num_workers = 2
)

In [12]:
parser = argparse.ArgumentParser(description=None)
parser.add_argument('-e', '--env', default='PongNoFrameskip-v4', type=str, help='gym environment')
parser.add_argument('-d', '--density', default=1, type=int, help='density of grid of gaussian blurs')
parser.add_argument('-r', '--radius', default=5, type=int, help='radius of gaussian blur')
parser.add_argument('-f', '--num_frames', default=100, type=int, help='number of frames in movie')
parser.add_argument('-i', '--first_frame', default=150, type=int, help='index of first frame')
parser.add_argument('-dpi', '--resolution', default=75, type=int, help='resolution (dpi)')
parser.add_argument('-s', '--save_dir', default='./movies/', type=str,
                    help='dir to save agent logs and checkpoints')
parser.add_argument('-p', '--prefix', default='default', type=str, help='prefix to help make video name unique')
parser.add_argument('-c', '--checkpoint', default='*.tar', type=str,
                    help='checkpoint name (in case there is more than one')
parser.add_argument('-o', '--overfit_mode', default=False, type=bool,
                    help='analyze an overfit environment (see paper)')

# text parameter
parser.add_argument('--max_length', type=int, default=64)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--target_update_freq', type=int, default=100)
parser.add_argument('--evaluation_freq', type=int, default=10)
parser.add_argument('--network_save_freq', type=int, default=100)
parser.add_argument('--num_actions', type=int, default=1)

parser.add_argument('--min_freq', type=int, default=10)
parser.add_argument('--embedding_dim', type=int, default=300)
parser.add_argument('--n_filters', type=int, default=50)
parser.add_argument('--filter_sizes', type=list, default=[3, 4, 5])
parser.add_argument('--pad_idx', type=list, default=1)
parser.add_argument('--gamma', type=float, default=0.0)
parser.add_argument('--learning_rate', type=float, default=2.5e-5)
parser.add_argument('--round', type=float, default=0)

parser.add_argument('--num_quantile', type=int, default=32)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--device', type=str, default=device)

args = parser.parse_args(args=[])

args.rnn = False

In [13]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

text_vectors = train_ds.vocab.vectors
vocab_size = len(train_ds.vocab.vectors)
model = IQN(text_vectors, vocab_size, args.embedding_dim, args.n_filters,
                         args.filter_sizes, args.pad_idx,
                         n_actions=args.num_actions,
                         n_quant=args.num_quantile,
                         rnn=args.rnn)

model = model.to(device)

trainer = Trainer(args, text_vectors, vocab_size, train_dl)
trainer.load_model()
trainer.model.eval()

RuntimeError: Error(s) in loading state_dict for IQN:
	size mismatch for embedding.weight: copying a param with shape torch.Size([2638, 300]) from checkpoint, the shape in current model is torch.Size([1493, 300]).