# Named Entity Recognition 
We will implement a bidirectional LSTM-CNN-CRF for sequence labeling, following [this paper by Xuezhe Ma and Ed Hovy](https://www.aclweb.org/anthology/P16-1101.pdf), on the CoNLL named entity recognition dataset.


## Imports + GPU


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

print(f'GPU available: {torch.cuda.is_available()}')

Sat Oct  4 02:11:23 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   54C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

## Download the Data

Run the following code to download the English part of the CoNLL 2003 dataset, the evaluation script and pre-filtered GloVe embeddings we are providing for this data.

In [3]:
#CoNLL 2003 data
!wget https://raw.githubusercontent.com/patverga/torch-ner-nlp-from-scratch/master/data/conll2003/eng.train
!wget https://raw.githubusercontent.com/patverga/torch-ner-nlp-from-scratch/master/data/conll2003/eng.testa
!wget https://raw.githubusercontent.com/patverga/torch-ner-nlp-from-scratch/master/data/conll2003/eng.testb
!cat eng.train | awk '{print $1 "\t" $4}' > train
!cat eng.testa | awk '{print $1 "\t" $4}' > dev
!cat eng.testb | awk '{print $1 "\t" $4}' > test

#Evaluation Script
!wget https://raw.githubusercontent.com/aritter/twitter_nlp/master/data/annotated/wnut16/conlleval.pl

#Pre-filtered GloVe embeddings
!wget https://raw.githubusercontent.com/aritter/aritter.github.io/master/files/glove.840B.300d.conll_filtered.txt

--2025-10-04 02:11:23--  https://raw.githubusercontent.com/patverga/torch-ner-nlp-from-scratch/master/data/conll2003/eng.train
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3283420 (3.1M) [text/plain]
Saving to: ‘eng.train.6’


2025-10-04 02:11:24 (14.0 MB/s) - ‘eng.train.6’ saved [3283420/3283420]

--2025-10-04 02:11:24--  https://raw.githubusercontent.com/patverga/torch-ner-nlp-from-scratch/master/data/conll2003/eng.testa
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 827443 (808K) [text/plain]
Saving to: ‘eng.testa.6’


2025

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## CoNLL Data Format

In [5]:
!head -n 20 train

-DOCSTART-	O
	
EU	I-ORG
rejects	O
German	I-MISC
call	O
to	O
boycott	O
British	I-MISC
lamb	O
.	O
	
Peter	I-PER
Blackburn	I-PER
	
BRUSSELS	I-LOC
1996-08-22	O
	
The	O
European	I-ORG


## Reading in the Data

In [6]:
#Read in the training data
def read_conll_format(filename):
    (words, tags, currentSent, currentTags) = ([],[],['-START-'],['START'])
    for line in open(filename).readlines():
        line = line.strip()
        #print(line)
        if line == "":
            currentSent.append('-END-')
            currentTags.append('END')
            words.append(currentSent)
            tags.append(currentTags)
            (currentSent, currentTags) = (['-START-'], ['START'])
        else:
            (word, tag) = line.split()
            currentSent.append(word)
            currentTags.append(tag)
    return (words, tags)

def sentences2char(sentences):
    return [[['start'] + [c for c in w] + ['end'] for w in l] for l in sentences]


(sentences_train, tags_train) = read_conll_format("train")
(sentences_dev, tags_dev)     = read_conll_format("dev")

print("The second sentence in train set:", sentences_train[2])
print("The NER label of the sentence:   ", tags_train[2])

sentencesChar = sentences2char(sentences_train)

print("The char repersentation of the sentence:", sentencesChar[2])

The second sentence in train set: ['-START-', 'Peter', 'Blackburn', '-END-']
The NER label of the sentence:    ['START', 'I-PER', 'I-PER', 'END']
The char repersentation of the sentence: [['start', '-', 'S', 'T', 'A', 'R', 'T', '-', 'end'], ['start', 'P', 'e', 't', 'e', 'r', 'end'], ['start', 'B', 'l', 'a', 'c', 'k', 'b', 'u', 'r', 'n', 'end'], ['start', '-', 'E', 'N', 'D', '-', 'end']]


In [7]:
#Read GloVe embeddings.
def read_GloVe(filename):
    embeddings = {}
    for line in open(filename).readlines():
        #print(line)
        fields = line.strip().split(" ")
        word = fields[0]
        embeddings[word] = [float(x) for x in fields[1:]]
    return embeddings

GloVe = read_GloVe("glove.840B.300d.conll_filtered.txt")

print("The GloVe word embedding of the word 'the':", GloVe["the"])
print("dimension of glove embedding:", len(GloVe["the"]))

The GloVe word embedding of the word 'the': [0.27204, -0.06203, -0.1884, 0.023225, -0.018158, 0.0067192, -0.13877, 0.17708, 0.17709, 2.5882, -0.35179, -0.17312, 0.43285, -0.10708, 0.15006, -0.19982, -0.19093, 1.1871, -0.16207, -0.23538, 0.003664, -0.19156, -0.085662, 0.039199, -0.066449, -0.04209, -0.19122, 0.011679, -0.37138, 0.21886, 0.0011423, 0.4319, -0.14205, 0.38059, 0.30654, 0.020167, -0.18316, -0.0065186, -0.0080549, -0.12063, 0.027507, 0.29839, -0.22896, -0.22882, 0.14671, -0.076301, -0.1268, -0.0066651, -0.052795, 0.14258, 0.1561, 0.05551, -0.16149, 0.09629, -0.076533, -0.049971, -0.010195, -0.047641, -0.16679, -0.2394, 0.0050141, -0.049175, 0.013338, 0.41923, -0.10104, 0.015111, -0.077706, -0.13471, 0.119, 0.10802, 0.21061, -0.051904, 0.18527, 0.17856, 0.041293, -0.014385, -0.082567, -0.035483, -0.076173, -0.045367, 0.089281, 0.33672, -0.22099, -0.0067275, 0.23983, -0.23147, -0.88592, 0.091297, -0.012123, 0.013233, -0.25799, -0.02972, 0.016754, 0.01369, 0.32377, 0.039546, 0.

## Mapping Tokens to Indices

In [8]:
#Create mappings between tokens and indices.

from collections import Counter
import random

#Will need this later to remove 50% of words that only appear once in the training data from the vocabulary (and don't have GloVe embeddings).
wordCounts = Counter([w for l in sentences_train for w in l])
charCounts = Counter([c for l in sentences_train for w in l for c in w])
singletons = set([w for (w,c) in wordCounts.items() if c == 1 and not w in GloVe.keys()])
charSingletons = set([w for (w,c) in charCounts.items() if c == 1])

#Build dictionaries to map from words, characters to indices and vice versa.
#Save first two words in the vocabulary for padding and "UNK" token.
word2i = {w:i+2 for i,w in enumerate(set([w for l in sentences_train for w in l] + list(GloVe.keys())))}
char2i = {w:i+2 for i,w in enumerate(set([c for l in sentencesChar for w in l for c in w]))}
i2word = {i:w for w,i in word2i.items()}
i2char = {i:w for w,i in char2i.items()}

vocab_size = max(word2i.values()) + 1
char_vocab_size = max(char2i.values()) + 1

#Tag dictionaries.
tag2i = {w:i for i,w in enumerate(set([t for l in tags_train for t in l]))}
i2tag = {i:t for t,i in tag2i.items()}

#When training, randomly replace singletons with UNK tokens sometimes to simulate situation at test time.
def getDictionaryRandomUnk(w, dictionary, train=False):
    if train and (w in singletons and random.random() > 0.5):
        return 1
    else:
        return dictionary.get(w, 1)

#Map a list of sentences from words to indices.
def sentences2indices(words, dictionary, train=False):
    #1.0 => UNK
    return [[getDictionaryRandomUnk(w,dictionary, train=train) for w in l] for l in words]

#Map a list of sentences containing to indices (character indices)
def sentences2indicesChar(chars, dictionary):
    #1.0 => UNK
    return [[[dictionary.get(c,1) for c in w] for w in l] for l in chars]

#Indices
X       = sentences2indices(sentences_train, word2i, train=True)
X_char  = sentences2indicesChar(sentencesChar, char2i)
Y       = sentences2indices(tags_train, tag2i)

print("vocab size:", vocab_size)
print("char vocab size:", char_vocab_size)
print()

print("index of word 'the':", word2i["the"])
print("word of index 253:", i2word[253])
print()

#Print out some examples of what the dev inputs will look like
for i in range(10):
    print(" ".join([i2word.get(w,'UNK') for w in X[i]]))

vocab size: 29148
char vocab size: 88

index of word 'the': 11270
word of index 253: Refining

-START- -DOCSTART- -END-
-START- EU rejects German call to boycott British lamb . -END-
-START- Peter Blackburn -END-
-START- BRUSSELS 1996-08-22 -END-
-START- The European Commission said on Thursday it disagreed with German advice to consumers to shun British lamb until scientists determine whether mad cow disease can be transmitted to sheep . -END-
-START- Germany 's representative to the European Union 's veterinary committee Werner Zwingmann said on Wednesday consumers should buy sheepmeat from countries other than Britain until the scientific advice was clearer . -END-
-START- " We do n't support any such recommendation because we do n't see any grounds for it , " the Commission 's chief spokesman Nikolaus van der Pas told a news briefing . -END-
-START- He said further scientific study was required and if it was found that action was needed it should be taken by the European Union . -E

## Padding and Batching

In [9]:
#Pad inputs to max sequence length (for batching)
def prepare_input(X_list):
    X_padded = torch.nn.utils.rnn.pad_sequence([torch.as_tensor(l) for l in X_list], batch_first=True).type(torch.LongTensor) # padding the sequences with 0
    X_mask   = torch.nn.utils.rnn.pad_sequence([torch.as_tensor([1.0] * len(l)) for l in X_list], batch_first=True).type(torch.FloatTensor) # consisting of 0 and 1, 0 for padded positions, 1 for non-padded positions
    return (X_padded, X_mask)

#Maximum word length (for character representations)
MAX_CLEN = 32

def prepare_input_char(X_list):
    MAX_SLEN = max([len(l) for l in X_list])
    X_padded  = [l + [[]]*(MAX_SLEN-len(l))  for l in X_list]
    X_padded  = [[w[0:MAX_CLEN] for w in l] for l in X_padded]
    X_padded  = [[w + [1]*(MAX_CLEN-len(w)) for w in l] for l in X_padded]
    return torch.as_tensor(X_padded).type(torch.LongTensor)

#Pad outputs using one-hot encoding
def prepare_output_onehot(Y_list, NUM_TAGS=max(tag2i.values())+1):
    Y_onehot = [torch.zeros(len(l), NUM_TAGS) for l in Y_list]
    for i in range(len(Y_list)):
        for j in range(len(Y_list[i])):
            Y_onehot[i][j,Y_list[i][j]] = 1.0
    Y_padded = torch.nn.utils.rnn.pad_sequence(Y_onehot, batch_first=True).type(torch.FloatTensor)
    return Y_padded

print("max slen:", max([len(x) for x in X_char]))

(X_padded, X_mask) = prepare_input(X)
X_padded_char      = prepare_input_char(X_char)
Y_onehot           = prepare_output_onehot(Y)

print("X_padded:", X_padded.shape)
print("X_mask:", X_mask.shape)
print("X_padded_char:", X_padded_char.shape)
print("Y_onehot:", Y_onehot.shape)

max slen: 115
X_padded: torch.Size([14987, 115])
X_mask: torch.Size([14987, 115])
X_padded_char: torch.Size([14987, 115, 32])
Y_onehot: torch.Size([14987, 115, 10])


## Basic LSTM Tagger

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicLSTMtagger(nn.Module):
    def __init__(self, DIM_EMB=10, DIM_HID=10):
        super(BasicLSTMtagger, self).__init__()
        NUM_TAGS = max(tag2i.values())+1

        (self.DIM_EMB, self.NUM_TAGS) = (DIM_EMB, NUM_TAGS)
        vocab_size = max(word2i.values()) + 1
        self.embedding   = nn.Embedding(num_embeddings=vocab_size,
                                        embedding_dim=DIM_EMB,
                                        padding_idx=0)
        self.lstm        = nn.LSTM(input_size=DIM_EMB,
                                   hidden_size=DIM_HID,
                                   num_layers=1,
                                   batch_first=True,
                                   bidirectional=True)
        self.linear      = nn.Linear(2*DIM_HID, NUM_TAGS)
        self.logsoftmax  = nn.LogSoftmax(dim=2)

    def forward(self, X, train=False):

        X = X.long()
        emb = self.embedding(X)                        # (B, T, DIM_EMB)
        lstm_out, _ = self.lstm(emb)                   # (B, T, DIM_HID)
        logits = self.linear(lstm_out)                 # (B, T, NUM_TAGS)
        log_probs = self.logsoftmax(logits)            # (B, T, NUM_TAGS)
        return log_probs

    def init_glove(self, GloVe):
        with torch.no_grad():
            W = self.embedding.weight
            # keep PAD row zeroed
            if W.size(0) > 0:
                W[0].zero_()
            # (optional) you could also zero UNK row (index 1) if desired:
            # if W.size(0) > 1: W[1].zero_()

            for w, idx in word2i.items():
                vec = GloVe.get(w)
                if vec is None:
                    continue
                v = torch.as_tensor(vec, dtype=W.dtype, device=W.device)
                # copy up to DIM_EMB in case vectors are longer
                W[idx, :self.DIM_EMB] = v[:self.DIM_EMB]


    def inference(self, sentences):
        X, X_mask       = prepare_input(sentences2indices(sentences, word2i))
        pred = self.forward(X.cuda()).argmax(dim=2)
        return [[i2tag[pred[i,j].item()] for j in range(len(sentences[i]))] for i in range(len(sentences))]

    def print_predictions(self, words, tags):
        Y_pred = self.inference(words)
        for i in range(len(words)):
            print("----------------------------")
            print(" ".join([f"{words[i][j]}/{Y_pred[i][j]}/{tags[i][j]}" for j in range(len(words[i]))]))
            print("Predicted:\t", Y_pred[i])
            print("Gold:\t\t", tags[i])

    def write_predictions(self, sentences, outFile):
        fOut = open(outFile, 'w')
        for s in sentences:
            y = self.inference([s])[0]
            #print("\n".join(y[1:len(y)-1]))
            fOut.write("\n".join(y[1:len(y)-1]))  #Skip start and end tokens
            fOut.write("\n\n")

#The following code will initialize a model and test that your forward computation runs without errors.
lstm_test   = BasicLSTMtagger(DIM_HID=7, DIM_EMB=300)
lstm_output = lstm_test.forward(prepare_input(X[0:5])[0])
Y_onehot    = prepare_output_onehot(Y[0:5])

#Check the shape of the lstm_output and one-hot label tensors.
print("lstm output shape:", lstm_output.shape)
print("Y onehot shape:", Y_onehot.shape)

lstm output shape: torch.Size([5, 32, 10])
Y onehot shape: torch.Size([5, 32, 10])


In [11]:
#Read in the data

(sentences_dev, tags_dev)     = read_conll_format('dev')
(sentences_train, tags_train) = read_conll_format('train')
(sentences_test, tags_test)   = read_conll_format('test')

## Train your Model (10 points)

Next, implement the function below to train your basic BiLSTM tagger.  See [torch.nn.lstm](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html).  Make sure to save your predictions on the test set (`test_pred_lstm.txt`) for submission to GradeScope. Feel free to change number of epochs, optimizer, learning rate and batch size.

In [12]:
#Training

from random import sample
from tqdm import tqdm
import os
import subprocess
import random

def shuffle_sentences(sentences, tags):
    shuffled_sentences = []
    shuffled_tags      = []
    indices = list(range(len(sentences)))
    random.shuffle(indices)
    for i in indices:
        shuffled_sentences.append(sentences[i])
        shuffled_tags.append(tags[i])
    return (shuffled_sentences, shuffled_tags)


def train_basic_lstm(sentences, tags, lstm):
    lr = 1e-3
    nEpochs = 5
    batchSize = 32
    optimizer = torch.optim.Adam(lstm.parameters(), lr=lr)


    for epoch in range(nEpochs):
        totalLoss = 0.0

        (sentences_shuffled, tags_shuffled) = shuffle_sentences(sentences, tags)
        for batch in tqdm(range(0, len(sentences), batchSize), leave=False):
            lstm.train()
            batch_sents = sentences_shuffled[batch:batch+batchSize]
            batch_tags  = tags_shuffled[batch:batch+batchSize]

            # indices, padding & masks
            X_batch, X_mask = prepare_input(
                sentences2indices(batch_sents, word2i, train=True)
            )
            Y_idx     = sentences2indices(batch_tags, tag2i)
            Y_onehot  = prepare_output_onehot(Y_idx, NUM_TAGS=lstm.NUM_TAGS)

            # forward
            log_probs = lstm.forward(X_batch.cuda(), train=True)   # (B,T,L)

            Y_ids_pad = prepare_input(Y_idx)[0].to(X_mask.device)   # (B,T) tag ids, padded with 0s
            valid = (Y_ids_pad != tag2i['START']) & (Y_ids_pad != tag2i['END'])
            Xm = X_mask.cuda()
            mask = (Xm * valid.float().cuda())

            nll  = -(Y_onehot.cuda() * log_probs).sum(dim=2)  # (B,T)
            loss = (nll * mask).sum() / mask.sum()

            # backward
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(lstm.parameters(), max_norm=5.0)
            optimizer.step()

            totalLoss += loss.item()

        print(f"loss on epoch {epoch} = {totalLoss}")
        lstm.write_predictions(sentences_dev, 'dev_pred')   #Performance on dev set
        print('conlleval:')
        print(subprocess.Popen('paste dev dev_pred | perl conlleval.pl -d "\t"', shell=True, stdout=subprocess.PIPE,stderr=subprocess.STDOUT).communicate()[0].decode('UTF-8'))

        if epoch % 10 == 0:
            s = sample(range(len(sentences_dev)), 5)
            lstm.print_predictions([sentences_dev[i] for i in s], [tags_dev[i] for i in s])


lstm = BasicLSTMtagger(DIM_HID=500, DIM_EMB=300).cuda()
lstm.init_glove(GloVe)
train_basic_lstm(sentences_train, tags_train, lstm)



loss on epoch 0 = 73.39761619269848
conlleval:
processed 51578 tokens with 5942 phrases; found: 5911 phrases; correct: 5227.
accuracy:  98.07%; precision:  88.43%; recall:  87.97%; FB1:  88.20
              LOC: precision:  93.16%; recall:  91.94%; FB1:  92.55  1813
             MISC: precision:  81.56%; recall:  80.59%; FB1:  81.07  911
              ORG: precision:  80.81%; recall:  83.82%; FB1:  82.28  1391
              PER: precision:  93.04%; recall:  90.72%; FB1:  91.86  1796

----------------------------
-START-/O/START Among/O/O the/O/O crowd/O/O on/O/O Friday/O/O were/O/O Olympic/I-MISC/I-MISC 100/O/O metres/O/O champions/O/O going/O/O back/O/O to/O/O 1948/O/O ./O/O -END-/O/END
Predicted:	 ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
Gold:		 ['START', 'O', 'O', 'O', 'O', 'O', 'O', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'END']
----------------------------
-START-/O/START Statistics/I-ORG/I-ORG Canada/I-LOC/I-ORG on/O/O F



loss on epoch 1 = 18.73706996953115
conlleval:
processed 51578 tokens with 5942 phrases; found: 5942 phrases; correct: 5322.
accuracy:  98.36%; precision:  89.57%; recall:  89.57%; FB1:  89.57
              LOC: precision:  94.49%; recall:  91.45%; FB1:  92.95  1778
             MISC: precision:  85.76%; recall:  83.62%; FB1:  84.68  899
              ORG: precision:  81.03%; recall:  85.98%; FB1:  83.43  1423
              PER: precision:  93.27%; recall:  93.27%; FB1:  93.27  1842





loss on epoch 2 = 8.131664399872534
conlleval:
processed 51578 tokens with 5942 phrases; found: 5938 phrases; correct: 5361.
accuracy:  98.43%; precision:  90.28%; recall:  90.22%; FB1:  90.25
              LOC: precision:  94.27%; recall:  94.88%; FB1:  94.57  1849
             MISC: precision:  84.95%; recall:  84.49%; FB1:  84.72  917
              ORG: precision:  85.04%; recall:  83.52%; FB1:  84.27  1317
              PER: precision:  92.67%; recall:  93.32%; FB1:  92.99  1855





loss on epoch 3 = 3.9350833047938067
conlleval:
processed 51578 tokens with 5942 phrases; found: 5972 phrases; correct: 5394.
accuracy:  98.50%; precision:  90.32%; recall:  90.78%; FB1:  90.55
              LOC: precision:  94.82%; recall:  94.61%; FB1:  94.71  1833
             MISC: precision:  84.64%; recall:  84.27%; FB1:  84.46  918
              ORG: precision:  83.71%; recall:  86.20%; FB1:  84.94  1381
              PER: precision:  93.64%; recall:  93.54%; FB1:  93.59  1840





loss on epoch 4 = 1.9633699299884029
conlleval:
processed 51578 tokens with 5942 phrases; found: 5994 phrases; correct: 5396.
accuracy:  98.44%; precision:  90.02%; recall:  90.81%; FB1:  90.42
              LOC: precision:  93.79%; recall:  94.50%; FB1:  94.14  1851
             MISC: precision:  83.10%; recall:  85.36%; FB1:  84.22  947
              ORG: precision:  85.14%; recall:  85.01%; FB1:  85.07  1339
              PER: precision:  93.32%; recall:  94.08%; FB1:  93.70  1857



In [13]:
#Evaluation on test data
lstm.write_predictions(sentences_test, 'test_pred_lstm.txt')
!wget https://raw.githubusercontent.com/aritter/twitter_nlp/master/data/annotated/wnut16/conlleval.pl
!paste test test_pred_lstm.txt | perl conlleval.pl -d "\t"

--2025-10-04 02:13:15--  https://raw.githubusercontent.com/aritter/twitter_nlp/master/data/annotated/wnut16/conlleval.pl
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 12754 (12K) [text/plain]
Saving to: ‘conlleval.pl.23’


2025-10-04 02:13:15 (5.41 MB/s) - ‘conlleval.pl.23’ saved [12754/12754]

processed 46666 tokens with 5648 phrases; found: 5662 phrases; correct: 4814.
accuracy:  97.22%; precision:  85.02%; recall:  85.23%; FB1:  85.13
              LOC: precision:  87.17%; recall:  89.21%; FB1:  88.18  1707
             MISC: precision:  69.05%; recall:  75.64%; FB1:  72.20  769
              ORG: precision:  82.33%; recall:  80.25%; FB1:  81.28  1619
              PER: precision:  93.30%; recall:  90.41%; FB1:  91.83  1567


## Initialization with GloVe Embeddings

## Character Embeddings

In [14]:
class CharLSTMtagger(BasicLSTMtagger):
    def __init__(self, DIM_EMB=10, DIM_CHAR_EMB=30, DIM_HID=10):
        super(CharLSTMtagger, self).__init__(DIM_EMB=DIM_EMB, DIM_HID=DIM_HID)
        NUM_TAGS = max(tag2i.values())+1

        (self.DIM_EMB, self.NUM_TAGS) = (DIM_EMB, NUM_TAGS)

        char_vocab_size = max(char2i.values()) + 1
        self.char_emb   = nn.Embedding(char_vocab_size, DIM_CHAR_EMB, padding_idx=1)

        self.char_conv  = nn.Conv1d(in_channels=DIM_CHAR_EMB, out_channels=30, kernel_size=3)

        self.dropout    = nn.Dropout(p=0.5)

        self.lstm       = nn.LSTM(input_size=DIM_EMB + 30,
                                  hidden_size=DIM_HID,
                                  num_layers=1,
                                  batch_first=True,
                                  bidirectional=True)
        self.linear     = nn.Linear(2 * DIM_HID, NUM_TAGS)

    def forward(self, X, X_char, train=False):
        X = X.long()
        word_emb = self.embedding(X)                           # (B, T, DIM_EMB)

        B, T, CLEN = X_char.size()
        Xc = X_char.view(B * T, CLEN).long()                   # (B*T, CLEN)
        ce = self.char_emb(Xc)                                 # (B*T, CLEN, Cdim)
        ce = ce.permute(0, 2, 1)                               # (B*T, Cdim, CLEN) for Conv1d
        cf = F.relu(self.char_conv(ce))                        # (B*T, 30, L')
        cp = F.max_pool1d(cf, kernel_size=cf.size(2)).squeeze(2)  # (B*T, 30)
        char_repr = cp.view(B, T, -1)                          # (B, T, 30)

        feats = torch.cat([word_emb, char_repr], dim=2)        # (B, T, DIM_EMB+30)
        feats = self.dropout(feats)

        lstm_out, _ = self.lstm(feats)                         # (B, T, 2*H)
        lstm_out = self.dropout(lstm_out)

        logits = self.linear(lstm_out)                         # (B, T, NUM_TAGS)
        log_probs = self.logsoftmax(logits)                    # (B, T, NUM_TAGS)
        return log_probs

    def sentences2input_tensors(self, sentences):
        (X, X_mask)   = prepare_input(sentences2indices(sentences, word2i))
        X_char        = prepare_input_char(sentences2indicesChar(sentences, char2i))
        return (X, X_mask, X_char)

    def inference(self, sentences):
        (X, X_mask, X_char) = self.sentences2input_tensors(sentences)
        pred = self.forward(X.cuda(), X_char.cuda()).argmax(dim=2)
        return [[i2tag[pred[i,j].item()] for j in range(len(sentences[i]))] for i in range(len(sentences))]

    def print_predictions(self, words, tags):
        Y_pred = self.inference(words)
        for i in range(len(words)):
            print("----------------------------")
            print(" ".join([f"{words[i][j]}/{Y_pred[i][j]}/{tags[i][j]}" for j in range(len(words[i]))]))
            print("Predicted:\t", Y_pred[i])
            print("Gold:\t\t", tags[i])

char_lstm_test = CharLSTMtagger(DIM_HID=7, DIM_EMB=300)
lstm_output    = char_lstm_test.forward(prepare_input(X[0:5])[0], prepare_input_char(X_char[0:5]))
Y_onehot       = prepare_output_onehot(Y[0:5])

print("lstm output shape:", lstm_output.shape)
print("Y onehot shape:", Y_onehot.shape)

lstm output shape: torch.Size([5, 32, 10])
Y onehot shape: torch.Size([5, 32, 10])


In [15]:
def train_char_lstm(sentences, tags, lstm):
    lr = 5e-4
    nEpochs = 15
    batchSize = 32
    optimizer = torch.optim.Adam(lstm.parameters(), lr=lr)


    for epoch in range(nEpochs):
        totalLoss = 0.0

        (sentences_shuffled, tags_shuffled) = shuffle_sentences(sentences, tags)
        for batch in tqdm(range(0, len(sentences), batchSize), leave=False):
            lstm.train()
            batch_sents = sentences_shuffled[batch:batch+batchSize]
            batch_tags  = tags_shuffled[batch:batch+batchSize]

            X_batch, X_mask = prepare_input(
                sentences2indices(batch_sents, word2i, train=True)
            )
            X_char_batch = prepare_input_char(
                sentences2indicesChar(batch_sents, char2i)
            )
            Y_idx    = sentences2indices(batch_tags, tag2i)
            Y_onehot = prepare_output_onehot(Y_idx, NUM_TAGS=lstm.NUM_TAGS)

            log_probs = lstm.forward(X_batch.cuda(), X_char_batch.cuda(), train=True)  # (B,T,L)

            nll = -(Y_onehot.cuda() * log_probs).sum(dim=2)   # (B,T)
            Xm  = X_mask.cuda()

            Y_pad_ids = prepare_input(Y_idx)[0].cuda()        # (B,T)
            if 'START' in tag2i and 'END' in tag2i:
                valid = (Y_pad_ids != tag2i['START']) & (Y_pad_ids != tag2i['END'])
                mask = Xm * valid.float()
            else:
                mask = Xm

            loss = (nll * mask).sum() / mask.sum()

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(lstm.parameters(), 5.0)
            optimizer.step()

            totalLoss += loss.item()

        print(f"loss on epoch {epoch} = {totalLoss}")
        lstm.write_predictions(sentences_dev, 'dev_pred')   #Performance on dev set
        print('conlleval:')
        print(subprocess.Popen('paste dev dev_pred | perl conlleval.pl -d "\t"', shell=True, stdout=subprocess.PIPE,stderr=subprocess.STDOUT).communicate()[0].decode('UTF-8'))

        if epoch % 10 == 0:
            s = sample(range(len(sentences_dev)), 5)
            lstm.print_predictions([sentences_dev[i] for i in s], [tags_dev[i] for i in s])

char_lstm = CharLSTMtagger(DIM_HID=500, DIM_EMB=300).cuda()
char_lstm.init_glove(GloVe)
train_char_lstm(sentences_train, tags_train, char_lstm)



loss on epoch 0 = 125.15589151531458
conlleval:
processed 51578 tokens with 5942 phrases; found: 6071 phrases; correct: 4553.
accuracy:  95.92%; precision:  75.00%; recall:  76.62%; FB1:  75.80
              LOC: precision:  82.65%; recall:  82.96%; FB1:  82.80  1844
             MISC: precision:  69.10%; recall:  67.68%; FB1:  68.38  903
              ORG: precision:  60.06%; recall:  64.35%; FB1:  62.13  1437
              PER: precision:  81.72%; recall:  83.71%; FB1:  82.70  1887

----------------------------
-START-/O/START DNIB/I-LOC/I-ORG issued/O/O a/O/O 275/O/O million/O/O Norwegian/I-MISC/I-MISC crown/O/O bond/O/O ,/O/O which/O/O was/O/O pre-placed/O/O with/O/O a/O/O European/I-MISC/I-MISC institution/O/O ./O/O -END-/O/END
Predicted:	 ['O', 'I-LOC', 'O', 'O', 'O', 'O', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-MISC', 'O', 'O', 'O']
Gold:		 ['START', 'I-ORG', 'O', 'O', 'O', 'O', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-MISC', 'O', 'O', 'END']
-----------



loss on epoch 1 = 52.07320407964289
conlleval:
processed 51578 tokens with 5942 phrases; found: 6055 phrases; correct: 4873.
accuracy:  97.00%; precision:  80.48%; recall:  82.01%; FB1:  81.24
              LOC: precision:  86.32%; recall:  85.90%; FB1:  86.11  1828
             MISC: precision:  73.46%; recall:  73.86%; FB1:  73.66  927
              ORG: precision:  66.86%; recall:  75.69%; FB1:  71.00  1518
              PER: precision:  89.73%; recall:  86.81%; FB1:  88.25  1782





loss on epoch 2 = 35.00297735724598
conlleval:
processed 51578 tokens with 5942 phrases; found: 6021 phrases; correct: 5112.
accuracy:  97.72%; precision:  84.90%; recall:  86.03%; FB1:  85.46
              LOC: precision:  88.30%; recall:  92.87%; FB1:  90.53  1932
             MISC: precision:  78.11%; recall:  78.20%; FB1:  78.16  923
              ORG: precision:  77.65%; recall:  76.96%; FB1:  77.30  1329
              PER: precision:  89.98%; recall:  89.74%; FB1:  89.86  1837





loss on epoch 3 = 25.5188364693895
conlleval:
processed 51578 tokens with 5942 phrases; found: 5979 phrases; correct: 5109.
accuracy:  97.81%; precision:  85.45%; recall:  85.98%; FB1:  85.71
              LOC: precision:  92.53%; recall:  86.34%; FB1:  89.33  1714
             MISC: precision:  80.46%; recall:  79.07%; FB1:  79.76  906
              ORG: precision:  73.56%; recall:  82.18%; FB1:  77.63  1498
              PER: precision:  90.92%; recall:  91.86%; FB1:  91.39  1861





loss on epoch 4 = 19.417307747527957
conlleval:
processed 51578 tokens with 5942 phrases; found: 5998 phrases; correct: 5213.
accuracy:  98.04%; precision:  86.91%; recall:  87.73%; FB1:  87.32
              LOC: precision:  93.40%; recall:  89.38%; FB1:  91.35  1758
             MISC: precision:  81.71%; recall:  80.91%; FB1:  81.31  913
              ORG: precision:  78.89%; recall:  80.24%; FB1:  79.56  1364
              PER: precision:  89.10%; recall:  94.95%; FB1:  91.93  1963





loss on epoch 5 = 15.37983898865059
conlleval:
processed 51578 tokens with 5942 phrases; found: 5986 phrases; correct: 5286.
accuracy:  98.30%; precision:  88.31%; recall:  88.96%; FB1:  88.63
              LOC: precision:  93.63%; recall:  91.94%; FB1:  92.78  1804
             MISC: precision:  82.95%; recall:  81.78%; FB1:  82.36  909
              ORG: precision:  79.93%; recall:  82.25%; FB1:  81.07  1380
              PER: precision:  91.92%; recall:  94.46%; FB1:  93.17  1893





loss on epoch 6 = 12.910344891017303
conlleval:
processed 51578 tokens with 5942 phrases; found: 6061 phrases; correct: 5351.
accuracy:  98.36%; precision:  88.29%; recall:  90.05%; FB1:  89.16
              LOC: precision:  92.77%; recall:  92.87%; FB1:  92.82  1839
             MISC: precision:  81.79%; recall:  83.30%; FB1:  82.54  939
              ORG: precision:  80.61%; recall:  85.23%; FB1:  82.86  1418
              PER: precision:  92.98%; recall:  94.14%; FB1:  93.55  1865





loss on epoch 7 = 10.706088378326967
conlleval:
processed 51578 tokens with 5942 phrases; found: 5996 phrases; correct: 5324.
accuracy:  98.39%; precision:  88.79%; recall:  89.60%; FB1:  89.19
              LOC: precision:  94.04%; recall:  91.94%; FB1:  92.98  1796
             MISC: precision:  83.26%; recall:  82.54%; FB1:  82.90  914
              ORG: precision:  80.28%; recall:  85.01%; FB1:  82.58  1420
              PER: precision:  92.93%; recall:  94.14%; FB1:  93.53  1866





loss on epoch 8 = 8.6190308642108
conlleval:
processed 51578 tokens with 5942 phrases; found: 6015 phrases; correct: 5354.
accuracy:  98.41%; precision:  89.01%; recall:  90.10%; FB1:  89.55
              LOC: precision:  93.83%; recall:  93.58%; FB1:  93.70  1832
             MISC: precision:  83.68%; recall:  82.86%; FB1:  83.27  913
              ORG: precision:  81.74%; recall:  83.45%; FB1:  82.58  1369
              PER: precision:  92.16%; recall:  95.11%; FB1:  93.61  1901





loss on epoch 9 = 7.4008088011760265
conlleval:
processed 51578 tokens with 5942 phrases; found: 6056 phrases; correct: 5382.
accuracy:  98.47%; precision:  88.87%; recall:  90.58%; FB1:  89.71
              LOC: precision:  92.85%; recall:  93.96%; FB1:  93.40  1859
             MISC: precision:  83.15%; recall:  84.06%; FB1:  83.60  932
              ORG: precision:  82.63%; recall:  85.16%; FB1:  83.88  1382
              PER: precision:  92.35%; recall:  94.41%; FB1:  93.37  1883





loss on epoch 10 = 6.220509472768754
conlleval:
processed 51578 tokens with 5942 phrases; found: 6015 phrases; correct: 5351.
accuracy:  98.45%; precision:  88.96%; recall:  90.05%; FB1:  89.50
              LOC: precision:  94.17%; recall:  93.20%; FB1:  93.68  1818
             MISC: precision:  80.62%; recall:  83.95%; FB1:  82.25  960
              ORG: precision:  81.99%; recall:  85.23%; FB1:  83.58  1394
              PER: precision:  93.43%; recall:  93.49%; FB1:  93.46  1843

----------------------------
-START-/O/START HERZLIYA/I-LOC/I-LOC ,/O/O Israel/I-LOC/I-LOC 1996-08-31/O/O -END-/O/END
Predicted:	 ['O', 'I-LOC', 'O', 'I-LOC', 'O', 'O']
Gold:		 ['START', 'I-LOC', 'O', 'I-LOC', 'O', 'END']
----------------------------
-START-/I-ORG/START Orrell/I-ORG/I-ORG 13/O/O Bath/I-ORG/I-ORG 56/O/O -END-/I-ORG/END
Predicted:	 ['I-ORG', 'I-ORG', 'O', 'I-ORG', 'O', 'I-ORG']
Gold:		 ['START', 'I-ORG', 'O', 'I-ORG', 'O', 'END']
----------------------------
-START-/O/START Standings/O/O (/



loss on epoch 11 = 5.3882121496717446
conlleval:
processed 51578 tokens with 5942 phrases; found: 5995 phrases; correct: 5393.
accuracy:  98.53%; precision:  89.96%; recall:  90.76%; FB1:  90.36
              LOC: precision:  94.47%; recall:  93.96%; FB1:  94.21  1827
             MISC: precision:  83.46%; recall:  84.82%; FB1:  84.13  937
              ORG: precision:  83.44%; recall:  84.94%; FB1:  84.18  1365
              PER: precision:  93.57%; recall:  94.79%; FB1:  94.17  1866





loss on epoch 12 = 4.838785940344678
conlleval:
processed 51578 tokens with 5942 phrases; found: 5985 phrases; correct: 5359.
accuracy:  98.43%; precision:  89.54%; recall:  90.19%; FB1:  89.86
              LOC: precision:  94.46%; recall:  92.81%; FB1:  93.63  1805
             MISC: precision:  80.14%; recall:  85.79%; FB1:  82.87  987
              ORG: precision:  85.74%; recall:  82.03%; FB1:  83.84  1283
              PER: precision:  92.30%; recall:  95.71%; FB1:  93.98  1910





loss on epoch 13 = 3.903855000215117
conlleval:
processed 51578 tokens with 5942 phrases; found: 6038 phrases; correct: 5385.
accuracy:  98.47%; precision:  89.19%; recall:  90.63%; FB1:  89.90
              LOC: precision:  92.78%; recall:  94.45%; FB1:  93.61  1870
             MISC: precision:  82.38%; recall:  84.16%; FB1:  83.26  942
              ORG: precision:  83.52%; recall:  85.38%; FB1:  84.44  1371
              PER: precision:  93.21%; recall:  93.87%; FB1:  93.54  1855





loss on epoch 14 = 3.576671508286381
conlleval:
processed 51578 tokens with 5942 phrases; found: 5951 phrases; correct: 5333.
accuracy:  98.42%; precision:  89.62%; recall:  89.75%; FB1:  89.68
              LOC: precision:  93.40%; recall:  94.01%; FB1:  93.71  1849
             MISC: precision:  80.96%; recall:  83.95%; FB1:  82.43  956
              ORG: precision:  84.38%; recall:  83.37%; FB1:  83.87  1325
              PER: precision:  94.12%; recall:  93.05%; FB1:  93.58  1821



In [16]:
#Evaluation on test set
char_lstm.write_predictions(sentences_test, 'test_pred_cnn_lstm.txt')
!wget https://raw.githubusercontent.com/aritter/twitter_nlp/master/data/annotated/wnut16/conlleval.pl
!paste test test_pred_cnn_lstm.txt | perl conlleval.pl -d "\t"

--2025-10-04 02:19:25--  https://raw.githubusercontent.com/aritter/twitter_nlp/master/data/annotated/wnut16/conlleval.pl
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 12754 (12K) [text/plain]
Saving to: ‘conlleval.pl.24’


2025-10-04 02:19:25 (5.55 MB/s) - ‘conlleval.pl.24’ saved [12754/12754]

processed 46666 tokens with 5648 phrases; found: 5647 phrases; correct: 4818.
accuracy:  97.36%; precision:  85.32%; recall:  85.30%; FB1:  85.31
              LOC: precision:  88.34%; recall:  90.89%; FB1:  89.60  1716
             MISC: precision:  70.17%; recall:  77.07%; FB1:  73.46  771
              ORG: precision:  81.28%; recall:  79.23%; FB1:  80.24  1619
              PER: precision:  93.77%; recall:  89.36%; FB1:  91.51  1541


## Conditional Random Fields 

In [17]:
import torch.nn.functional as F

class LSTM_CRFtagger(CharLSTMtagger):
    def __init__(self, DIM_EMB=10, DIM_CHAR_EMB=30, DIM_HID=10):
        super(LSTM_CRFtagger, self).__init__(DIM_EMB=DIM_EMB, DIM_HID=DIM_HID, DIM_CHAR_EMB=DIM_CHAR_EMB)
        self.transitions = nn.Parameter(torch.empty(self.NUM_TAGS, self.NUM_TAGS))
        nn.init.xavier_uniform_(self.transitions)

    def gold_score(self, lstm_scores, Y):
        if not torch.is_tensor(Y):
            Y = torch.tensor(Y, device=lstm_scores.device, dtype=torch.long)
        T = Y.size(0)

        # emissions along gold path
        emit = lstm_scores[torch.arange(T, device=lstm_scores.device), Y].sum()

        # transitions along gold path
        if T > 1:
            prev = Y[:-1]
            curr = Y[1:]
            trans = self.transitions[prev, curr].sum()
        else:
            trans = lstm_scores.new_tensor(0.0)

        return emit + trans


    #Forward algorithm for a single sentence
    #Efficiency will eventually be important here.  We recommend you start by
    #training on a single batch and make sure your code can memorize the
    #training data.  Then you can go back and re-write the inner loop using
    #tensor operations to speed things up.
    def forward_algorithm(self, lstm_scores, sLen):

        scores = lstm_scores[:sLen]           # (T, L)
        T, L = scores.size()

        # alpha_0 = emissions at t=0
        alpha = scores[0]                     # (L,)

        # DP over time: alpha_t(j) = logsum_i [ alpha_{t-1}(i) + A[i,j] ] + emit_t(j)
        for t in range(1, T):
            emit_t = scores[t].unsqueeze(0)   # (1, L)
            # (L,1) + (L,L) -> (L,L); then add emit_t (1,L) -> broadcast to (L,L)
            m = alpha.unsqueeze(1) + self.transitions
            alpha = torch.logsumexp(m, dim=0) + emit_t.squeeze(0)  # (L,)

        # log Z = logsumexp over last alpha
        return torch.logsumexp(alpha, dim=0)

    def conditional_log_likelihood(self, sentences, tags, train=True):
        (X, X_mask, X_char) = self.sentences2input_tensors(sentences)
        emissions = self.forward(X.cuda(), X_char.cuda(), train=train)  # (B, T_max, L)

        total_nll = emissions.new_tensor(0.0)
        for b in range(len(sentences)):
            sLen = len(sentences[b])
            # gold tag ids for this sentence
            Y_ids = torch.tensor([tag2i[t] for t in tags[b]], device=emissions.device, dtype=torch.long)

            # per-sentence scores
            gold = self.gold_score(emissions[b, :sLen, :], Y_ids)
            Z    = self.forward_algorithm(emissions[b], sLen)

            # negative log-likelihood
            total_nll = total_nll + (Z - gold)

        # average per sentence (you could also normalize by tokens)
        loss = total_nll / len(sentences)
        return loss


    def viterbi(self, lstm_scores, sLen):
        scores = lstm_scores[:sLen]        # (T, L)
        T, L = scores.size()

        delta = scores[0]                  # (L,)
        backp = torch.zeros(T, L, dtype=torch.long, device=scores.device)

        for t in range(1, T):
            # prev->curr: (L,1) + (L,L) -> (L,L)
            m = delta.unsqueeze(1) + self.transitions
            max_prev, arg_prev = torch.max(m, dim=0)   # (L,), (L,)
            delta = max_prev + scores[t]               # (L,)
            backp[t] = arg_prev

        best_last_score, best_last_tag = torch.max(delta, dim=0)

        # backtrack
        path = [best_last_tag.item()]
        for t in range(T-1, 0, -1):
            best_last_tag = backp[t, best_last_tag]
            path.append(best_last_tag.item())
        path.reverse()

        return torch.tensor(path, device=scores.device, dtype=torch.long), best_last_score

    #Computes Viterbi sequences on a batch of data.
    def viterbi_batch(self, sentences):
        viterbiSeqs = []
        (X, X_mask, X_char) = self.sentences2input_tensors(sentences)
        lstm_scores = self.forward(X.cuda(), X_char.cuda(), train=False)
        for s in range(len(sentences)):
            (viterbiSeq, ll) = self.viterbi(lstm_scores[s], len(sentences[s]))
            viterbiSeqs.append(viterbiSeq)
        return viterbiSeqs


    def forward(self, X, X_char, train=False):
        X = X.long()
        word_emb = self.embedding(X)                      # (B, T, DIM_EMB)

        B, T, CLEN = X_char.size()
        Xc = X_char.view(B * T, CLEN).long()             # (B*T, CLEN)
        ce = self.char_emb(Xc)                           # (B*T, CLEN, Cdim)
        ce = ce.permute(0, 2, 1)                         # (B*T, Cdim, CLEN)
        cf = F.relu(self.char_conv(ce))                  # (B*T, C_out, L')
        cp = F.max_pool1d(cf, kernel_size=cf.size(2)).squeeze(2)  # (B*T, C_out)
        char_repr = cp.view(B, T, -1)                    # (B, T, C_out)

        feats = torch.cat([word_emb, char_repr], dim=2)  # (B, T, DIM_EMB+C_out)
        feats = self.dropout(feats)

        lstm_out, _ = self.lstm(feats)                   # (B, T, 2*H)
        lstm_out = self.dropout(lstm_out)

        emissions = self.linear(lstm_out)                # (B, T, NUM_TAGS)
        return emissions


    def print_predictions(self, words, tags):
        Y_pred = self.inference(words)
        for i in range(len(words)):
            print("----------------------------")
            print(" ".join([f"{words[i][j]}/{Y_pred[i][j]}/{tags[i][j]}" for j in range(len(words[i]))]))
            print("Predicted:\t", [Y_pred[i][j] for j in range(len(words[i]))])
            print("Gold:\t\t", tags[i])

    #Need to use Viterbi this time.
    def inference(self, sentences, viterbi=True):
        pred = self.viterbi_batch(sentences)
        return [[i2tag[pred[i][j].item()] for j in range(len(sentences[i]))] for i in range(len(sentences))]

lstm_crf = LSTM_CRFtagger(DIM_EMB=300).cuda()
# print(lstm_crf.conditional_log_likelihood(sentences_dev[0:5], tags_dev[0:5]))

In [18]:
# This is a cell for debugging, feel free to change it as you like
print(lstm_crf.conditional_log_likelihood(sentences_dev[0:5], tags_dev[0:5]))

tensor(43.7689, device='cuda:0', grad_fn=<DivBackward0>)


In [19]:
#CharLSTM-CRF Training

from tqdm.auto import tqdm


#Get CoNLL evaluation script
os.system('wget https://raw.githubusercontent.com/aritter/twitter_nlp/master/data/annotated/wnut16/conlleval.pl')

def train_crf_lstm(sentences, tags, lstm):

    lr = 5e-4
    nEpochs = 15
    batchSize = 32
    optimizer = torch.optim.Adam(lstm.parameters(), lr=lr)
    for epoch in range(nEpochs):
        totalLoss = 0.0
        lstm.train()

        #Shuffle the sentences
        (sentences_shuffled, tags_shuffled) = shuffle_sentences(sentences, tags)
        for batch in tqdm(range(0, len(sentences), batchSize), leave=False):
        
            batch_sents = sentences_shuffled[batch:batch+batchSize]
            batch_tags  = tags_shuffled[batch:batch+batchSize]

            loss = lstm.conditional_log_likelihood(batch_sents, batch_tags, train=True)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(lstm.parameters(), 5.0)
            optimizer.step()

            totalLoss += loss.item()
            

        print(f"loss on epoch {epoch} = {totalLoss}")
        lstm.write_predictions(sentences_dev, 'dev_pred')   #Performance on dev set
        print('conlleval:')
        print(subprocess.Popen('paste dev dev_pred | perl conlleval.pl -d "\t"', shell=True, stdout=subprocess.PIPE,stderr=subprocess.STDOUT).communicate()[0].decode('UTF-8'))

        if epoch % 5 == 0:
            lstm.eval()
            s = random.sample(range(50), 5)
            lstm.print_predictions([sentences_train[i] for i in s], [tags_train[i] for i in s])   #Print predictions on train data (useful for debugging)

crf_lstm = LSTM_CRFtagger(DIM_HID=500, DIM_EMB=300, DIM_CHAR_EMB=30).cuda()
crf_lstm.init_glove(GloVe)
train_crf_lstm(sentences_train, tags_train, crf_lstm)             #Train on the full dataset
# train_crf_lstm(sentences_train[0:50], tags_train[0:50], crf_lstm)   #Train only the first batch (use this during development/debugging)

  0%|          | 0/469 [00:00<?, ?it/s]

loss on epoch 0 = 1617.3575000166893
conlleval:
processed 51578 tokens with 5942 phrases; found: 6060 phrases; correct: 4646.
accuracy:  96.12%; precision:  76.67%; recall:  78.19%; FB1:  77.42
              LOC: precision:  83.69%; recall:  84.92%; FB1:  84.30  1864
             MISC: precision:  72.07%; recall:  66.05%; FB1:  68.93  845
              ORG: precision:  65.11%; recall:  67.64%; FB1:  66.35  1393
              PER: precision:  80.18%; recall:  85.23%; FB1:  82.63  1958

----------------------------
-START-/START/START State/O/O media/O/O quoted/O/O China/I-LOC/I-LOC 's/O/O top/O/O negotiator/O/O with/O/O Taipei/I-LOC/I-LOC ,/O/O Tang/I-PER/I-PER Shubei/I-PER/I-PER ,/O/O as/O/O telling/O/O a/O/O visiting/O/O group/O/O from/O/O Taiwan/I-LOC/I-LOC on/O/O Wednesday/O/O that/O/O it/O/O was/O/O time/O/O for/O/O the/O/O rivals/O/O to/O/O hold/O/O political/O/O talks/O/O ./O/O -END-/END/END
Predicted:	 ['START', 'O', 'O', 'O', 'I-LOC', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-PER', 

  0%|          | 0/469 [00:00<?, ?it/s]

loss on epoch 1 = 641.5234677791595
conlleval:
processed 51578 tokens with 5942 phrases; found: 5870 phrases; correct: 4978.
accuracy:  97.35%; precision:  84.80%; recall:  83.78%; FB1:  84.29
              LOC: precision:  88.77%; recall:  89.93%; FB1:  89.35  1861
             MISC: precision:  81.50%; recall:  72.13%; FB1:  76.52  816
              ORG: precision:  74.23%; recall:  77.78%; FB1:  75.97  1405
              PER: precision:  90.49%; recall:  87.84%; FB1:  89.15  1788



  0%|          | 0/469 [00:00<?, ?it/s]

loss on epoch 2 = 406.0256187915802
conlleval:
processed 51578 tokens with 5942 phrases; found: 5880 phrases; correct: 5070.
accuracy:  97.66%; precision:  86.22%; recall:  85.32%; FB1:  85.77
              LOC: precision:  90.85%; recall:  89.22%; FB1:  90.03  1804
             MISC: precision:  80.09%; recall:  78.52%; FB1:  79.30  904
              ORG: precision:  78.37%; recall:  76.21%; FB1:  77.28  1304
              PER: precision:  90.20%; recall:  91.48%; FB1:  90.84  1868



  0%|          | 0/469 [00:00<?, ?it/s]

loss on epoch 3 = 293.2581206560135
conlleval:
processed 51578 tokens with 5942 phrases; found: 5926 phrases; correct: 5242.
accuracy:  98.11%; precision:  88.46%; recall:  88.22%; FB1:  88.34
              LOC: precision:  92.20%; recall:  91.34%; FB1:  91.77  1820
             MISC: precision:  84.38%; recall:  80.26%; FB1:  82.27  877
              ORG: precision:  80.54%; recall:  84.86%; FB1:  82.64  1413
              PER: precision:  92.84%; recall:  91.53%; FB1:  92.18  1816



  0%|          | 0/469 [00:00<?, ?it/s]

loss on epoch 4 = 222.05209976434708
conlleval:
processed 51578 tokens with 5942 phrases; found: 5897 phrases; correct: 5265.
accuracy:  98.20%; precision:  89.28%; recall:  88.61%; FB1:  88.94
              LOC: precision:  94.71%; recall:  90.64%; FB1:  92.63  1758
             MISC: precision:  81.60%; recall:  81.78%; FB1:  81.69  924
              ORG: precision:  82.33%; recall:  84.79%; FB1:  83.54  1381
              PER: precision:  93.18%; recall:  92.78%; FB1:  92.98  1834



  0%|          | 0/469 [00:00<?, ?it/s]

loss on epoch 5 = 178.63294922932982
conlleval:
processed 51578 tokens with 5942 phrases; found: 5933 phrases; correct: 5289.
accuracy:  98.21%; precision:  89.15%; recall:  89.01%; FB1:  89.08
              LOC: precision:  94.46%; recall:  91.94%; FB1:  93.19  1788
             MISC: precision:  82.42%; recall:  81.89%; FB1:  82.15  916
              ORG: precision:  83.50%; recall:  83.37%; FB1:  83.43  1339
              PER: precision:  91.38%; recall:  93.76%; FB1:  92.55  1890

----------------------------
-START-/START/START He/O/O said/O/O further/O/O scientific/O/O study/O/O was/O/O required/O/O and/O/O if/O/O it/O/O was/O/O found/O/O that/O/O action/O/O was/O/O needed/O/O it/O/O should/O/O be/O/O taken/O/O by/O/O the/O/O European/I-ORG/I-ORG Union/I-ORG/I-ORG ./O/O -END-/END/END
Predicted:	 ['START', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-ORG', 'I-ORG', 'O', 'END']
Gold:		 ['START', 'O', 'O', 'O', 'O',

  0%|          | 0/469 [00:00<?, ?it/s]

loss on epoch 6 = 144.50791764259338
conlleval:
processed 51578 tokens with 5942 phrases; found: 5963 phrases; correct: 5334.
accuracy:  98.31%; precision:  89.45%; recall:  89.77%; FB1:  89.61
              LOC: precision:  93.06%; recall:  94.18%; FB1:  93.61  1859
             MISC: precision:  84.61%; recall:  82.86%; FB1:  83.73  903
              ORG: precision:  84.72%; recall:  83.52%; FB1:  84.12  1322
              PER: precision:  91.54%; recall:  93.38%; FB1:  92.45  1879



  0%|          | 0/469 [00:00<?, ?it/s]

loss on epoch 7 = 113.77006647363305
conlleval:
processed 51578 tokens with 5942 phrases; found: 5941 phrases; correct: 5338.
accuracy:  98.35%; precision:  89.85%; recall:  89.84%; FB1:  89.84
              LOC: precision:  93.45%; recall:  92.43%; FB1:  92.94  1817
             MISC: precision:  86.18%; recall:  81.13%; FB1:  83.58  868
              ORG: precision:  82.80%; recall:  86.50%; FB1:  84.61  1401
              PER: precision:  93.37%; recall:  94.03%; FB1:  93.70  1855



  0%|          | 0/469 [00:00<?, ?it/s]

loss on epoch 8 = 96.19471228867769
conlleval:
processed 51578 tokens with 5942 phrases; found: 5930 phrases; correct: 5373.
accuracy:  98.43%; precision:  90.61%; recall:  90.42%; FB1:  90.52
              LOC: precision:  93.86%; recall:  93.96%; FB1:  93.91  1839
             MISC: precision:  87.80%; recall:  83.51%; FB1:  85.60  877
              ORG: precision:  85.04%; recall:  85.23%; FB1:  85.14  1344
              PER: precision:  92.73%; recall:  94.14%; FB1:  93.43  1870



  0%|          | 0/469 [00:00<?, ?it/s]

loss on epoch 9 = 76.89269664883614
conlleval:
processed 51578 tokens with 5942 phrases; found: 5958 phrases; correct: 5394.
accuracy:  98.44%; precision:  90.53%; recall:  90.78%; FB1:  90.66
              LOC: precision:  94.35%; recall:  92.71%; FB1:  93.52  1805
             MISC: precision:  85.36%; recall:  84.71%; FB1:  85.03  915
              ORG: precision:  84.49%; recall:  88.14%; FB1:  86.28  1399
              PER: precision:  93.96%; recall:  93.81%; FB1:  93.89  1839



  0%|          | 0/469 [00:00<?, ?it/s]

loss on epoch 10 = 65.9182491376996
conlleval:
processed 51578 tokens with 5942 phrases; found: 5940 phrases; correct: 5367.
accuracy:  98.42%; precision:  90.35%; recall:  90.32%; FB1:  90.34
              LOC: precision:  93.61%; recall:  94.07%; FB1:  93.84  1846
             MISC: precision:  87.47%; recall:  83.30%; FB1:  85.33  878
              ORG: precision:  83.83%; recall:  85.83%; FB1:  84.82  1373
              PER: precision:  93.33%; recall:  93.38%; FB1:  93.35  1843

----------------------------
-START-/START/START China/I-LOC/I-LOC on/O/O Thursday/O/O accused/O/O Taipei/I-LOC/I-LOC of/O/O spoiling/O/O the/O/O atmosphere/O/O for/O/O a/O/O resumption/O/O of/O/O talks/O/O across/O/O the/O/O Taiwan/I-LOC/I-LOC Strait/I-LOC/I-LOC with/O/O a/O/O visit/O/O to/O/O Ukraine/I-LOC/I-LOC by/O/O Taiwanese/I-MISC/I-MISC Vice/O/O President/O/O Lien/I-PER/I-PER Chan/I-PER/I-PER this/O/O week/O/O that/O/O infuriated/O/O Beijing/I-LOC/I-LOC ./O/O -END-/END/END
Predicted:	 ['START', 'I-

  0%|          | 0/469 [00:00<?, ?it/s]

loss on epoch 11 = 55.25656037032604
conlleval:
processed 51578 tokens with 5942 phrases; found: 5899 phrases; correct: 5365.
accuracy:  98.47%; precision:  90.95%; recall:  90.29%; FB1:  90.62
              LOC: precision:  94.57%; recall:  93.79%; FB1:  94.18  1822
             MISC: precision:  85.50%; recall:  85.03%; FB1:  85.26  917
              ORG: precision:  85.47%; recall:  84.64%; FB1:  85.05  1328
              PER: precision:  94.05%; recall:  93.54%; FB1:  93.79  1832



  0%|          | 0/469 [00:00<?, ?it/s]

loss on epoch 12 = 48.49529451131821
conlleval:
processed 51578 tokens with 5942 phrases; found: 5932 phrases; correct: 5378.
accuracy:  98.46%; precision:  90.66%; recall:  90.51%; FB1:  90.58
              LOC: precision:  94.60%; recall:  93.47%; FB1:  94.03  1815
             MISC: precision:  86.25%; recall:  84.38%; FB1:  85.31  902
              ORG: precision:  85.09%; recall:  85.98%; FB1:  85.53  1355
              PER: precision:  93.01%; recall:  93.92%; FB1:  93.46  1860



  0%|          | 0/469 [00:00<?, ?it/s]

loss on epoch 13 = 42.6387715190649
conlleval:
processed 51578 tokens with 5942 phrases; found: 5943 phrases; correct: 5385.
accuracy:  98.47%; precision:  90.61%; recall:  90.63%; FB1:  90.62
              LOC: precision:  95.26%; recall:  92.92%; FB1:  94.08  1792
             MISC: precision:  85.65%; recall:  86.12%; FB1:  85.88  927
              ORG: precision:  83.21%; recall:  87.25%; FB1:  85.18  1406
              PER: precision:  94.28%; recall:  93.05%; FB1:  93.66  1818



  0%|          | 0/469 [00:00<?, ?it/s]

loss on epoch 14 = 34.22291642986238
conlleval:
processed 51578 tokens with 5942 phrases; found: 5912 phrases; correct: 5366.
accuracy:  98.44%; precision:  90.76%; recall:  90.31%; FB1:  90.53
              LOC: precision:  94.61%; recall:  93.58%; FB1:  94.09  1817
             MISC: precision:  86.04%; recall:  84.92%; FB1:  85.48  910
              ORG: precision:  83.77%; recall:  85.83%; FB1:  84.79  1374
              PER: precision:  94.59%; recall:  93.00%; FB1:  93.79  1811



In [20]:
crf_lstm.eval()
crf_lstm.write_predictions(sentences_test, 'test_pred_cnn_lstm_crf.txt')
!wget https://raw.githubusercontent.com/aritter/twitter_nlp/master/data/annotated/wnut16/conlleval.pl
!paste test test_pred_cnn_lstm_crf.txt | perl conlleval.pl -d "\t"

--2025-10-04 02:50:05--  https://raw.githubusercontent.com/aritter/twitter_nlp/master/data/annotated/wnut16/conlleval.pl
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 12754 (12K) [text/plain]
Saving to: ‘conlleval.pl.26’


2025-10-04 02:50:05 (5.58 MB/s) - ‘conlleval.pl.26’ saved [12754/12754]

processed 46666 tokens with 5648 phrases; found: 5581 phrases; correct: 5004.
accuracy:  97.79%; precision:  89.66%; recall:  88.60%; FB1:  89.13
              LOC: precision:  91.05%; recall:  92.15%; FB1:  91.60  1688
             MISC: precision:  77.09%; recall:  76.21%; FB1:  76.65  694
              ORG: precision:  86.31%; recall:  86.94%; FB1:  86.62  1673
              PER: precision:  97.51%; recall:  92.02%; FB1:  94.69  1526


## Gradescope Submission Instructions

This is the end—congratulations! 🎉  

Follow the steps below carefully to prepare and submit your work in [Gradescope](https://www.gradescope.com/courses/1086056):

1. **Rename your notebook**  
   Save this notebook as `CS7650_p2_GTusername.ipynb`
   (replace `GTusername` with your actual GT username).  
Before submission, please:
- Remove extraneous/debugging cells and print statements.  
- Clear all outputs, then use **Runtime → Run all** to regenerate them.  
- Optionally, add comments in your code to explain your reasoning (recommended for grading clarity).

2. **Download source files**  
- `File → Download → Download .py`  
- `File → Download → Download .ipynb`

3. **Generate a PDF of your notebook outputs**  
It is important that training and evaluation logs (loss/accuracy over time) are fully visible.  
- **Do NOT rely on the LaTeX exporter (`--to pdf`)** – it will break on remote images and undefined macros.  
- Instead, use the Chromium-based web PDF exporter:  
  ```bash
  pip install "nbconvert[webpdf]" playwright
  python -m playwright install chromium
  jupyter nbconvert --to webpdf --allow-chromium-download CS7650_p2_GTusername.ipynb
  ```
- This avoids the common `\pandocbounded` / LaTeX error and ensures all outputs render.

4. **(Optional) Download model predictions from Colab**  
Use the left folder pane under *Files* to download:  
- `test_pred_lstm.txt`  
- `test_pred_cnn_lstm.txt`  
- `test_pred_cnn_lstm_crf.txt`  

5. **Upload your submission to Gradescope**  
Upload the following files (5 or 6 total if you include the optional CRF predictions):  
- `CS7650_p2_GTusername.ipynb`  
- `CS7650_p2_GTusername.py`  
- `CS7650_p2_GTusername.pdf`  
- `test_pred_lstm.txt`  
- `test_pred_cnn_lstm.txt`  
- `test_pred_cnn_lstm_crf.txt` *(optional)*

---

**Important notes:**
- Your implementation must meet the stated accuracy thresholds to receive full credit.  
- File names **must** match exactly as above.  
- You may submit multiple times before the deadline; choose your preferred submission under **Submission History** in Gradescope.