In [27]:
'''
  code by Tae Hwan Jung(Jeff Jung) @graykode, modify by wmathor
'''
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
import torch.nn.functional as F
import collections
dtype = torch.FloatTensor
device = torch.device("mps")

In [28]:
# 3 words sentences (=sequence_length is 3)
sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
labels = [1, 1, 1, 0, 0, 0]  # 1 is good, 0 is not good.


In [29]:
tokens = [sentence.split() for sentence in sentences]

In [30]:
def get_tokens(tokens):
    tokens =  [token for line in tokens for token in line]
    return tokens

In [31]:
tokens = get_tokens(tokens)

In [32]:
def count_corpus(tokens):
    return collections.Counter(tokens)

In [33]:
class Vocab:
    def __init__(self,tokens,min_freq=0):
        self.id2tokens = []
        self.token2ids = {}
        for token,freq in count_corpus(tokens).items():
            if freq<min_freq:
                break
            if token not in self.token2ids:
                self.id2tokens.append(token)
                self.token2ids[token] = len(self.id2tokens)
    def __len__(self):
        return len(self.id2tokens)
    def __getitem__(self,tokens):
        return [self.token2ids.get(token,self.unk) for token in tokens]
    @property
    def unk(self):  # 未知词元的索引为0
        return 0 
    @property
    def token_freqs(self):
        return self._token_freqs

In [34]:
vocab = Vocab(tokens)

In [35]:
tokens = [sentence.split() for sentence in sentences]
features = [vocab[token] for token in tokens]
features, targets = torch.LongTensor(features), torch.LongTensor(labels)
dataset = torch.utils.data.TensorDataset(features,targets)
dataloader =  torch.utils.data.DataLoader(dataset, 3, True)

In [36]:
for x,y in dataloader:
    print(x)
    print(y)

tensor([[ 4,  5,  6],
        [ 7,  8,  9],
        [11, 12, 13]])
tensor([1, 1, 0])
tensor([[ 1, 10,  3],
        [14, 15, 16],
        [ 1,  2,  3]])
tensor([0, 0, 1])


In [37]:
class TextCNN(torch.nn.Module):
    def __init__(self,vocab_size, embedding_size):
        super(TextCNN, self).__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embedding_size)
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(1,3,kernel_size=(2,embedding_size)),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d((2, 1)),
            torch.nn.Flatten()
        )
        self.fc = torch.nn.Linear(3,2)
    
    def forward(self,X):
        batch_size = X.shape[0]
        embedding = self.embedding(X)
        embedding = embedding.unsqueeze(1)
        conv = self.conv(embedding)
        print(conv.shape)
#         conv = conv.flatten()
#         conv = conv.view(batch_size, -1)
#         print(conv.shape)
        return self.fc(conv)
    

In [38]:
conv = torch.nn.Sequential(
            torch.nn.Conv2d(1,3,kernel_size=(2,2)),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d((2, 1)),
        )
X = torch.randn(1,1,3,2)
for blk in conv:
    X = blk(X)
    print(blk.__class__.__name__,X.shape)

Conv2d torch.Size([1, 3, 2, 1])
ReLU torch.Size([1, 3, 2, 1])
MaxPool2d torch.Size([1, 3, 1, 1])


In [39]:
device = torch.device('mps')
model = TextCNN(16,2).to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.1)
# # Training
for epoch in range(20):
    for batch_x, batch_y in dataloader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        pred = model(batch_x)
        loss = criterion(pred, batch_y)
        #     if (epoch + 1) % 1000 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

torch.Size([3, 3])
Epoch: 0001 loss = 0.862486
torch.Size([3, 3])
Epoch: 0001 loss = 0.731477
torch.Size([3, 3])
Epoch: 0002 loss = 0.732072
torch.Size([3, 3])
Epoch: 0002 loss = 0.481238
torch.Size([3, 3])
Epoch: 0003 loss = 0.536991
torch.Size([3, 3])
Epoch: 0003 loss = 0.443261
torch.Size([3, 3])
Epoch: 0004 loss = 0.408995
torch.Size([3, 3])
Epoch: 0004 loss = 0.337583
torch.Size([3, 3])
Epoch: 0005 loss = 0.291860
torch.Size([3, 3])
Epoch: 0005 loss = 0.132564
torch.Size([3, 3])
Epoch: 0006 loss = 0.160632
torch.Size([3, 3])
Epoch: 0006 loss = 0.061674
torch.Size([3, 3])
Epoch: 0007 loss = 0.012779
torch.Size([3, 3])
Epoch: 0007 loss = 0.124808
torch.Size([3, 3])
Epoch: 0008 loss = 0.084493
torch.Size([3, 3])
Epoch: 0008 loss = 0.003014
torch.Size([3, 3])
Epoch: 0009 loss = 0.001890
torch.Size([3, 3])
Epoch: 0009 loss = 0.011953
torch.Size([3, 3])
Epoch: 0010 loss = 0.004757
torch.Size([3, 3])
Epoch: 0010 loss = 0.002494
torch.Size([3, 3])
Epoch: 0011 loss = 0.001838
torch.Size([3