In [2]:
import argparse
import re
from collections import Counter
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import pandas as pd
import random
from nltk.translate.bleu_score import corpus_bleu
from rouge_score import rouge_scorer

In [None]:
INPUT_CSV_PATH = "../3A2M_EXTENDED.csv"
OUTPUT_DIR = "./gpt2-ner2directions-optimized"
TOKENIZED_CACHE = "./tokenized_data"
TRAIN_FILE = "train_ner2dir.txt"
VAL_FILE = "val_ner2dir.txt"
BLOCK_SIZE = 512  # smaller blocks to reduce padding waste
NUM_PROC = 4      # CPU cores for preprocessing

In [4]:
df = pd.read_csv(INPUT_CSV_PATH)[:int(1e5)]
if "NER" not in df.columns or "directions" not in df.columns:
    raise ValueError("CSV must contain 'NER' and 'directions' columns.")
DF = df.dropna(subset=["NER", "directions"]).reset_index(drop=True)

def format_example_from_ner(ner_text: str, directions_text: str) -> str:
    parts = ["NER:"]
    for ent in ner_text.split(","):
        ent = ent.strip()
        if ent:
            parts.append(f"- {ent}")
    parts.append("")
    parts.append("Directions:")
    for idx, step in enumerate(directions_text.split("\n"), start=1):
        step = step.strip()
        if step:
            parts.append(f"{idx}. {step}")
    return "\n".join(parts) + "\n\n"

examples = [format_example_from_ner(row["NER"], row["directions"]) for _, row in DF.iterrows()]
random.seed(42)
random.shuffle(examples)
split_idx = int(0.9 * len(examples))
with open(TRAIN_FILE, "w", encoding="utf-8") as f:
    f.writelines(examples[:split_idx])
with open(VAL_FILE, "w", encoding="utf-8") as f:
    f.writelines(examples[split_idx:])

In [5]:
# -----------------------------
# Data utilities
# -----------------------------

def parse_examples(path):
    """
    Read examples from a file where each example is two paragraphs:
      NER:
      - entity1
      - entity2
      ...

      Directions:
      1. step one
      2. step two
      ...
    Returns list of (input_str, target_str)
    """
    examples = []
    with open(path, encoding='utf-8') as f:
        text = f.read().strip()
    for block in text.split("\n\nNER:"):
        block = block.strip()
        if not block:
            continue
        if not block.startswith("NER:"):
            block = "NER:" + "\n" + block
        parts = block.split("\n\nDirections:")
        if len(parts) != 2:
            continue
        ner_block, dir_block = parts
        # extract entities
        ents = re.findall(r"^-\s*(.+)$", ner_block, flags=re.MULTILINE)
        # extract steps
        steps = re.findall(r"^\d+\.\s*(.+)$", dir_block, flags=re.MULTILINE)
        input_str = " ".join(ents)
        target_str = " <sep> ".join(steps)
        examples.append((input_str, target_str))
    return examples

class Vocab:
    def __init__(self, tokens, min_freq=2, reserved_tokens=None):
        counter = Counter(tokens)
        self.reserved = ["<pad>", "<bos>", "<eos>", "<unk>"] + (reserved_tokens or [])
        # only keep tokens with freq >= min_freq
        self.tokens = [t for t, c in counter.items() if c >= min_freq]
        self.idx_to_token = self.reserved + self.tokens
        self.token_to_idx = {t: i for i, t in enumerate(self.idx_to_token)}

    def __len__(self): return len(self.idx_to_token)
    def __getitem__(self, token): return self.token_to_idx.get(token, self.token_to_idx['<unk>'])
    def to_tokens(self, indices): return [self.idx_to_token[i] for i in indices]

class Seq2SeqDataset(Dataset):
    def __init__(self, examples, src_vocab, tgt_vocab, max_src_len=50, max_tgt_len=100):
        self.examples = examples
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.max_src_len = max_src_len
        self.max_tgt_len = max_tgt_len

    def __len__(self): return len(self.examples)

    def __getitem__(self, idx):
        src, tgt = self.examples[idx]
        # tokenize
        src_tokens = src.split()[:self.max_src_len]
        tgt_tokens = tgt.split()[:self.max_tgt_len]
        # indices
        src_ids = [self.src_vocab[t] for t in src_tokens]
        # add BOS/EOS to target
        tgt_ids = [self.tgt_vocab['<bos>']] + [self.tgt_vocab[t] for t in tgt_tokens] + [self.tgt_vocab['<eos>']]
        return torch.tensor(src_ids, dtype=torch.long), torch.tensor(tgt_ids, dtype=torch.long)

