In [1]:
'''
  code by Tae Hwan Jung(Jeff Jung) @graykode, modify by wmathor
'''
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
import torch.nn.functional as F
import collections
from d2l import torch as d2l
dtype = torch.FloatTensor
device = torch.device("mps")

In [2]:
# 3 words sentences (=sequence_length is 3)
sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
labels = [1, 1, 1, 0, 0, 0]  # 1 is good, 0 is not good.


In [3]:
tokens = [sentence.split() for sentence in sentences]

In [4]:
def get_tokens(tokens):
    tokens =  [token for line in tokens for token in line]
    return tokens

In [5]:
tokens = get_tokens(tokens)

In [6]:
def count_corpus(tokens):
    return collections.Counter(tokens)

In [7]:
class Vocab:
    def __init__(self,tokens,min_freq=0):
        self.id2tokens = []
        self.token2ids = {}
        for token,freq in count_corpus(tokens).items():
            if freq<min_freq:
                break
            if token not in self.token2ids:
                self.id2tokens.append(token)
                self.token2ids[token] = len(self.id2tokens)
    def __len__(self):
        return len(self.id2tokens)
    def __getitem__(self,tokens):
        return [self.token2ids.get(token,self.unk) for token in tokens]
    @property
    def unk(self):  # 未知词元的索引为0
        return 0 
    @property
    def token_freqs(self):
        return self._token_freqs

In [8]:
vocab = Vocab(tokens)

In [9]:
tokens = [sentence.split() for sentence in sentences]
features = [vocab[token] for token in tokens]
features, targets = torch.LongTensor(features), torch.LongTensor(labels)
dataset = torch.utils.data.TensorDataset(features,targets)
dataloader =  torch.utils.data.DataLoader(dataset, 3, True)

In [10]:
class TextCNN(nn.Module):
    def __init__(self, vocab_size, embed_size, kernel_sizes, num_channels,
                 **kwargs):
        super(TextCNN, self).__init__(**kwargs)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        # 这个嵌入层不需要训练
        self.constant_embedding = nn.Embedding(vocab_size, embed_size)
        self.dropout = nn.Dropout(0.5)
        self.decoder = nn.Linear(sum(num_channels), 2)
        # 最大时间汇聚层没有参数，因此可以共享此实例
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.relu = nn.ReLU()
        # 创建多个一维卷积层
        self.convs = nn.ModuleList()
        for c, k in zip(num_channels, kernel_sizes):
            self.convs.append(nn.Conv1d(2 * embed_size, c, k))

    def forward(self, inputs):
        # 沿着向量维度将两个嵌入层连结起来，
        # 每个嵌入层的输出形状都是（批量大小，词元数量，词元向量维度）连结起来
        embeddings = torch.cat((
            self.embedding(inputs), self.constant_embedding(inputs)), dim=2)
        # 根据一维卷积层的输入格式，重新排列张量，以便通道作为第2维
        embeddings = embeddings.permute(0, 2, 1)
        # 每个一维卷积层在最大时间汇聚层合并后，获得的张量形状是（批量大小，通道数，1）
        # 删除最后一个维度并沿通道维度连结
        encoding = torch.cat([
            torch.squeeze(self.relu(self.pool(conv(embeddings))), dim=-1)
            for conv in self.convs], dim=1)
        outputs = self.decoder(self.dropout(encoding))
        return outputs

In [11]:
embed_size, kernel_sizes, nums_channels = 100, [2, 2, 2], [1, 1, 1]
devices = d2l.try_all_gpus()
net = TextCNN(16, embed_size, kernel_sizes, nums_channels)

def init_weights(m):
    if type(m) in (nn.Linear, nn.Conv1d):
        nn.init.xavier_uniform_(m.weight)

net.apply(init_weights);

In [12]:
device = torch.device('mps')
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.1)
# # Training
for epoch in range(20):
    for batch_x, batch_y in dataloader:
        print(batch_x.shape)
#         break
        pred = net(batch_x)
        loss = criterion(pred, batch_y)
        #     if (epoch + 1) % 1000 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

torch.Size([3, 3])
Epoch: 0001 loss = 1.539051
torch.Size([3, 3])


IndexError: index out of range in self