In [1]:
import sys
sys.path.append("/workspace/tez")

from pathlib import Path
import torch
import copy
import tez
import os

from torch import nn
from transformers import AdamW, AutoConfig, AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup
from torchcrf import CRF

In [2]:
def get_model_path_list(base_dir):
    """
    从文件夹中获取 model.pt 的路径
    """
    model_lists = [file for file in Path(base_dir).iterdir() if file.stem.startswith("model")]

    return model_lists

In [3]:
class FeedbackModel(tez.Model):
    def __init__(
        self,
        model_name,
        num_labels,
        dynamic_merge_layers,
        decoder="softmax",
        max_len=4096,
        span_num_labels=8
    ):
        super().__init__()
        self.max_len = max_len
        self.dynamic_merge_layers = dynamic_merge_layers
        self.model_name = model_name
        self.num_labels = num_labels
        self.decoder = decoder

        hidden_dropout_prob: float = 0.1
        layer_norm_eps: float = 1e-7

        config = AutoConfig.from_pretrained(model_name)

        config.update(
            {
                "output_hidden_states": True,
                "hidden_dropout_prob": hidden_dropout_prob,
                "layer_norm_eps": layer_norm_eps,
                "add_pooling_layer": False,
                "num_labels": self.num_labels,
            }
        )
        
        self.transformer = AutoModel.from_pretrained(model_name, config=config)
        if self.dynamic_merge_layers:
            self.layer_logits = nn.Linear(config.hidden_size, 1)

        if self.decoder == "span":
            self.start_fc = nn.Linear(config.hidden_size, span_num_labels)
            self.end_fc = nn.Linear(config.hidden_size, span_num_labels)
        else:
            self.output = nn.Linear(config.hidden_size, self.num_labels)
            if self.decoder == "crf":
                self.crf = CRF(num_tags=num_labels, batch_first=True)
        

    def forward(self, ids, mask, token_type_ids=None, targets=None):
        if token_type_ids:
            transformer_out = self.transformer(ids, mask, token_type_ids, output_hidden_states=self.dynamic_merge_layers)
        else:
            transformer_out = self.transformer(ids, mask, output_hidden_states=self.dynamic_merge_layers)

        if self.decoder == "crf" and transformer_out.last_hidden_state.shape[1] != ids.shape[1]:
            mask_add = torch.zeros((mask.shape[0],  transformer_out.hidden_states[-1].shape[1] - ids.shape[1])).to(mask.device)
            mask = torch.cat((mask, mask_add), dim=-1)
        if self.dynamic_merge_layers:
            layers_output = torch.cat([torch.unsqueeze(layer, 2) for layer in transformer_out.hidden_states[self.merge_layers_num:]], dim=2)
            layers_logits = self.layer_logits(layers_output)
            layers_weights = torch.transpose(torch.softmax(layers_logits, dim=-1), 2, 3)
            sequence_output = torch.squeeze(torch.matmul(layers_weights, layers_output), 2)
        else:
            sequence_output = transformer_out.last_hidden_state
        sequence_output = self.dropout(sequence_output)

        if self.decoder == "span":
            start_logits = self.start_fc(sequence_output)
            end_logits = self.end_fc(sequence_output)
            logits = (start_logits, end_logits)
            probs = None
        else:
            logits = self.output(sequence_output)
            if self.decoder == "softmax":
                probs = torch.softmax(logits, dim=-1)
            elif self.decoder == "crf":
                probs = self.crf.decode(emissions=logits, mask=mask.byte())
            else:
                raise ValueException("except decoder in [softmax, crf]")
        loss = 0

        return {
            "preds": probs,
            "logits": logits,
            "loss": loss,
            "metric": {}
        }

In [4]:
model = FeedbackModel(
    model_name="xlnet-base-cased",
    dynamic_merge_layers=False,
    num_labels=15,
    decoder="crf"
)

Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetModel: ['lm_loss.bias', 'lm_loss.weight']
- This IS expected if you are initializing XLNetModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
def swa(model, model_dir):
    """
    swa 滑动平均模型，一般在训练平稳阶段再使用 SWA
    """
    model_path_list = get_model_path_list(model_dir)

    swa_model = copy.deepcopy(model)
    swa_n = 0.

    with torch.no_grad():
        for _ckpt in model_path_list:
            print(f'Load model from {_ckpt}')
            model.load_state_dict(torch.load(_ckpt, map_location=torch.device('cpu')))
            tmp_para_dict = dict(model.named_parameters())

            alpha = 1. / (swa_n + 1.)

            for name, para in swa_model.named_parameters():
                para.copy_(tmp_para_dict[name].data.clone() * alpha + para.data.clone() * (1. - alpha))

            swa_n += 1

    # use 100000 to represent swa to avoid clash
    swa_model_dir = os.path.join(model_dir, f'checkpoint-100000')
    if not os.path.exists(swa_model_dir):
        os.mkdir(swa_model_dir)

    print(f'Save swa model in: {swa_model_dir}')

    swa_model_path = os.path.join(swa_model_dir, 'model.bin')

    torch.save(swa_model.state_dict(), swa_model_path)

    return swa_model


In [8]:
swa_model = swa(model, "/workspace/data-0225-xlnet-base-swa")

Load model from /workspace/data-0225-xlnet-base-swa/model_2.bin_epoch1
Load model from /workspace/data-0225-xlnet-base-swa/model_2.bin_epoch2
Load model from /workspace/data-0225-xlnet-base-swa/model_2.bin_epoch5
Save swa model in: /workspace/data-0225-xlnet-base-swa/checkpoint-100000