def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)
    src_lens = [len(x) for x in src_batch]
    tgt_lens = [len(x) for x in tgt_batch]
    src_padded = pad_sequence(src_batch, batch_first=True, padding_value=0)
    tgt_padded = pad_sequence(tgt_batch, batch_first=True, padding_value=0)
    return src_padded, torch.tensor(src_lens), tgt_padded, torch.tensor(tgt_lens)

In [6]:
# -----------------------------
# Model
# -----------------------------
class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_size, hid_size, num_layers=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=0)
        self.lstm = nn.LSTM(emb_size, hid_size, num_layers, batch_first=True)

    def forward(self, x, lengths):
        embedded = self.embedding(x)
        packed = nn.utils.rnn.pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)
        outputs, (h, c) = self.lstm(packed)
        return h, c

class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_size, hid_size, num_layers=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=0)
        self.lstm = nn.LSTM(emb_size, hid_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hid_size, vocab_size)

    def forward(self, y, h, c):
        emb = self.embedding(y)
        outputs, (h, c) = self.lstm(emb, (h, c))
        logits = self.fc(outputs)
        return logits, (h, c)

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, src_len, tgt, tgt_len=None):
        h, c = self.encoder(src, src_len)
        logits, _ = self.decoder(tgt[:, :-1], h, c)
        return logits

    def generate(self, src, src_len, max_len=100, bos_idx=1, eos_idx=2):
            """
            Greedy decoding: repeatedly feed back the last token.
            """
            self.eval()
            h, c = self.encoder(src, src_len)
            batch_size = src.size(0)
            # start with BOS
            generated = torch.full((batch_size, 1), bos_idx, dtype=torch.long, device=src.device)
            for _ in range(max_len):
                logits, (h, c) = self.decoder(generated[:, -1:], h, c)
                next_token = logits.argmax(-1)  # (batch, 1)
                generated = torch.cat([generated, next_token], dim=1)
                # if all sequences have emitted EOS, stop early
                if (next_token == eos_idx).all():
                    break
            return generated

