In [1]:
import json
from transformers import AutoTokenizer,AutoModel
from torchmetrics import Accuracy
from torch import nn
from torch.utils.data import DataLoader

import torch
from utils import evaluate_recall,evaluate_f1,evaluate_precision,train_model,generate_collate_fn
from torchmetrics import Precision, Recall, F1Score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
with open('./data/train.json','r') as f:
    train_data=json.load(f)
with open('./data/train.json','r') as f:
    test_data=json.load(f)
with open('./data/label2index.json','r') as f:
    label2index=json.load(f)

precision = Precision(task="multiclass",num_classes=19,ignore_index=0)
recall = Recall(task="multiclass",num_classes=19,ignore_index=0)
f1 = F1Score(task="multiclass",num_classes=19,ignore_index=0)
device="cuda" if torch.cuda.is_available() else "cpu"
precision.to(device)
recall.to(device)
f1.to(device)

tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
net = AutoModel.from_pretrained("bert-base-chinese")
net.to(device)

optimizer= torch.optim.Adam(net.parameters(),lr = 1e-5)
loss_fn = nn.CrossEntropyLoss()
metrics_dict = {"precision":precision,
                "recall":precision,
                "f1":precision}

collate_fn=generate_collate_fn(tokenizer,label2index,device=device)
dl_train=DataLoader(train_data,batch_size=10,collate_fn=collate_fn)
dl_test=DataLoader(test_data,batch_size=10,collate_fn=collate_fn)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [1]:
"""
for name,params in net.named_parameters():
    params.requires_grad=False
"""

'\nfor name,params in net.named_parameters():\n    params.requires_grad=False\n'

In [4]:
class BertBilstm(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.encoder=net
        self.bilstm=nn.LSTM(input_size=768,hidden_size=384,batch_first=True,bidirectional=True)
        self.classifier=nn.Linear(768,19)

    def forward(self,**inputs):
        x=self.encoder(**inputs).last_hidden_state
        output, (hn, cn)=self.bilstm(x)
       
        return self.classifier(output)
    

In [5]:
bertbilstm=BertBilstm()

In [6]:
dfhistory = train_model(bertbilstm,
    optimizer,
    loss_fn,
    metrics_dict,
    train_data = dl_train,
    val_data= dl_train,
    epochs=1,
    patience=5,
    monitor="val_f1", 
    mode="max"
    )


Epoch 1 / 1

  0%|          | 19/20039 [00:30<8:37:52,  1.55s/it, train_f1=0.0232, train_loss=3, train_precision=0.0232, train_recall=0.0232]      

  0%|          | 21/20039 [00:33<7:54:31,  1.42s/it, train_f1=0.0231, train_loss=2.98, train_precision=0.0231, train_recall=0.0231]   

In [None]:
index2label={
    value:key
    for key,value in label2index
}


def parse_fn(preds,sample_id,offset_mapping,text):
    preds = torch.argmax(preds, dim=-1)
    start,end,flag=0,0,false
    Entitys=[]
    for i,index in enumerate(preds):
        if index%2==1 and flag==0:
            start,end,flag=i,i,index
        elif flag!=0 and index==flag+1:
            end=i
        elif flag!=0:
            span=(offset_mapping[start][0],offset_mapping[end][1])
            Entitys.append({
                "sample_id":sample_id,
                "span":text[span[0],span[1]],
                "type":index2label[(flag+1)//2]
            })
            flag=0
            if index%2==1:
                start,end,flag=i,i,index
        else:
            flag=0

    return Entitys

            

