In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import transformers as ts
from torch.utils.data import Dataset,DataLoader
import numpy as np
from simple_bert import RoBERTa_CRF

In [2]:
def load_data(path,train=True):
    if train:
        name='train.conll'
    else:
        name='dev.conll'
    texts=[]
    labels=[]
    with open(path+'/'+name) as file:
        for line in file:
            if line == '' or line == '\n':
                if texts:
                    yield{
                        'text':texts,
                        'label':labels
                    }
                    texts=[]
                    labels=[]
            else:
                sprilts=line.split()
                texts.append(sprilts[0])
                labels.append(sprilts[1])
        if texts:
            yield{
                'text':texts,
                'label':labels
            }
    file.close()

def get_entities(text,label):
    entities=[]
    cur_entities={}
    for t,l in zip(text,label):
        if l[0] in 'BOS' and cur_entities:
            entities.append(cur_entities)
            cur_entities={}
        if l[0] in 'BS':
            cur_entities={
                'text':t,
                'entities':[l[2:]]
            }
        elif l[0] in 'IE':
            cur_entities['text']+=t
            cur_entities['entities'].append(l[2:])
    if cur_entities:
        entities.append(cur_entities)
    return entities

def makedata(data):
    sentences=[]
    labels=[]
    for _,d in enumerate(data):
        entities=get_entities(d['text'],d['label'])
        sentence=''
        label=[]
        for e in entities :
            sentence+=e['text']
            label.extend(e['entities'])
        sentences.append([sentence])
        labels.append(label)
    return {'text':sentences,'label':labels}

train_data=load_data('./datasets/')
train_data=makedata(train_data)
val_data=load_data('./datasets/',False)
val_data=makedata(val_data)
train_data['text'][0],train_data['label'][0]

(['浙江杭州市江干区九堡镇三村村一区'],
 ['prov',
  'prov',
  'city',
  'city',
  'city',
  'district',
  'district',
  'district',
  'town',
  'town',
  'town',
  'community',
  'community',
  'community',
  'poi',
  'poi'])

In [3]:
def label2int(labels):
    ldict=['O']
    for l in labels:
        for i in l:
            if i not in ldict:
                ldict.append(i)
    return ldict
ldict=label2int(train_data['label']+val_data['label'])
ldict

['O',
 'prov',
 'city',
 'district',
 'town',
 'community',
 'poi',
 'road',
 'roadno',
 'subpoi',
 'devzone',
 'houseno',
 'intersection',
 'assist',
 'cellno',
 'floorno',
 'distance',
 'village_group']

In [4]:
class Addr(Dataset):
    """docstring for Addr."""
    def __init__(self,data,tokenizer,ldict):
        self.text=[]
        for t in data['text']:
            self.text.append(list(t[0]))
        self.encodings=tokenizer(self.text,is_split_into_words=True,padding=True)
        labels=[]
        for i,l in enumerate(data['label']):
            label=[0,]
            for t in l:
                label.append(ldict.index(t))
            for _ in range(0,len(self.encodings['input_ids'][i])-len(label)):
                label.append(0)
            labels.append(label)
        self.labels=labels
        
    def __getitem__(self, idx):
        input_ids = torch.LongTensor(self.encodings['input_ids'][idx])
        attention_mask = torch.LongTensor(self.encodings['attention_mask'][idx])
        labels = torch.LongTensor(self.labels[idx])
        return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}

    def __len__(self):
        return len(self.text)

In [5]:
tokenizer=ts.AutoTokenizer.from_pretrained('../../model/chinese-roberta-wwm-ext/')

In [6]:
train_datasets=Addr(train_data,tokenizer,ldict)
val_datasets=Addr(val_data,tokenizer,ldict)
val_datasets[1273]