In [7]:
# -----------------------------
# Training
# -----------------------------
def train(args):
    # parse data
    train_ex = parse_examples(args.train_file)
    val_ex = parse_examples(args.val_file)
    scorer = rouge_scorer.RougeScorer(['rouge1','rougeL'], use_stemmer=True)
    # build vocabs
    all_src_tok = []
    all_tgt_tok = []
    for s, t in train_ex:
        all_src_tok += s.split()
        all_tgt_tok += t.split()
    src_vocab = Vocab(all_src_tok, min_freq=2)
    tgt_vocab = Vocab(all_tgt_tok, min_freq=2)

    train_ds = Seq2SeqDataset(train_ex, src_vocab, tgt_vocab)
    val_ds   = Seq2SeqDataset(val_ex, src_vocab, tgt_vocab)
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader   = DataLoader(val_ds, batch_size=args.batch_size, collate_fn=collate_fn)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    encoder = Encoder(len(src_vocab), args.emb_size, args.hid_size).to(device)
    decoder = Decoder(len(tgt_vocab), args.emb_size, args.hid_size).to(device)
    model = Seq2Seq(encoder, decoder).to(device)

    criterion = nn.CrossEntropyLoss(ignore_index=0)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    scorer = rouge_scorer.RougeScorer(['rouge1','rouge2','rougeL'], use_stemmer=True)

    for epoch in range(1, args.epochs + 1):
        model.train()
        total_loss = 0
        for src, src_len, tgt, _ in tqdm(train_loader, desc=f"Train Epoch {epoch}"):
            src, src_len, tgt = src.to(device), src_len.to(device), tgt.to(device)
            optimizer.zero_grad()
            logits = model(src, src_len, tgt)
            loss = criterion(logits.reshape(-1, logits.size(-1)), tgt[:,1:].reshape(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch} Train Loss: {total_loss/len(train_loader):.4f}")

        model.eval()
        val_loss = 0
        refs, hyps = [], []
        with torch.no_grad():
            for src, src_len, tgt, _ in tqdm(val_loader, desc="Validate"):
                src, src_len, tgt = src.to(device), src_len.to(device), tgt.to(device)
                logits = model(src, src_len, tgt)
                val_loss += criterion(logits.reshape(-1, logits.size(-1)), tgt[:,1:].reshape(-1)).item()
                gen_ids = model.generate(src, src_len, max_len=args.max_tgt_len,
                                         bos_idx=tgt_vocab['<bos>'], eos_idx=tgt_vocab['<eos>'])
                for i in range(src.size(0)):
                    ref = [tgt_vocab.idx_to_token[id] for id in tgt[i,1:].tolist() if id not in (0, tgt_vocab['<eos>'])]
                    hyp = [tgt_vocab.idx_to_token[id] for id in gen_ids[i,1:].tolist() if id not in (0, tgt_vocab['<eos>'])]
                    refs.append([ref])
                    hyps.append(hyp)
        avg_val_loss = val_loss / len(val_loader)
        bleu = corpus_bleu(refs, hyps)
        rouge_scores = [scorer.score(' '.join(r[0]), ' '.join(h)) for r,h in zip(refs, hyps)]
        avg_rouge1 = sum(s['rouge1'].fmeasure for s in rouge_scores) / len(rouge_scores)
        avg_rouge2 = sum(s['rouge2'].fmeasure for s in rouge_scores) / len(rouge_scores)
        avg_rougeL = sum(s['rougeL'].fmeasure for s in rouge_scores) / len(rouge_scores)
        print(
            f"Epoch {epoch} Val Loss: {avg_val_loss:.4f} "
            f"| BLEU: {bleu:.4f} "
            f"| ROUGE-1: {avg_rouge1:.4f} "
            f"| ROUGE-2: {avg_rouge2:.4f} "
            f"| ROUGE-L: {avg_rougeL:.4f}"
        )

In [None]:
# --- Notebook usage cell ---

from types import SimpleNamespace

# 1) define your paths & hyper-params
args = SimpleNamespace(
    train_file="train_ner2dir.txt",
    val_file="val_ner2dir.txt",
    epochs=10,
    batch_size=32,
    emb_size=128,
    hid_size=256,
    lr=1e-3,
    max_tgt_len=100
)

# 2) kick off training
train(args)

Train Epoch 1: 100%|██████████| 2813/2813 [03:11<00:00, 14.67it/s]


Epoch 1 Train Loss: 4.8347


Validate: 100%|██████████| 313/313 [00:32<00:00,  9.59it/s]


Epoch 1 Val Loss: 3.9977 | BLEU: 0.0114 | ROUGE-1: 0.1783 | ROUGE-2: 0.0323 | ROUGE-L: 0.1346


Train Epoch 2: 100%|██████████| 2813/2813 [03:11<00:00, 14.71it/s]


Epoch 2 Train Loss: 3.7613


Validate: 100%|██████████| 313/313 [00:32<00:00,  9.70it/s]


Epoch 2 Val Loss: 3.6215 | BLEU: 0.0223 | ROUGE-1: 0.2589 | ROUGE-2: 0.0601 | ROUGE-L: 0.1851


Train Epoch 3: 100%|██████████| 2813/2813 [03:11<00:00, 14.70it/s]


Epoch 3 Train Loss: 3.4406


Validate: 100%|██████████| 313/313 [00:32<00:00,  9.67it/s]


Epoch 3 Val Loss: 3.4709 | BLEU: 0.0252 | ROUGE-1: 0.2678 | ROUGE-2: 0.0612 | ROUGE-L: 0.1863


Train Epoch 4: 100%|██████████| 2813/2813 [03:11<00:00, 14.71it/s]


Epoch 4 Train Loss: 3.2563


Validate: 100%|██████████| 313/313 [00:32<00:00,  9.67it/s]


Epoch 4 Val Loss: 3.3972 | BLEU: 0.0281 | ROUGE-1: 0.2871 | ROUGE-2: 0.0709 | ROUGE-L: 0.1995


Train Epoch 5: 100%|██████████| 2813/2813 [03:11<00:00, 14.67it/s]


Epoch 5 Train Loss: 3.1288


Validate: 100%|██████████| 313/313 [00:32<00:00,  9.62it/s]


Epoch 5 Val Loss: 3.3610 | BLEU: 0.0291 | ROUGE-1: 0.2968 | ROUGE-2: 0.0738 | ROUGE-L: 0.2062
