In [2]:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils import data

In [3]:
class Dataset:
    def __init__(self, **datas):
        self.test = datas.get('test', False)
        
        self.title = datas['title']
        self.desc = datas['desc']

        # train为True，返回对应的label
        if self.test is False:
            self.n_classes = datas['class_num']
            self.label = datas['label']

    def __getitem__(self, idx):
        title = torch.from_numpy(self.title[idx])
        desc = torch.from_numpy(self.desc[idx])
        if self.test is False:
            label = torch.zeros(self.n_classes).scatter_(0, torch.from_numpy(self.label[idx]).long(), 1)
            return torch.cat((title, desc)), label
        return torch.cat((title, desc))

    def __len__(self):
        return self.title.shape[0]

In [4]:
word_embed_mat = np.load('../../data_preprocess/embed/word_embed_mat.npy')
train_title_word = np.load('../../data_preprocess/train/train_title_word_indices.npy')
train_desc_word = np.load('../../data_preprocess/train/train_desc_word_indices.npy')
train_label = np.load('../../data_preprocess/train/train_label_indices.npy')

In [5]:
train_dataset = Dataset(title=train_title_word, desc=train_desc_word, label=train_label, class_num=1999)
train_loader = data.DataLoader(train_dataset, shuffle=True, batch_size=64)

In [8]:
for i, datas in enumerate(train_loader):
    print datas[0].size(), datas[1].size()
    break

torch.Size([64, 71]) torch.Size([64, 1999])


In [6]:
val_title_word = np.load('../../data_preprocess/val/val_title_word_indices.npy')
val_desc_word = np.load('../../data_preprocess/val/val_desc_word_indices.npy')

In [7]:
val_dataset = Dataset(test=True, title=val_title_word, desc=val_desc_word)
val_loader = data.DataLoader(val_dataset, batch_size=64)

In [8]:
for i, datas in enumerate(val_loader):
    print datas.size()
    break

torch.Size([64, 71])


In [9]:
train_title_word_all = np.load('../../data_preprocess/train_all/train_title_word_indices_all.npy')
train_desc_word_all = np.load('../../data_preprocess/train_all/train_desc_word_indices_all.npy')
train_label_all = np.load('../../data_preprocess/train_all/train_label_indices_all.npy')

In [10]:
train_dataset_all = Dataset(title=train_title_word_all, desc=train_desc_word_all, label=train_label_all, class_num=1999)
train_loader_all = data.DataLoader(train_dataset_all, batch_size=64)

In [11]:
for i, datas in enumerate(train_loader_all):
    print datas[0].size(), datas[1].size()
    break

torch.Size([64, 71]) torch.Size([64, 1999])


In [12]:
test_title_word = np.load('../../data_preprocess/test/test_title_word_indices.npy')
test_desc_word = np.load('../../data_preprocess/test/test_desc_word_indices.npy')

In [13]:
test_dataset = Dataset(test=True, title=test_title_word, desc=test_desc_word)
test_loader = data.DataLoader(test_dataset, batch_size=64)

In [14]:
for i, datas in enumerate(test_loader):
    print datas.size()
    break

torch.Size([64, 71])
