In [5]:
import torch
import torch.nn as nn
from GPT2 import GPT2Model, GPT2Tokenizer
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
device = 'cpu' #'cuda'

def tokenize_input(inputStr, tokenizer, seq_length=1024):
    pad_id = tokenizer.encoder['<pad>']
    tokenized_sentence = tokenizer.encode(inputStr)[:seq_length-20]
    tokens = tokenized_sentence
    token_length = len(tokens)
    tokens.extend([pad_id] * (seq_length - token_length))
    tokens = torch.tensor(tokens, dtype=torch.long)
    return tokens.reshape(1,1024), [token_length]

tokenizer = GPT2Tokenizer(
    'GPT2/bpe/vocab.json',
    'GPT2/bpe/chinese_vocab.model',
    max_len=512)
    
class GPT2classification(nn.Module):
    def __init__(self):
        super(GPT2classification, self).__init__()
        
        self.GPT2model = GPT2Model(
                            vocab_size=30000,
                            layer_size=12,
                            block_size=1024,
                            embedding_dropout=0.0,
                            embedding_size=768,
                            num_attention_heads=12,
                            attention_dropout=0.0,
                            residual_dropout=0.0)

        self.mlp =  nn.Sequential(
                nn.Linear(30000, 512),
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Linear(256, 3),
            )

    def forward(self, x, length):
        x = self.GPT2model(x)
        classify = []
        for i in range(len(length)):
            classify.append(x[i, length[i]].view(-1))
        classify = torch.stack(classify)
        x = self.mlp(classify)
        return x

model = torch.load("./models/financial_sentiment.pth", 'cpu')
model.eval()
model.to(device)

In [15]:
inputStr = '这股票估计会大跌'  # the text you want to classify

tokens, token_length = tokenize_input(inputStr, tokenizer, seq_length=1024)
output = model(tokens, token_length)
output = torch.softmax(output, dim=1)

print('negative neutral possitive:', output[0].detach().numpy())

negative neutral possitive: [0.74180526 0.1355453  0.12264954]
