In [None]:
import copy
import os
import time

import torch
import tqdm
from torch import nn, optim
from torch.utils.tensorboard import SummaryWriter
from transformers import BertConfig, BertForSequenceClassification, BertTokenizer
from transformers import get_linear_schedule_with_warmup

from divide_dataset import mk_dataset_paths
from utils import setup_seed
from utils.compute.metrics import count_metrics_binary_classification
from utils.dir import mk_dir
from utils.early_stopping import LossEarlyStopping
from utils.finish import finish_train
from utils.mk_data_loaders import mk_data_loaders_single_funcs
from utils.records import train_epoch_record
from utils.time import datetime_now_str

In [None]:
# 设置随机种子
setup_seed(2023)
num_labels = 2
hidden_dropout_prob = 0.3
learning_rate = 1e-5
weight_decay = 1e-2
batch_size = 8
epochs = 100
save_interval = 10
log_interval = 1
step_size = 10
gamma = 0.1
dataset_dir_path = r"F:\sunhj\Multi-Modality-SNPCorrelateImage\data\divide\20230724190047"
max_len = 510
dataset_in_memory = True
label_data_id_field_name = None
label_data_label_field_name = None
use_early_stopping = True
early_stopping_step = 7
early_stopping_delta = 0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
UNCASED = r'.\work_dirs\bert\bert-base-uncased'
config = BertConfig.from_pretrained(UNCASED, num_labels=num_labels, output_attentions=False, output_hidden_states=False,
                                    hidden_dropout_prob=hidden_dropout_prob, author="diklios")
config.output_hidden_states = True
config.output_attentions = True
net = BertForSequenceClassification.from_pretrained(UNCASED, config=config)
net.to(device)
tokenizer = BertTokenizer.from_pretrained(os.path.join(UNCASED, 'vocab.txt'), config=config, do_lower_case=True)

In [None]:
# attention:需要使用dataset模块中的方法从原始数据中生成数据集，否则需要自己手动更改以下 dataloader 的各个文件和文件夹路径
data_paths = mk_dataset_paths(dataset_dir_path)
data_loaders_func = mk_data_loaders_single_funcs['BertSNPNet']
data_loaders_func_kwargs = {'data_paths': data_paths, 'batch_size': batch_size, 'tokenizer': tokenizer,
                            'snp_number': max_len}
if dataset_in_memory:
    data_loaders_func_kwargs['in_memory'] = dataset_in_memory
    data_loaders_func_kwargs['persistent_workers'] = True
if label_data_id_field_name:
    data_loaders_func_kwargs['label_data_id_field_name'] = label_data_id_field_name
if label_data_label_field_name:
    data_loaders_func_kwargs['label_data_label_field_name'] = label_data_label_field_name
data_loaders = data_loaders_func(**data_loaders_func_kwargs)

In [None]:
loss_early_stopping = LossEarlyStopping(patience=early_stopping_step, delta=early_stopping_delta)
# 定义优化器和损失函数
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in net.named_parameters() if not any(nd in n for nd in no_decay)],
     'weight_decay': weight_decay},
    {'params': [p for n, p in net.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
# optimizer = optim.AdamW(net.parameters(), lr=learning_rate)
optimizer = optim.AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=1e-8)
# Total number of training steps is number of batches * number of epochs.
total_steps = len(data_loaders['train']) * epochs
# Create the learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=0,  # Default value in run_glue.py
                                            num_training_steps=total_steps)
criterion = nn.CrossEntropyLoss()

In [None]:
records_dir = r'.\work_dirs\bert'
train_dir_prefix = datetime_now_str()
log_dir = os.path.join(records_dir, train_dir_prefix, 'logs')
mk_dir(log_dir)
wts_dir = os.path.join(records_dir, train_dir_prefix, 'wts')
mk_dir(wts_dir)
checkpoints_dir = os.path.join(records_dir, train_dir_prefix, 'checkpoints')
mk_dir(checkpoints_dir)
writer = SummaryWriter(log_dir=log_dir)
best_model_checkpoint_path = os.path.join(checkpoints_dir, 'best_model_checkpoint.pth')
best_model_wts_path = os.path.join(checkpoints_dir, 'best_model_checkpoint.pth')

In [None]:
best_model_wts = copy.deepcopy(net.state_dict())
best_f1 = 0
since = time.time()

for epoch in range(epochs):
    # 训练一次、验证一次
    for phase in ['train', 'valid']:
        if phase == 'train':
            # 训练
            net.train()
        else:
            # 验证
            net.eval()
        # 循环所有数据
        data_loader_iter = tqdm.tqdm(data_loaders[phase])
        y_true, y_pred, y_score = [], [], []
        running_loss = 0.0
        for inputs, labels in data_loader_iter:
            input_ids = inputs[0].to(device)
            attention_mask = inputs[1].to(device)
            labels = labels.to(device)
            y_true += torch.max(labels, dim=-1)[1].int().reshape(-1).tolist()
            # 梯度清零
            optimizer.zero_grad()
            with torch.set_grad_enabled(phase == 'train'):
                outputs = net(input_ids,
                              token_type_ids=None,
                              attention_mask=attention_mask,
                              labels=labels)
            loss = outputs[0]
            if phase == 'train':
                loss.backward()
                optimizer.step()
                scheduler.step()
            running_loss += loss.item()
            # Clip the norm of the gradients to 1.0.
            # This is to help prevent the "exploding gradients" problem.
            torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
            logits = outputs[1]
            logits = logits.detach().cpu().numpy()
            for logit in logits:
                if logit[0] > logit[1]:
                    y_pred.append(0)
                    y_score.append(logit[0])
                else:
                    y_pred.append(1)
                    y_score.append(logit[1])
        # 计算损失
        epoch_loss = running_loss / len(data_loaders[phase].dataset)
        # 计算指标
        all_metrics = count_metrics_binary_classification(y_true, y_pred, y_score)
        # 记录指标
        best_f1, best_model_wts = train_epoch_record(
            epoch_loss, all_metrics, net, optimizer, epoch, epochs, phase,
            writer, log_interval, best_f1, best_model_wts, best_model_checkpoint_path, since)
        # 判断是否早停
        if use_early_stopping and phase == 'valid':
            loss_early_stopping(epoch_loss)

    if epoch % step_size == 0:
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
    if epoch % save_interval == 0:
        torch.save(net.state_dict(), os.path.join(wts_dir, f'epoch_{epoch}_model_wts.pth'))
        torch.save({
            'epoch': epoch,
            'model': net.state_dict(),
            'best_f1': best_f1,
            'optimizer': optimizer.state_dict()
        }, os.path.join(checkpoints_dir, f'epoch_{epoch}_model_checkpoints.pth'))
    if use_early_stopping and loss_early_stopping.early_stop:
        break
finish_train(device, net, data_loaders, writer, best_f1, best_model_wts, best_model_wts_path, since)
