In [None]:
# CELL 1: Intuition 

'''
graph TD
  A[Original Text] --> B(AMR Parsing)
  B --> C{AMR Graph}
  C --> D[Graph Linearization]
  D --> E[Graph Tokenization]
  E --> F[BERT Embedding]
  F --> G[Model Input]
  '''

In [34]:
# CELL 2: AS2SP Model 
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import amrlib

# SAMPLE DATA GENERATION 
data = {
    "text": [
        "Scientists discover a new species of dinosaur in Argentina.",
        "Global tech conference announces 2024 event location as Singapore.",
        "New study shows coffee consumption linked to improved heart health."
    ],
    "summary": [
        "New dinosaur species found in Argentina.",
        "Tech conference 2024 to be held in Singapore.",
        "Coffee may benefit heart health, study finds."
    ]
}
df = pd.DataFrame(data)
df.to_csv("mini_dataset.csv", index=False)

# AMR PARSING & PREPROCESSING 
stog = amrlib.load_stog_model(device='cpu')
df = pd.read_csv("mini_dataset.csv")
texts = df['text'].tolist()

print("Parsing AMR graphs...")
amr_graphs = stog.parse_sents(texts)
graph_strings = [g if g else "" for g in amr_graphs]

# TOKENIZATION & VOCAB 
class Vocabulary:
    def __init__(self):
        self.word2idx = {"<pad>": 0, "<unk>": 1}
        self.idx2word = {0: "<pad>", 1: "<unk>"}
        
    def build_vocab(self, texts, max_size=2000):
        words = [word for text in texts for word in text.split()]
        word_counts = Counter(words)
        common_words = word_counts.most_common(max_size)
        
        for idx, (word, _) in enumerate(common_words, start=2):
            self.word2idx[word] = idx
            self.idx2word[idx] = word
            
VOCAB_SIZE = 2000
vocab = Vocabulary()
vocab.build_vocab(texts + df['summary'].tolist(), max_size=VOCAB_SIZE)

# DATASET & DATALOADER 
class SummaryDataset(Dataset):
    def __init__(self, graph_strings, summaries, vocab):
        self.graphs = [self.text_to_ids(gs, vocab) for gs in graph_strings]
        self.summaries = [self.text_to_ids(s, vocab) for s in summaries]
        
    def text_to_ids(self, text, vocab):
        return [vocab.word2idx.get(word, 1) for word in text.split()]
    
    def __len__(self):
        return len(self.graphs)
    
    def __getitem__(self, idx):
        return (
            torch.tensor(self.graphs[idx]),
            torch.tensor(self.summaries[idx])
        )

def collate_fn(batch):
    srcs, trgs = zip(*batch)
    srcs = torch.nn.utils.rnn.pad_sequence(srcs, padding_value=0).transpose(0, 1)
    trgs = torch.nn.utils.rnn.pad_sequence(trgs, padding_value=0).transpose(0, 1)
    return srcs, trgs

dataset = SummaryDataset(graph_strings, df['summary'].tolist(), vocab)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)

# MODEL ARCHITECTURE 
class AS2SP(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # Encoder components
        self.enc_embed = nn.Embedding(vocab_size, 128)
        self.encoder = nn.LSTM(128, 64, 
                             num_layers=1,
                             bidirectional=True,
                             batch_first=True)
        
        # Bridge layers
        self.hidden_proj = nn.Linear(64 * 2, 256)
        self.cell_proj = nn.Linear(64 * 2, 256)
        
        # Decoder components
        self.dec_embed = nn.Embedding(vocab_size, 128)
        self.decoder = nn.LSTM(128, 256, num_layers=1, batch_first=True)
        
        # Attention mechanism
        self.W_h = nn.Linear(64 * 2, 256)
        self.W_s = nn.Linear(256, 256)
        self.v = nn.Linear(256, 1)
        
        # Pointer-Generator 
        self.p_gen = nn.Linear(128 + 256 + 128, 1)  
        self.fc = nn.Linear(256, vocab_size)
        self.dropout = nn.Dropout(0.2)


    def forward(self, src_graph, trg_text):
        # Encoder forward
        enc_embedded = self.dropout(self.enc_embed(src_graph))
        enc_out, (h_n, c_n) = self.encoder(enc_embedded)
        
        # Process bidirectional states
        h_n = torch.cat([h_n[0], h_n[1]], dim=-1)
        c_n = torch.cat([c_n[0], c_n[1]], dim=-1)
        
        # Project to decoder dimensions
        decoder_hidden = self.hidden_proj(h_n).unsqueeze(0)
        decoder_cell = self.cell_proj(c_n).unsqueeze(0)
        
        # Decoder forward
        dec_embedded = self.dropout(self.dec_embed(trg_text))
        dec_out, _ = self.decoder(dec_embedded, (decoder_hidden, decoder_cell))
        
        # Attention 
        enc_proj = self.W_h(enc_out).unsqueeze(2) 
        dec_proj = self.W_s(dec_out).unsqueeze(1)  
        
        # Compute attention scores
        attn_energy = torch.tanh(enc_proj + dec_proj)
        attn_scores = self.v(attn_energy).squeeze(-1)  
        attn_weights = F.softmax(attn_scores, dim=1)
        
        # Transpose attention weights
        attn_weights = attn_weights.permute(0, 2, 1) 
        
        # Context vector calculation
        context = torch.bmm(attn_weights, enc_out)  
        
        # Pointer-generator
        p_gen_input = torch.cat([
            context,
            dec_out,
            dec_embedded
        ], dim=-1)
        p_gen = torch.sigmoid(self.p_gen(p_gen_input))
        
        # Final output
        output = self.fc(dec_out)
        return output, attn_weights, p_gen
# TRAINING SETUP 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AS2SP(VOCAB_SIZE).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=0)

# TRAINING LOOP 
for epoch in range(3):
    total_loss = 0
    for batch_idx, (src, trg) in enumerate(dataloader):
        src, trg = src.to(device), trg.to(device)
        
        # Forward 
        outputs, _, _ = model(src[:, :-1], trg[:, :-1])
        loss = criterion(outputs.reshape(-1, VOCAB_SIZE), 
                        trg[:, 1:].reshape(-1))
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 1 == 0:
            print(f"Epoch: {epoch+1}, Batch: {batch_idx}, Loss: {loss.item():.4f}")
    
    print(f"Epoch {epoch+1} Average Loss: {total_loss/len(dataloader):.4f}")

print("Training completed!")

Parsing AMR graphs...
Epoch: 1, Batch: 0, Loss: 7.6050
Epoch: 1, Batch: 1, Loss: 7.6442
Epoch 1 Average Loss: 7.6246
Epoch: 2, Batch: 0, Loss: 7.3074
Epoch: 2, Batch: 1, Loss: 7.4173
Epoch 2 Average Loss: 7.3623
Epoch: 3, Batch: 0, Loss: 7.0976
Epoch: 3, Batch: 1, Loss: 7.1774
Epoch 3 Average Loss: 7.1375
Training completed!


In [None]:
# CELL 3: Verification Checklist 

'''
# AMR Processing:

Text → AMR graphs using amrlib

Graph linearization to penman format

Graph tokenization as model input

# Model Architecture:

2-layer bidirectional LSTM encoder

1-layer LSTM decoder

Attention with coverage mechanism

Pointer-generator network

Dropout (0.3) and gradient clipping (2.0)

# Training Parameters:

Batch size 64/32

Learning rate 0.001

Adam optimizer

15 epochs
'''