In [None]:
!pip3 install optuna
!pip3 install transformers

In [1]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
import optuna

In [2]:
import transformers
from transformers import BertTokenizer
token = BertTokenizer.from_pretrained('bert-base-chinese')

In [3]:
textpath = '/openbayes/input/input0/text'
outpath = '/openbayes/input/input0/text_encoded'
def get_list(datapath, outpath):
    # 获取待处理的list
    try:
        datalist = os.listdir(datapath)
        donelist = os.listdir(outpath)
        tohandle_list = [x for x in datalist if x not in donelist]
        tohandle_list = [x for x in tohandle_list if x[-4:] == '.csv']
        return tohandle_list, donelist
    except:
        return False

In [4]:
tohandle_list, donelist = get_list(textpath, outpath)

In [5]:
print(len(tohandle_list), len(donelist))

80

In [7]:
import torch
from transformers import BertModel

class MLP(nn.Module):
    def __init__(self, input_size, hidden_layer_sizes, output_size):
        super(MLP, self).__init__()
        # 初始化权重
        self.initializer = nn.init.kaiming_normal_
        layers = [nn.Linear(input_size, hidden_layer_sizes[0])]  # 第一个隐藏层
        # 添加ReLU激活函数
        layers.append(nn.ReLU())
        # 添加更多的隐藏层
        for i in range(1, len(hidden_layer_sizes)):
            layers.append(nn.Linear(hidden_layer_sizes[i-1], hidden_layer_sizes[i]))
            layers.append(nn.ReLU())
        # 最后一个隐藏层到输出层
        layers.append(nn.Linear(hidden_layer_sizes[-1], output_size))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)
    
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                self.initializer(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

class CommentBertEncode(object):
    def __init__(self, textpath, outpath, stockname, num=500):
        self.outpath = outpath
        self.textpath = textpath
        self.stockname = stockname
        self.logger = None
        self.token = BertTokenizer.from_pretrained('bert-base-chinese')
        self.device = 'cpu'
        if torch.cuda.is_available():
            self.device = 'cuda'
        self.num = num

    def logger_config(self, log_path, logging_name):
        '''
        配置log
        :param log_path: 输出log路径
        :param logging_name: 记录中name，可随意
        :return:
        '''
        '''
        logger是日志对象，handler是流处理器，console是控制台输出（没有console也可以，将不会在控制台输出，会在日志文件中输出）
        '''
        # 获取logger对象,取名
        logger = logging.getLogger(logging_name)
        # 输出DEBUG及以上级别的信息，针对所有输出的第一层过滤
        logger.setLevel(level=logging.DEBUG)
        if not logger.handlers:
            # 获取文件日志句柄并设置日志级别，第二层过滤
            handler = logging.FileHandler(log_path, encoding='UTF-8')
            handler.setLevel(logging.INFO)
            # 生成并设置文件日志格式
            formatter = logging.Formatter('%(message)s - %(levelname)s - %(asctime)s - %(name)s')
            handler.setFormatter(formatter)
            # 为logger对象添加句柄
            logger.addHandler(handler)
            # logger.addHandler(console)
        return logger

    
        
    def get_text(self):
        # 获取文本
        text_df = pd.read_csv(os.path.join(self.textpath,self.stockname))
        text_l = text_df['text'].tolist()
        
        return text_l
    
    def collate_fn(self, data):
        #编码
        datas = self.token.batch_encode_plus(batch_text_or_text_pairs=data,
                                       truncation=True,
                                       padding='max_length',
                                       max_length=100,
                                       return_tensors='pt',
                                       return_length=True)

        #input_ids:编码之后的数字
        #attention_mask:是补零的位置是0,其他位置是1
        input_ids = datas['input_ids']
        attention_mask = datas['attention_mask']
        token_type_ids = datas['token_type_ids']

        return input_ids, attention_mask, token_type_ids
    
    def encode_text(self, input_ids, attention_mask, token_type_ids):
        num = self.num
        device = self.device
        input_size = 768  # 输入特征维度
        hidden_layer_sizes = [188, 158, 79]  # 假设有3个隐藏层
        output_size = 2  # 假设是2分类问题
        learning_rate = 0.00020148404515308536
        optimizer ="Adam"
        # 创建一个新的模型实例
        new_model = MLP(input_size, hidden_layer_sizes, output_size)
        # 加载状态字典
        new_model.load_state_dict(torch.load('best_model_state_dict.pth'))
        # 确保模型在相同的设备上
        new_model.to(device)
        new_model.eval()

        for i in range(0, len(input_ids), num):
            # 加载预训练模型
            pretrained = BertModel.from_pretrained('bert-base-chinese')
            # 不训练预训练模型,不需要计算梯度
            for param in pretrained.parameters():
                param.requires_grad_(False)
            pretrained.to(device)
            
            if len(input_ids) - i > num:
                input_ids1 = input_ids[i:i+num].to(device)
                attention_mask1 = attention_mask[i:i+num].to(device)
                token_type_ids1 = token_type_ids[i:i+num].to(device)
            else:
                input_ids1 = input_ids[i:].to(device)
                attention_mask1 = attention_mask[i:].to(device)
                token_type_ids1 = token_type_ids[i:].to(device)
            out = pretrained(input_ids1, attention_mask1, token_type_ids1)
            out = out.last_hidden_state[:, 0].cpu().detach().numpy().tolist()
            
            X_tensor = torch.tensor(out, dtype=torch.float32)
            predictions = new_model(X_tensor.to(device))
            # predictions = new_model(X_tensor.to(device)).argmax(dim=1).cpu().tolist()
            # 使用 softmax 获取概率分布
            probabilities = F.softmax(predictions, dim=1).cpu().tolist()
            if i == 0:
                prob = probabilities
            else:     
                prob.extend(probabilities)
            # input_ids1.cpu() 
            # attention_mask1.cpu() 
            # token_type_ids1.cpu()
            # pretrained.cpu()
            del input_ids1, attention_mask1, token_type_ids1, out, pretrained
            torch.cuda.empty_cache()
            allocated_memory = torch.cuda.memory_allocated()
            cached_memory = torch.cuda.memory_reserved()
            # print(f"已分配的GPU内存：{allocated_memory}, 已缓存的GPU内存：{cached_memory}")
        print(np.shape(prob))
        return prob
        
    def save_data(self, out):
        try:
            np.savetxt(os.path.join(self.outpath, self.stockname), out, delimiter=",")
            return True   
        except:
            return False

    def RunStart(self):
        self.logger = self.logger_config(log_path=os.path.join(self.outpath, 'log.txt'), logging_name='hpz')
        text = self.get_text()
        if text == []:
            self.logger.error(f"{self.stockname}读取失败!")
            return
        self.logger.info(f"{self.stockname}开始处理!")
        try:
            input_ids, attention_mask, token_type_ids = self.collate_fn(text) 
        except:
            self.logger.error(f"{self.stockname}初编码失败!")
            return
        self.logger.info(f"{self.stockname}初编码完成!")
        try:
            out = self.encode_text(input_ids, attention_mask, token_type_ids)
        except:
            self.logger.error(f"{self.stockname}编码失败!")
            return
        self.logger.info(f"{self.stockname}编码完成!")
        if self.save_data(out):
            self.logger.info(f"{self.stockname}处理完成!")
        else:
            self.logger.info(f"{self.stockname}处理失败!")
        del input_ids, attention_mask, token_type_ids, out
        logging.shutdown()
        return

In [8]:
for stockname in tqdm(tohandle_list):
    encode = CommentBertEncode(textpath, outpath, stockname, num=2800)
    encode.RunStart()
    del encode

  0%|          | 0/80 [00:00<?, ?it/s]

(206331, 2)


  1%|▏         | 1/80 [04:02<5:19:35, 242.72s/it]

(385843, 2)


  2%|▎         | 2/80 [11:31<7:52:55, 363.78s/it]

(116237, 2)


  4%|▍         | 3/80 [13:52<5:36:22, 262.11s/it]

(372223, 2)


  5%|▌         | 4/80 [20:48<6:49:13, 323.07s/it]

(439621, 2)


  6%|▋         | 5/80 [29:03<8:01:14, 385.00s/it]

(165102, 2)


  8%|▊         | 6/80 [32:15<6:33:56, 319.42s/it]

(106554, 2)


  9%|▉         | 7/80 [34:16<5:09:27, 254.35s/it]

(134945, 2)


 10%|█         | 8/80 [36:50<4:27:03, 222.55s/it]

(416861, 2)


 11%|█▏        | 9/80 [44:34<5:52:28, 297.86s/it]

(408890, 2)


 12%|█▎        | 10/80 [52:11<6:44:56, 347.09s/it]

(633387, 2)


 14%|█▍        | 11/80 [1:04:02<8:47:12, 458.44s/it]

(369209, 2)


 15%|█▌        | 12/80 [1:10:55<8:23:55, 444.64s/it]

(280677, 2)


 18%|█▊        | 14/80 [1:17:08<5:30:00, 300.00s/it]

(42675, 2)
(149327, 2)


 19%|█▉        | 15/80 [1:19:58<4:42:26, 260.71s/it]

(138365, 2)


 20%|██        | 16/80 [1:22:41<4:06:41, 231.28s/it]

(99715, 2)


 21%|██▏       | 17/80 [1:24:38<3:26:52, 197.03s/it]

(216892, 2)


 22%|██▎       | 18/80 [1:28:45<3:38:55, 211.86s/it]

(86041, 2)


 24%|██▍       | 19/80 [1:30:23<3:00:33, 177.60s/it]

(392470, 2)


 25%|██▌       | 20/80 [1:38:00<4:21:29, 261.49s/it]

(165060, 2)


 26%|██▋       | 21/80 [1:41:11<3:56:22, 240.38s/it]

(181074, 2)


 28%|██▊       | 22/80 [1:44:34<3:41:30, 229.15s/it]

(158455, 2)


 29%|██▉       | 23/80 [1:47:31<3:22:48, 213.48s/it]

(184909, 2)


 30%|███       | 24/80 [1:51:00<3:18:09, 212.31s/it]

(154999, 2)


 31%|███▏      | 25/80 [1:53:56<3:04:37, 201.41s/it]

(262476, 2)


 32%|███▎      | 26/80 [1:58:51<3:26:34, 229.52s/it]

(76108, 2)


 34%|███▍      | 27/80 [2:00:34<2:49:06, 191.44s/it]

(393848, 2)


 35%|███▌      | 28/80 [2:08:10<3:54:46, 270.90s/it]

(119767, 2)


 36%|███▋      | 29/80 [2:10:29<3:16:37, 231.33s/it]

(79896, 2)


 38%|███▊      | 30/80 [2:12:03<2:38:22, 190.04s/it]

(255575, 2)


 39%|███▉      | 31/80 [2:16:58<3:00:54, 221.51s/it]

(762238, 2)


 40%|████      | 32/80 [2:31:57<5:39:43, 424.65s/it]

(88906, 2)


 41%|████▏     | 33/80 [2:33:42<4:17:31, 328.77s/it]

(87220, 2)


 42%|████▎     | 34/80 [2:35:30<3:21:26, 262.75s/it]

(137416, 2)


 44%|████▍     | 35/80 [2:38:07<2:53:10, 230.91s/it]

(156976, 2)


 45%|████▌     | 36/80 [2:41:07<2:38:10, 215.70s/it]

(219522, 2)


 46%|████▋     | 37/80 [2:45:14<2:41:17, 225.07s/it]

(152449, 2)


 48%|████▊     | 38/80 [2:48:05<2:26:07, 208.75s/it]

(60780, 2)


 49%|████▉     | 39/80 [2:49:17<1:54:42, 167.87s/it]

(92077, 2)


 50%|█████     | 40/80 [2:51:03<1:39:25, 149.15s/it]

(175606, 2)


 51%|█████▏    | 41/80 [2:54:20<1:46:24, 163.70s/it]

(99934, 2)


 52%|█████▎    | 42/80 [2:56:15<1:34:23, 149.03s/it]

(128826, 2)


 54%|█████▍    | 43/80 [2:58:45<1:32:08, 149.41s/it]

(94643, 2)


 55%|█████▌    | 44/80 [3:00:33<1:22:04, 136.80s/it]

(80008, 2)


 56%|█████▋    | 45/80 [3:02:06<1:12:15, 123.86s/it]

(338522, 2)


 57%|█████▊    | 46/80 [3:09:01<1:59:39, 211.16s/it]

(97941, 2)


 59%|█████▉    | 47/80 [3:10:52<1:39:34, 181.04s/it]

(262911, 2)


 60%|██████    | 48/80 [3:15:56<1:56:14, 217.95s/it]

(289716, 2)


 61%|██████▏   | 49/80 [3:21:29<2:10:24, 252.42s/it]

(151030, 2)


 62%|██████▎   | 50/80 [3:24:21<1:54:08, 228.29s/it]

(151459, 2)


 64%|██████▍   | 51/80 [3:27:14<1:42:15, 211.57s/it]

(647776, 2)


 65%|██████▌   | 52/80 [3:39:45<2:54:22, 373.68s/it]

(621710, 2)


 66%|██████▋   | 53/80 [3:51:35<3:33:30, 474.45s/it]

(211615, 2)


 68%|██████▊   | 54/80 [3:55:35<2:55:08, 404.18s/it]

(115705, 2)


 69%|██████▉   | 55/80 [3:57:53<2:15:04, 324.16s/it]

(111848, 2)


 70%|███████   | 56/80 [4:00:04<1:46:27, 266.16s/it]

(95169, 2)


 71%|███████▏  | 57/80 [4:01:52<1:23:53, 218.85s/it]

(240126, 2)


 72%|███████▎  | 58/80 [4:06:22<1:25:51, 234.17s/it]

(130912, 2)


 74%|███████▍  | 59/80 [4:08:51<1:13:02, 208.69s/it]

(83293, 2)


 75%|███████▌  | 60/80 [4:10:31<58:39, 175.96s/it]  

(108419, 2)


 76%|███████▋  | 61/80 [4:12:48<52:00, 164.22s/it]

(181737, 2)


 78%|███████▊  | 62/80 [4:16:27<54:14, 180.78s/it]

(187054, 2)


 79%|███████▉  | 63/80 [4:20:12<55:00, 194.16s/it]

(150260, 2)


 80%|████████  | 64/80 [4:23:13<50:43, 190.23s/it]

(103757, 2)


 81%|████████▏ | 65/80 [4:25:25<43:10, 172.71s/it]

(115211, 2)


 82%|████████▎ | 66/80 [4:27:48<38:10, 163.64s/it]

(64278, 2)


 84%|████████▍ | 67/80 [4:29:05<29:48, 137.59s/it]

(159496, 2)


 85%|████████▌ | 68/80 [4:32:17<30:49, 154.13s/it]

(92141, 2)


 86%|████████▋ | 69/80 [4:34:13<26:09, 142.67s/it]

(246029, 2)


 88%|████████▊ | 70/80 [4:39:18<31:54, 191.45s/it]

(164589, 2)


 89%|████████▉ | 71/80 [4:42:42<29:15, 195.04s/it]

(619102, 2)


 90%|█████████ | 72/80 [4:56:18<50:50, 381.26s/it]

(177380, 2)


 92%|█████████▎| 74/80 [5:00:54<25:02, 250.48s/it]

(54526, 2)
(151700, 2)


 95%|█████████▌| 76/80 [5:05:15<12:17, 184.45s/it]

(65470, 2)
(98794, 2)


 96%|█████████▋| 77/80 [5:07:11<08:11, 163.96s/it]

(138572, 2)


 99%|█████████▉| 79/80 [5:11:05<02:16, 136.42s/it]

(62842, 2)
(71140, 2)


100%|██████████| 80/80 [5:12:30<00:00, 234.38s/it]