{'input_ids': tensor([ 101, 3851, 3736, 4689, 2123, 3797, 2356, 3851, 3736, 4689, 2123, 3797,
         2356, 6969, 2336, 1277, 7674, 1298, 6125, 6887, 1921, 4997, 1298, 6662,
          121,  121,  121,  121, 1384, 4384, 4413, 7213, 3805, 1814,  121,  121,
         2231, 7824, 5440, 1767, 7674, 1298, 3862, 7623, 1324,  102]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
 'labels': tensor([ 0,  1,  1,  1,  2,  2,  2,  1,  1,  1,  2,  2,  2,  3,  3,  3,  4,  4,
          4,  4,  7,  7,  7,  7,  8,  8,  8,  8,  8,  6,  6,  6,  6,  6, 15, 15,
         15,  9,  9,  9,  9,  9,  9,  9,  9,  0])}

In [7]:
def load_test_data(path):
    texts=[]
    with open(path+'/final_test.txt') as file:
        for line in file:
            splits=line.split('\x01')
            texts.append(list(splits[1].rsplit()[0]))
        file.close()
    return texts

class AddrTest(Dataset):
    def __init__(self,data,tokenizer):
        self.text=data
        self.encodings=tokenizer(data,is_split_into_words=True,padding=True)
    def __len__(self):
        return len(self.text)
    def __getitem__(self, idx):
        input_ids = torch.LongTensor(self.encodings['input_ids'][idx])
        attention_mask = torch.LongTensor(self.encodings['attention_mask'][idx])
        return {'input_ids': input_ids, 'attention_mask': attention_mask}
    
test_data=load_test_data('./datasets/')
test_datasets=AddrTest(test_data,tokenizer)
test_datasets[0]

{'input_ids': tensor([ 101, 3308, 7345, 1277, 2207, 1068, 1266, 7027,  121,  121,  121,  118,
          121, 1384,  102,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [50]:
bertcfig=ts.BertConfig.from_pretrained('../../model/chinese-roberta-wwm-ext/')
bertcfig

BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "directionality": "bidi",
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "output_past": true,
  "pad_token_id": 0,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "position_embedding_type": "absolute",
  "transformers_version": "4.34.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 21128
}

In [51]:
bert=ts.BertModel.from_pretrained('../../model/chinese-roberta-wwm-ext/',config=bertcfig)

In [22]:
mt=RoBERTa_CRF('../../model/chinese-roberta-wwm-ext/',len(ldict))

In [29]:
tdl=DataLoader(train_datasets,batch_size=2)
for step ,batch in enumerate(tdl):
    out,loss=mt(batch)
    if step > 1:
        break

In [3]:
def load_test_data(path):
    texts=[]
    with open(path+'/final_test.txt') as file:
        for line in file:
            splits=line.split('\x01')
            texts.append(list(splits[1].rsplit()[0]))
        file.close()
    return texts

class AddrTest(Dataset):
    def __init__(self,data,tokenizer,device):
        self.text=data
        self.encodings=tokenizer(data,is_split_into_words=True,padding=True)
        self.device=device
    def __len__(self):
        return len(self.text)
    def __getitem__(self, idx):
        input_ids = torch.LongTensor(self.encodings['input_ids'][idx]).to(self.device)
        attention_mask = torch.LongTensor(self.encodings['attention_mask'][idx]).to(self.device)
        labels=torch.LongTensor(np.zeros_like(self.encodings['attention_mask'][idx])).to(self.device)
        return {'input_ids': input_ids, 'attention_mask': attention_mask,'labels':labels}

In [4]:
ldict=['O',
 'prov',
 'city',
 'district',
 'town',
 'community',
 'poi',
 'road',
 'roadno',
 'subpoi',
 'devzone',
 'houseno',
 'intersection',
 'assist',
 'cellno',
 'floorno',
 'distance',
 'village_group']

In [5]:
test_data=load_test_data('./datasets/')
device=torch.device('cuda')
tokenizer=ts.AutoTokenizer.from_pretrained('../../model/chinese-roberta-wwm-ext/')
test_datasets=AddrTest(test_data,tokenizer,device)
test_dl=DataLoader(test_datasets,batch_size=1)

In [6]:
model=RoBERTa_CRF(bert_path='../../model/chinese-roberta-wwm-ext/',num_labels=len(ldict))
state_dict=torch.load('./roberta_crf.pth')
model.load_state_dict(state_dict=state_dict)
model=model.to(device)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at ../../model/chinese-roberta-wwm-ext/ and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
def create_line(idx,data,preds):
    line=str(idx)+'\u0001'+''.join(data)+'\u0001'
    def split_same(preds):
        result = []
        current_group = [preds[0]]
        for i in range(1, len(preds)):
            if preds[i] == preds[i - 1]:
                current_group.append(preds[i])
            else:
                result.append(current_group)
                current_group = [preds[i]]
        result.append(current_group)
        return result
    rpreds=split_same(preds)
    rpreds=rpreds[1:-1]
    bio=''
    for s in rpreds:
        if s[0] == 0:
            bio=bio+(ldict[s[0]]+' ')*(len(s))
        elif len(s) == 1:
            bio=bio+'B-'+ldict[s[0]]+' '
        else:
            lens=len(s)-2
            bio=bio+'B-'+ldict[s[0]]+' '+('I-'+ldict[s[0]]+' ')*lens+'E-'+ldict[s[0]]+' '
    bio=bio.rstrip()
    return line+bio

In [11]:
with open('./Akiyan_addr_parsing_runid.txt',"r",encoding="utf-8") as file:
    for step ,batch in enumerate(test_dl):
        with torch.no_grad():
            out,loss=model(batch)
        indexs = batch['input_ids'].tolist()[0].index(102)
        sentence = tokenizer.decode(batch['input_ids'][0][1:indexs]).replace(" ","")
        line=create_line(step+1,sentence,out[0])
        print(line)
        file.write(line+'\n')
        if step >1 :
            break
    file.close()

0朝阳区小关北里000-0号B-district I-district E-district B-community I-community I-community E-community B-houseno I-houseno I-houseno I-houseno I-houseno E-houseno
1朝阳区惠新东街00号B-district I-district E-district B-road I-road I-road E-road B-roadno I-roadno E-roadno
2朝阳区南磨房路与西大望路交口东南角B-district I-district E-district B-road I-road I-road I-road I-road I-road I-road I-road E-road B-intersection E-intersection B-assist I-assist E-assist


In [2]:
result=[]
with open('./pred.txt',encoding='utf-8') as file:
    for line in file:
        result.append(line)

In [None]:
Akiyan_addr_parsing_runid

In [3]:
result[0]

'1\x01朝阳区小关北里000-0号\x01 B-district I-district E-district B-community I-community I-community E-community B-houseno I-houseno I-houseno I-houseno I-houseno E-houseno\n'

In [16]:
result[0]='1\x01朝阳区小关北里000-0号\x01 B-district I-district E-district B-community I-community I-community E-community B-houseno I-houseno I-houseno I-houseno I-houseno E-houseno\n'

In [3]:
with open('./new.txt','w',encoding='utf-8') as f2:
    for line in result:
        ss=line.split('\u0001')
        s=ss[2].split()
        l=len(ss[1])-len(s)
        if l>0:
            s=ss[2].rstrip()+' O'*l
            s=s[1:]
            #print(s)
        elif l<0:
            print(ss[0])
        else:
            s=ss[2].rstrip()
            s=s[1:]
        print(ss[0], ''.join(ss[1]), s, sep='\u0001', file=f2)