In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification
import json
from tqdm import tqdm
import numpy as np
from collections import deque, defaultdict
import time

# train_dataset_file_path = "D:/db/news/THUCNews/train_3.json"
# test_dataset_file_path = "D:/db/news/THUCNews/test_3.json"
train_dataset_file_path = "./train_3.json"
test_dataset_file_path = "./test_3.json"
transformer_pretrained_path = "../hfl/chinese-roberta-wwm-ext"

batch_size = 256
yamma = 3

labels = ['体育', '娱乐', '家居', '彩票', '房产', '教育', '时尚', '时政', '星座', '游戏', '社会', '科技', '股票', '财经']
cls_id_dict = {x: i for (i, x) in enumerate(labels)}

labels_parsed = ["时政", "财经", "房产", "科技"]


class TextPreprocessor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained(transformer_pretrained_path)

    def preprocess(self, input_title: str):
        input_encoding = self.tokenizer(input_title,
                                        padding="max_length", truncation=True,
                                        max_length=32, return_tensors="pt")
        input_encoding = {k: v.squeeze() for k, v in input_encoding.items()}
        return input_encoding


class NewsDataset(Dataset):
    def __init__(self, train: bool):
        self.train = train
        self.text_preprocessor = TextPreprocessor()
        self.dataset_file_path = train_dataset_file_path if self.train else test_dataset_file_path
        with open(self.dataset_file_path, encoding="utf-8") as file_:
            self.dataset = json.load(file_)
        np.random.shuffle(self.dataset)
        self.dataset_length = len(self.dataset)
        print(f"load {'train' if self.train else 'test'} dataset size: {self.dataset_length}")

    def __getitem__(self, index):
        sample = self.dataset[index]
        sample_title = sample["title"]
        sample_field_name = sample["label"]
        label_parsed = "None"
        if sample_field_name == "时政":
            label_parsed = "时政"
        elif sample_field_name in ["股票", "财经"]:
            label_parsed = "财经"
        elif sample_field_name == "房产":
            label_parsed = "房产"
        elif sample_field_name == "科技":
            label_parsed = "科技"

        y = np.array([x == label_parsed for x in labels_parsed], dtype=int)
        input_encoding = self.text_preprocessor.preprocess(sample_title)
        return input_encoding, y

    def __len__(self):
        return self.dataset_length


def train():
    time.sleep(0.2)
    model.train()
    train_loss = deque([], maxlen=100)
    TP_count = defaultdict(int)
    FP_count = defaultdict(int)
    FN_count = defaultdict(int)
    TN_count = defaultdict(int)
    pbar = tqdm(dataloader_train)
    pbar.set_description("train epoch {}".format(epoch))
    for input_encoding, y_target in pbar:
        optimizer.zero_grad()
        input_encoding = {k: v.to(device) for k, v in input_encoding.items()}
        y_target = y_target.to(device)
        with torch.cuda.amp.autocast():
            y_predict = model(**input_encoding)[0]

            bce_loss = F.binary_cross_entropy_with_logits(y_predict, y_target.float(), reduction='none')
            focal_loss = torch.pow(1 - torch.exp(-bce_loss), yamma) * bce_loss
            loss = torch.mean(focal_loss)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        train_loss.append(loss.item())

        y_predict = torch.sigmoid(y_predict)
        for i, label_str in enumerate(labels_parsed):
            y_predict_label = torch.gt(y_predict[..., i], 0.5)
            y_target_label = torch.eq(y_target[..., i], 1)
            TP_count[label_str] += torch.logical_and(y_predict_label, y_target_label).sum().item()
            FP_count[label_str] += torch.logical_and(y_predict_label, torch.logical_not(y_target_label)).sum().item()
            FN_count[label_str] += torch.logical_and(torch.logical_not(y_predict_label), y_target_label).sum().item()
            TN_count[label_str] += torch.logical_and(torch.logical_not(y_predict_label), torch.logical_not(y_target_label)).sum().item()

        log_str = "loss={}".format(np.mean(train_loss))
        pbar.set_postfix_str(log_str)
    for i, label_str in enumerate(labels_parsed):
        nums = TP_count[label_str] + FN_count[label_str]
        precision = TP_count[label_str] / (TP_count[label_str] + FP_count[label_str] + 1e-5)
        recall = TP_count[label_str] / (TP_count[label_str] + FN_count[label_str] + 1e-5)
        f1 = (2 * precision * recall) / (precision + recall + 1e-5)
        print(f"label {label_str}: precision={precision}, recall={recall}, f1={f1}, nums={nums}")


def test():
    time.sleep(0.2)
    model.eval()
    test_loss = []
    TP_count = defaultdict(int)
    FP_count = defaultdict(int)
    FN_count = defaultdict(int)
    TN_count = defaultdict(int)
    pbar = tqdm(dataloader_test)
    pbar.set_description("test epoch {}".format(epoch))
    for input_encoding, y_target in pbar:
        input_encoding = {k: v.to(device) for k, v in input_encoding.items()}
        y_target = y_target.to(device)
        y_predict = model(**input_encoding)[0]

        bce_loss = F.binary_cross_entropy_with_logits(y_predict, y_target.float(), reduction='none')
        focal_loss = torch.pow(1 - torch.exp(-bce_loss), yamma) * bce_loss
        loss = torch.mean(focal_loss)
        test_loss.append(loss.item())

        y_predict = torch.sigmoid(y_predict)
        for i, label_str in enumerate(labels_parsed):
            y_predict_label = torch.gt(y_predict[..., i], 0.5)
            y_target_label = torch.eq(y_target[..., i], 1)

            TP_count[label_str] += torch.logical_and(y_predict_label, y_target_label).sum().item()
            FP_count[label_str] += torch.logical_and(y_predict_label, torch.logical_not(y_target_label)).sum().item()
            FN_count[label_str] += torch.logical_and(torch.logical_not(y_predict_label), y_target_label).sum().item()
            TN_count[label_str] += torch.logical_and(torch.logical_not(y_predict_label), torch.logical_not(y_target_label)).sum().item()

        log_str = f"loss={np.mean(test_loss)}"
        pbar.set_postfix_str(log_str)

    for i, label_str in enumerate(labels_parsed):
        nums = TP_count[label_str] + FN_count[label_str]
        precision = TP_count[label_str] / (TP_count[label_str] + FP_count[label_str] + 1e-5)
        recall = TP_count[label_str] / (TP_count[label_str] + FN_count[label_str] + 1e-5)
        f1 = (2 * precision * recall) / (precision + recall + 1e-5)
        print(f"label {label_str}: precision={precision}, recall={recall}, f1={f1}, nums={nums}")


if __name__ == '__main__':
    dataset_train = NewsDataset(train=True)
    dataloader_train = DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True, num_workers=12, pin_memory=True)
    dataset_test = NewsDataset(train=False)
    dataloader_test = DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=False, num_workers=12, pin_memory=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = BertForSequenceClassification.from_pretrained(transformer_pretrained_path, num_labels=4)
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(100):
        train()
        test()
        torch.save(model.state_dict(), f"./model_6/model_{epoch}.pth")
