In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [4]:
class STModel(nn.Module):
    
    def __init__(self, emd_dim, vocab_size, classes_num, hidden_num, layer_num, embeddings = None, padding_idx = 0,dropout=0.5):
        super(STModel, self).__init__()
        self.encoder = nn.Embedding(vocab_size, emd_dim, padding_idx= padding_idx, _weight = embeddings)
        self.drop = nn.Dropout(dropout)
        self.lstm = nn.LSTM(emd_dim, hidden_num, layer_num, dropout= dropout,
                            bidirectional= False)
        self.fc = nn.Linear(emd_dim, classes_num)
        self.init_weights()
        self.sigm = nn.Sigmoid()
        self.hidden_num = hidden_num
        
    def forward(self, x, hidden):
        batch_size = x.size(0)
        emb = self.drop(self.encoder(x))
        lstm_out, hidden = self.lstm(emb, hidden)
        lstm_out = self.drop(lstm_out) #[seq_len, batch, num_directions * hidden_size];
        fc_input = lstm_out[-1]
        fc_out = self.fc(lstm_out)
        sig_out = self.sigm(fc_out)
        pred = sig_out.view(batch_size, -1)
        return sig_out
    
    def init_weights(self):
        initrange = 0.1
        #self.encoder.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
        self.fc.weight.data.uniform_(-initrange, initrange)
        
    def init_hidden(self, batch_size):
        weight = next(self.parameters())
        return (weight.new_zeros(self.layer_num, batch_size, self.hidden_num),
                    weight.new_zeros(self.layer_num, batch_size, self.hidden_num))
        

In [15]:
emd_dim = 300
vocab_size = 100
classes_num = 3
hidden_num = 100
layer_num = 2
dropout = 0.5
net = STModel(emd_dim, vocab_size, classes_num, hidden_num, layer_num, dropout=0.5)

In [16]:
print(net)

STModel(
  (encoder): Embedding(100, 300)
  (drop): Dropout(p=0.5, inplace=False)
  (lstm): LSTM(300, 100, num_layers=2, dropout=0.5)
  (fc): Linear(in_features=300, out_features=3, bias=True)
  (sigm): Sigmoid()
)


In [18]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [19]:
import torch.optim as optim

SAVE_PATH = 'STMmodel.pt'
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

clip = 5
epochs = 5


cpu


In [None]:
for epoch in range(epochs):
    running_loss = 0.0
    batch_count = 0
    for i, train_data in enumerate(trainloader):
        train_input, train_laebl = train_data
        # forward
        hidden = model.init_hidden(batch_size) # 有问题
        model.zero_grad()
        outputs = model(train_input, hidden)
        loss = criterion(outputs, labels)
        # backward
        loss.backward()
        # optimize
        optimizer.step()
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0
            torch.save(net.state_dict(), PATH)
            with open(SAVE_PATH, 'wb') as f:
                torch.save(model, f)
print('finished training')

In [1]:
import pandas as pd

In [3]:
data_train = pd.read_csv('data/train_weibo_clean.csv')


In [6]:
data_train

Unnamed: 0.1,Unnamed: 0,微博id,微博发布时间,发布人账号,微博中文内容,微博图片,微博视频,情感倾向
0,0,4456072029125500,01月01日 23:50,存曦1988,写在年末冬初孩子流感的第五天，我们仍然没有忘记热情拥抱这2020年的第一天。带着一丝迷信，早...,['https://ww2.sinaimg.cn/orj360/005VnA1zly1gah...,[],0
1,1,4456074167480980,01月01日 23:58,LunaKrys,开年大模型…累到以为自己发烧了腰疼膝盖疼腿疼胳膊疼脖子疼#Luna的Krystallife#?,[],[],-1
2,2,4456054253264520,01月01日 22:39,小王爷学辩论o_O,邱晨这就是我爹，爹，发烧快好，毕竟美好的假期拿来养病不太好，假期还是要好好享受快乐，爹，新...,['https://ww2.sinaimg.cn/thumb150/006ymYXKgy1g...,[],1
3,3,4456061509126470,01月01日 23:08,芩鎟,新年的第一天感冒又发烧的也太衰了但是我要想着明天一定会好的?,['https://ww2.sinaimg.cn/orj360/005FL9LZgy1gah...,[],1
4,4,4455979322528190,01月01日 17:42,changlwj,问：我们意念里有坏的想法了，天神就会给记下来，那如果有好的想法也会被记下来吗？答：那当然了。...,[],[],1
...,...,...,...,...,...,...,...,...
99908,99995,4473033438259880,02月17日 19:08,中国教育新闻网,#抗击新型肺炎第一线#【@中国计量大学研制新冠病毒蛋白标准样品】记者从中国计量大学获悉，新型...,['https://ww1.sinaimg.cn/orj360/682cebefly1gbz...,[],0
99909,99996,4472969222714290,02月17日 14:53,fuzhuoting,1、类RaTG13病毒（一种从云南蝙蝠身上分离出来的冠状病毒）可能是2019-nCoV的源头...,[],[],0
99910,99997,4473035904435920,02月17日 19:18,蝌蚪五线谱,#微博辟谣#没有证据表明，吃大蒜、漱口水、涂抹芝麻油、生理盐水洗鼻子等手段可以防止感染新型冠...,['https://ww4.sinaimg.cn/orj360/6d2cc4e6ly1gbz...,[],0
99911,99998,4472950743017610,02月17日 13:40,医库,【新冠疫情最受关注的十一篇英文核心期刊论文全解析】本文整理了关于新型冠状病毒最受关注的十一篇...,[],[],1


In [2]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/13/33/ffb67897a6985a7b7d8e5e7878c3628678f553634bd3836404fef06ef19b/transformers-2.5.1-py3-none-any.whl (499kB)
[K     |████████████████████████████████| 501kB 42kB/s eta 0:00:011
[?25hCollecting tokenizers==0.5.2 (from transformers)
[?25l  Downloading https://files.pythonhosted.org/packages/d6/e3/5e49e9a83fb605aaa34a1c1173e607302fecae529428c28696fb18f1c2c9/tokenizers-0.5.2-cp37-cp37m-manylinux1_x86_64.whl (5.6MB)
[K     |████████████████████████████████| 5.6MB 11kB/s eta 0:00:0163
[?25hCollecting sentencepiece (from transformers)
[?25l  Downloading https://files.pythonhosted.org/packages/11/e0/1264990c559fb945cfb6664742001608e1ed8359eeec6722830ae085062b/sentencepiece-0.1.85-cp37-cp37m-manylinux1_x86_64.whl (1.0MB)
[K     |████████████████████████████████| 1.0MB 11kB/s eta 0:00:016
Collecting sacremoses (from transformers)
[?25l  Downloading https://files.pythonhosted.org/packages/a6/b4/7a41d6305