In [1]:
# (Test Cell 1) 安装必要包并挂载 Drive
!pip install -q pytorch-crf seqeval

from google.colab import drive
drive.mount('/content/drive')

import os, pickle, torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAVE_DIR = "/content/drive/MyDrive/CRFmodel"   # <-- 根据你的路径修改


[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone
Mounted at /content/drive


In [3]:
word2id = pickle.load(open(os.path.join(SAVE_DIR, "word2id.pkl"), "rb"))
id2label = pickle.load(open(os.path.join(SAVE_DIR, "id2label.pkl"), "rb"))
label2id = pickle.load(open(os.path.join(SAVE_DIR, "label2id.pkl"), "rb"))

print("Vocab size:", len(word2id))
print("Num labels:", len(id2label))


Vocab size: 21011
Num labels: 9


In [15]:
state = torch.load(os.path.join(SAVE_DIR, "bilstm_crf_ner_model.pt"), map_location=device)

# infer emb_dim
emb_dim = state['embedding.weight'].shape[1]

# infer hidden_dim
# CRF model: hidden2tag.weight shape = [tagset, hidden_dim*2]
tagset, fc_in = state['hidden2tag.weight'].shape
hidden_dim = fc_in // 2

print("Inferred emb_dim =", emb_dim)
print("Inferred hidden_dim =", hidden_dim)
print("Tagset size =", tagset)


Inferred emb_dim = 100
Inferred hidden_dim = 256
Tagset size = 9


In [16]:
class BiLSTM_CRF(nn.Module):
    def __init__(self, vocab_size, tagset_size, emb_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.lstm = nn.LSTM(
            emb_dim, hidden_dim,
            bidirectional=True,
            batch_first=True
        )
        self.hidden2tag = nn.Linear(hidden_dim * 2, tagset_size)
        self.crf = CRF(tagset_size, batch_first=True)

    def forward(self, x, tags=None, mask=None):
        emb = self.embedding(x)
        lstm_out, _ = self.lstm(emb)
        emissions = self.hidden2tag(lstm_out)
        if tags is not None:
            return -self.crf(emissions, tags, mask=mask, reduction="mean")
        else:
            return self.crf.decode(emissions, mask=mask)


In [17]:
model = BiLSTM_CRF(
    vocab_size=len(word2id),
    tagset_size=len(id2label),
    emb_dim=emb_dim,
    hidden_dim=hidden_dim
).to(device)

model.load_state_dict(state)
model.eval()

print("Model loaded successfully!")




Model loaded successfully!


In [18]:
MAX_LEN = 128

def encode(tokens):
    ids = [word2id.get(w.lower(), word2id.get("<UNK>", 1)) for w in tokens]
    ids = ids[:MAX_LEN]
    mask = [1] * len(ids)
    pad_len = MAX_LEN - len(ids)
    ids += [0] * pad_len
    mask += [0] * pad_len
    return torch.tensor([ids], dtype=torch.long), torch.tensor([mask], dtype=torch.bool)


In [19]:
def predict(tokens):
    x, mask = encode(tokens)
    x, mask = x.to(device), mask.to(device)

    with torch.no_grad():
        pred_ids = model(x, mask=mask)[0]

    return [id2label[i] for i in pred_ids[:len(tokens)]]


In [20]:
def extract_entities(tokens, labels):
    entities = []
    cur = []
    ent_type = None

    for w, t in zip(tokens, labels):
        if t.startswith("B-"):
            if cur:
                entities.append((" ".join(cur), ent_type))
            cur = [w]
            ent_type = t[2:]

        elif t.startswith("I-") and cur:
            cur.append(w)

        else:
            if cur:
                entities.append((" ".join(cur), ent_type))
            cur = []
            ent_type = None

    if cur:
        entities.append((" ".join(cur), ent_type))

    return entities


In [27]:
sentence = "I study at City University of Hong Kong with John Smith ."
tokens = sentence.split()

labels = predict(tokens)
entities = extract_entities(tokens, labels)

print("Tokens:", tokens)
print("Labels:", labels)
print("Entities:", entities)


Tokens: ['I', 'study', 'at', 'City', 'University', 'of', 'Hong', 'Kong', 'with', 'John', 'Smith', '.']
Labels: ['O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'I-LOC', 'O', 'B-PER', 'I-PER', 'O']
Entities: [('Hong Kong', 'LOC'), ('John Smith', 'PER')]


In [22]:
sentence = "I studied at the University of Cambridge ."
tokens = sentence.split()

labels = predict(tokens)
entities = extract_entities(tokens, labels)

print("Sentence:", " ".join(tokens))
print("Labels:  ", labels)
print("Entities:", entities)


Sentence: I studied at the University of Cambridge .
Labels:   ['O', 'O', 'O', 'O', 'B-ORG', 'O', 'B-LOC', 'O']
Entities: [('University', 'ORG'), ('Cambridge', 'LOC')]


In [23]:
sentence = "He graduated from Oxford University last year ."
tokens = sentence.split()

labels = predict(tokens)
entities = extract_entities(tokens, labels)

print("Sentence:", " ".join(tokens))
print("Labels:  ", labels)
print("Entities:", entities)


Sentence: He graduated from Oxford University last year .
Labels:   ['O', 'O', 'O', 'B-ORG', 'I-ORG', 'O', 'O', 'O']
Entities: [('Oxford University', 'ORG')]


In [24]:
sentence = "She visited Tsinghua University in Beijing ."
tokens = sentence.split()

labels = predict(tokens)
entities = extract_entities(tokens, labels)

print("Sentence:", " ".join(tokens))
print("Labels:  ", labels)
print("Entities:", entities)


Sentence: She visited Tsinghua University in Beijing .
Labels:   ['O', 'O', 'B-ORG', 'I-ORG', 'O', 'B-LOC', 'O']
Entities: [('Tsinghua University', 'ORG'), ('Beijing', 'LOC')]


不是哥们，清华都能识别出来，你城就不行吗