In [42]:
import math, random, argparse
from pathlib import Path
from collections import Counter, defaultdict
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from seqeval.metrics import f1_score
from sklearn.metrics import classification_report
from tqdm import tqdm
import torch.backends.cudnn as cudnn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


In [2]:
PAD = "<PAD>"
UNK = "<UNK>"

In [3]:
@dataclass
class Config:
    # model
    emb_dim: int = 100
    hidden_dim: int = 128
    dropout: float = 0.2
    # training
    batch_size: int = 8
    lr: float = 1e-3
    epochs: int = 10
    max_grad_norm: float = 5.0
    seed: int = 17

cfg = Config()

def set_seed(seed=17):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(cfg.seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
def read_conll(path):
    """
    Reads a CoNLL file where each line is 'token [other cols ...] tag'
    Sentences are separated by blank lines.
    Returns: (list_of_token_lists, list_of_tag_lists)
    """
    sents, tags = [], []
    cur_w, cur_t = [], []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                if cur_w:
                    sents.append(cur_w); tags.append(cur_t)
                    cur_w, cur_t = [], []
                continue
            cols = line.split()
            token, tag = cols[0], cols[-1]
            cur_w.append(token)
            cur_t.append(tag)
    if cur_w:
        sents.append(cur_w); tags.append(cur_t)
    return sents, tags

def build_word_vocab(sents, min_freq=1):
    cnt = Counter([w for s in sents for w in s])
    itos = [PAD, UNK] + [w for w,c in cnt.items() if c >= min_freq and w not in (PAD, UNK)]
    stoi = {w:i for i,w in enumerate(itos)}
    return stoi, itos

def build_tag_vocab(tags):
    itos = sorted(list({t for seq in tags for t in seq}))
    stoi = {t:i for i,t in enumerate(itos)}
    return stoi, itos

In [5]:
class ConllDataset(Dataset):
    def __init__(self, sents, tags, w2i, t2i):
        self.instances = []
        for wseq, tseq in zip(sents, tags):
            x = [w2i.get(w, w2i[UNK]) for w in wseq]
            y = [t2i[t] for t in tseq]
            self.instances.append((x, y, len(x)))

    def __len__(self): return len(self.instances)
    def __getitem__(self, idx): return self.instances[idx]

def pad_batch(batch, pad_idx):
    maxlen = max(len(x) for x,_,_ in batch)
    xs, ys, lens, mask = [], [], [], []
    for x, y, L in batch:
        pad = maxlen - len(x)
        xs.append(x + [pad_idx]*pad)
        ys.append(y + [-100]*pad)
        lens.append(L)
        mask.append([1]*L + [0]*pad)
    return (torch.tensor(xs, dtype=torch.long),
            torch.tensor(ys, dtype=torch.long),
            torch.tensor(lens, dtype=torch.long),
            torch.tensor(mask, dtype=torch.long))


In [6]:
class LSTMTagger(nn.Module):
    def __init__(self, vocab_size, tag_size, emb_dim=100, hidden_dim=128, pad_idx=0, dropout=0.2):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(
            emb_dim, hidden_dim,
            num_layers=1,
            bidirectional=False,
            batch_first=True
        )
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim, tag_size)

    def forward(self, x, lengths):
        emb = self.emb(x)  # (B, T, E)
        out, _ = self.lstm(emb)
        out = self.dropout(out)
        return self.fc(out)



In [7]:
# CELL 5: training & evaluation

def ids_to_tags(id_seqs, tag_itos, mask):
    out = []
    for seq, m in zip(id_seqs, mask):
        out.append([tag_itos[i] for i, keep in zip(seq, m) if keep == 1])
    return out

@torch.no_grad()
def evaluate(model, loader, tag_itos, device):
    model.eval()
    y_true_all, y_pred_all = [], []
    for x, y, lens, mask in loader:
        x, y, lens, mask = x.to(device), y.to(device), lens.to(device), mask.to(device)
        logits = model(x, lens)  
        pred = logits.argmax(-1)
        y_true = ids_to_tags([seq.clamp_min(0).tolist() for seq in y], tag_itos, mask.tolist())
        y_pred = ids_to_tags([seq.tolist() for seq in pred], tag_itos, mask.tolist())
        y_true_all.extend(y_true)
        y_pred_all.extend(y_pred)
    f1 = f1_score(y_true_all, y_pred_all)
    return f1, (y_true_all, y_pred_all)

def train_loop(model, train_dl, dev_dl, tag_itos, cfg, device):
    criterion = nn.CrossEntropyLoss(ignore_index=-100)
    optim = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
    best_f1, best_state = -1, None
    history = {"train_loss": [], "dev_f1": []}

    for ep in range(1, cfg.epochs+1):
        model.train()
        running = 0.0
        pbar = tqdm(train_dl, desc=f"Epoch {ep}")
        for x, y, lens, _mask in pbar:
            x, y, lens = x.to(device), y.to(device), lens.to(device)
            optim.zero_grad()
            logits = model(x, lens)  # (B, T, C)
            loss = criterion(logits.reshape(-1, logits.size(-1)),
                            y.reshape(-1))
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=cfg.max_grad_norm)
            optim.step()
            running += loss.item()
            pbar.set_postfix(loss=f"{running / (pbar.n or 1):.4f}")
        history["train_loss"].append(running / max(1, len(train_dl)))

        dev_f1, _ = evaluate(model, dev_dl, tag_itos, device)
        history["dev_f1"].append(dev_f1)
        print(f"[Dev] F1 = {dev_f1:.4f}")
        if dev_f1 > best_f1:
            best_f1 = dev_f1
            best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}

        torch.cuda.empty_cache()
    if best_state:
        model.load_state_dict(best_state)
    return history


In [None]:
from sklearn.model_selection import train_test_split

conll_path = "../data/mex6/conll_1000.conll"
all_sents, all_tags = read_conll(conll_path)

pairs = [(s,t) for s,t in zip(all_sents, all_tags) if len(s) > 0]
all_sents, all_tags = map(list, zip(*pairs)) if pairs else ([], [])

print("Total sentences:", len(all_sents))

# First split: train vs temp (dev+test)
train_s, temp_s, train_t, temp_t = train_test_split(
    all_sents, all_tags, test_size=0.3, random_state=cfg.seed
)

dev_s, test_s, dev_t, test_t = train_test_split(
    temp_s, temp_t, test_size=0.5, random_state=cfg.seed
)

# Truncate or split long sequences into chunks of max_len tokens
def truncate_dataset(sents, tags, max_len=100):
    new_sents, new_tags = [], []
    for s, t in zip(sents, tags):
        while len(s) > max_len:
            new_sents.append(s[:max_len])
            new_tags.append(t[:max_len])
            s, t = s[max_len:], t[max_len:]
        if s:  # leftover shorter than max_len
            new_sents.append(s)
            new_tags.append(t)
    return new_sents, new_tags

# Apply to all splits
train_s, train_t = truncate_dataset(train_s, train_t, max_len=100)
dev_s, dev_t     = truncate_dataset(dev_s, dev_t, max_len=100)
test_s, test_t   = truncate_dataset(test_s, test_t, max_len=100)

w2i, i2w = build_word_vocab(train_s)
t2i, i2t = build_tag_vocab(train_t)

collate = lambda b: pad_batch(b, w2i[PAD])
tr_ds = ConllDataset(train_s, train_t, w2i, t2i)
dv_ds = ConllDataset(dev_s, dev_t, w2i, t2i)
te_ds = ConllDataset(test_s, test_t, w2i, t2i)

tr_dl = DataLoader(tr_ds, batch_size=cfg.batch_size, shuffle=True,  collate_fn=collate)
dv_dl = DataLoader(dv_ds, batch_size=cfg.batch_size, shuffle=False, collate_fn=collate)
te_dl = DataLoader(te_ds, batch_size=cfg.batch_size, shuffle=False, collate_fn=collate)

model = LSTMTagger(len(i2w), len(i2t), cfg.emb_dim, cfg.hidden_dim, pad_idx=w2i[PAD], dropout=cfg.dropout).to(device)
history = train_loop(model, tr_dl, dv_dl, i2t, cfg, device)

# Final test
test_f1, (y_true, y_pred) = evaluate(model, te_dl, i2t, device)

print(f"[Test] F1 = {test_f1:.4f}")
print(classification_report(y_true, y_pred, digits=4))

Total sentences: 346535


Epoch 1: 100%|██████████| 31637/31637 [09:15<00:00, 56.97it/s, loss=0.0973] 


[Dev] F1 = 0.0341


Epoch 2: 100%|██████████| 31637/31637 [09:17<00:00, 56.79it/s, loss=0.0911]


[Dev] F1 = 0.0735


Epoch 3: 100%|██████████| 31637/31637 [09:20<00:00, 56.46it/s, loss=0.0878]


[Dev] F1 = 0.0595


Epoch 4: 100%|██████████| 31637/31637 [09:19<00:00, 56.53it/s, loss=0.0844]


[Dev] F1 = 0.0726


Epoch 5: 100%|██████████| 31637/31637 [09:18<00:00, 56.69it/s, loss=0.0802]


[Dev] F1 = 0.0662


Epoch 6: 100%|██████████| 31637/31637 [09:25<00:00, 55.93it/s, loss=0.0762]


[Dev] F1 = 0.0543


Epoch 7: 100%|██████████| 31637/31637 [09:03<00:00, 58.26it/s, loss=0.0725]


[Dev] F1 = 0.0520


Epoch 8: 100%|██████████| 31637/31637 [08:44<00:00, 60.32it/s, loss=0.0695]


[Dev] F1 = 0.0524


Epoch 9: 100%|██████████| 31637/31637 [08:45<00:00, 60.17it/s, loss=0.0667]


[Dev] F1 = 0.0439


Epoch 10: 100%|██████████| 31637/31637 [08:42<00:00, 60.56it/s, loss=0.0649] 


[Dev] F1 = 0.0478
[Test] F1 = 0.0630
              precision    recall  f1-score   support

           _     0.2643    0.0358    0.0630      3884

   micro avg     0.2643    0.0358    0.0630      3884
   macro avg     0.2643    0.0358    0.0630      3884
weighted avg     0.2643    0.0358    0.0630      3884



In [44]:
torch.save(model.state_dict(), "../outputs/mex6/lstm_model_weights.pth")

In [43]:
flat_true = [tag for sent in y_true for tag in sent]
flat_pred = [tag for sent in y_pred for tag in sent]

labels = ["B", "I", "O"]
print(classification_report(flat_true, flat_pred, labels=labels, digits=4))

              precision    recall  f1-score   support

           B     0.5761    0.0360    0.0678      3884
           I     0.7172    0.0183    0.0358     30706
           O     0.9849    0.9999    0.9923   2209703

    accuracy                         0.9848   2244293
   macro avg     0.7594    0.3514    0.3653   2244293
weighted avg     0.9805    0.9848    0.9776   2244293

