In [None]:
# CELL 1: Intuition 

'''
graph TD
  A[Original Text] --> B(AMR Parsing)
  B --> C{AMR Graph}
  C --> D[Graph Linearization]
  D --> E[Graph Tokenization]
  E --> F[BERT Embedding]
  F --> G[Model Input]
  '''

In [10]:
# CELL 2: AS2SP Model 
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import amrlib

# SAMPLE DATA GENERATION 
data = {
    "text": [
        "Scientists discover a new species of dinosaur in Argentina.",
        "Global tech conference announces 2024 event location as Singapore.",
        "New study shows coffee consumption linked to improved heart health."
    ],
    "summary": [
        "New dinosaur species found in Argentina.",
        "Tech conference 2024 to be held in Singapore.",
        "Coffee may benefit heart health, study finds."
    ]
}
df = pd.DataFrame(data)
df.to_csv("mini_dataset.csv", index=False)

# AMR PARSING & PREPROCESSING 
stog = amrlib.load_stog_model(device='cpu')
df = pd.read_csv("mini_dataset.csv")
texts = df['text'].tolist()

print("Parsing AMR graphs...")
amr_graphs = stog.parse_sents(texts)
graph_strings = [g if g else "" for g in amr_graphs]

# TOKENIZATION & VOCAB 
class Vocabulary:
    def __init__(self):
        self.word2idx = {"<pad>": 0, "<unk>": 1, "<sos>": 2, "<eos>": 3}
        self.idx2word = {0: "<pad>", 1: "<unk>", 2: "<sos>", 3: "<eos>"}
        
    def build_vocab(self, texts, max_size=2000):
        words = [word for text in texts for word in text.split()]
        word_counts = Counter(words)
        common_words = word_counts.most_common(max_size)
        
        for idx, (word, _) in enumerate(common_words, start=4):
            self.word2idx[word] = idx
            self.idx2word[idx] = word
            
VOCAB_SIZE = 2000
vocab = Vocabulary()
vocab.build_vocab(graph_strings + df['summary'].tolist(), max_size=VOCAB_SIZE)

# DATASET & DATALOADER 
class SummaryDataset(Dataset):
    def __init__(self, graph_strings, summaries, vocab):
        self.graphs = [self.text_to_ids(gs, vocab) for gs in graph_strings]
        self.summaries = [self.text_to_ids(s, vocab, add_special=True) for s in summaries]
        
    def text_to_ids(self, text, vocab, add_special=False):
        ids = [vocab.word2idx.get(word, 1) for word in text.split()]
        if add_special:
            ids = [vocab.word2idx["<sos>"]] + ids + [vocab.word2idx["<eos>"]]
        return ids
    
    def __len__(self):
        return len(self.graphs)
    
    def __getitem__(self, idx):
        return (
            torch.tensor(self.graphs[idx]),
            torch.tensor(self.summaries[idx])
        )

def collate_fn(batch):
    srcs, trgs = zip(*batch)
    srcs = torch.nn.utils.rnn.pad_sequence(srcs, padding_value=0).transpose(0, 1)
    trgs = torch.nn.utils.rnn.pad_sequence(trgs, padding_value=0).transpose(0, 1)
    return srcs, trgs

dataset = SummaryDataset(graph_strings, df['summary'].tolist(), vocab)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)

# MODEL ARCHITECTURE 
class AS2SP(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.enc_embed = nn.Embedding(vocab_size, 128)
        self.encoder = nn.LSTM(128, 64, 
                             num_layers=1,
                             bidirectional=True,
                             batch_first=True)
        self.hidden_proj = nn.Linear(64 * 2, 256)
        self.cell_proj = nn.Linear(64 * 2, 256)
        self.dec_embed = nn.Embedding(vocab_size, 128)
        self.decoder = nn.LSTM(128, 256, num_layers=1, batch_first=True)
        self.W_h = nn.Linear(64 * 2, 256)
        self.W_s = nn.Linear(256, 256)
        self.v = nn.Linear(256, 1)
        self.p_gen = nn.Linear(128 + 256 + 128, 1)
        self.fc = nn.Linear(256, vocab_size)
        self.dropout = nn.Dropout(0.2)

    def forward(self, src_graph, trg_text):
        enc_embedded = self.dropout(self.enc_embed(src_graph))
        enc_out, (h_n, c_n) = self.encoder(enc_embedded)
        
        h_n = torch.cat([h_n[0], h_n[1]], dim=-1)
        c_n = torch.cat([c_n[0], c_n[1]], dim=-1)
        
        decoder_hidden = self.hidden_proj(h_n).unsqueeze(0)
        decoder_cell = self.cell_proj(c_n).unsqueeze(0)
        
        dec_embedded = self.dropout(self.dec_embed(trg_text))
        dec_out, _ = self.decoder(dec_embedded, (decoder_hidden, decoder_cell))
        
        enc_proj = self.W_h(enc_out).unsqueeze(2) 
        dec_proj = self.W_s(dec_out).unsqueeze(1)  
        
        attn_energy = torch.tanh(enc_proj + dec_proj)
        attn_scores = self.v(attn_energy).squeeze(-1)
        attn_weights = F.softmax(attn_scores, dim=1)
        attn_weights = attn_weights.permute(0, 2, 1)
        context = torch.bmm(attn_weights, enc_out)
        
        p_gen_input = torch.cat([context, dec_out, dec_embedded], dim=-1)
        p_gen = torch.sigmoid(self.p_gen(p_gen_input))
        
        output = self.fc(dec_out)
        return output, attn_weights, p_gen

# TRAINING SETUP 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AS2SP(VOCAB_SIZE).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=0)

# TRAINING LOOP 
for epoch in range(3):
    total_loss = 0
    for batch_idx, (src, trg) in enumerate(dataloader):
        src, trg = src.to(device), trg.to(device)
        
        outputs, _, _ = model(src[:, :-1], trg[:, :-1])
        loss = criterion(outputs.reshape(-1, VOCAB_SIZE), 
                        trg[:, 1:].reshape(-1))
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 1 == 0:
            print(f"Epoch: {epoch+1}, Batch: {batch_idx}, Loss: {loss.item():.4f}")
    
    print(f"Epoch {epoch+1} Average Loss: {total_loss/len(dataloader):.4f}")

print("Training completed!")

# GENERATION FUNCTION 
def generate_summary(model, graph_string, vocab, max_len=20):
    model.eval()
    tokenized = [vocab.word2idx.get(word, 1) for word in graph_string.split()]
    src = torch.tensor([tokenized]).to(device)
    
    decoder_input = torch.tensor([[vocab.word2idx["<sos>"]]]).to(device)
    summary = []
    
    with torch.no_grad():
        enc_embedded = model.enc_embed(src)
        enc_out, (h_n, c_n) = model.encoder(enc_embedded)
        
        h_n = torch.cat([h_n[0], h_n[1]], dim=-1)
        c_n = torch.cat([c_n[0], c_n[1]], dim=-1)
        decoder_hidden = model.hidden_proj(h_n).unsqueeze(0)
        decoder_cell = model.cell_proj(c_n).unsqueeze(0)
        
        for _ in range(max_len):
            dec_embedded = model.dec_embed(decoder_input)
            dec_out, (decoder_hidden, decoder_cell) = model.decoder(
                dec_embedded, (decoder_hidden, decoder_cell)
            )
            
            output = model.fc(dec_out)
            next_token = output.argmax(-1)[:, -1].item()
            
            if next_token == vocab.word2idx["<eos>"]:
                break
            
            summary.append(vocab.idx2word.get(next_token, "<unk>"))
            decoder_input = torch.tensor([[next_token]]).to(device)
            
    return " ".join(summary)

# GENERATE SUMMARIES FOR ALL SAMPLES
print("\nGenerated Summaries:")
for i in range(len(df)):
    input_graph = graph_strings[i]
    generated = generate_summary(model, input_graph, vocab)
    print(f"Original Text: {df['text'][i]}")
    print(f"Generated Summary: {generated}")
    print(f"Reference Summary: {df['summary'][i]}\n{'-'*50}")

Parsing AMR graphs...
Epoch: 1, Batch: 0, Loss: 7.5917
Epoch: 1, Batch: 1, Loss: 7.6052
Epoch 1 Average Loss: 7.5984
Epoch: 2, Batch: 0, Loss: 7.3817
Epoch: 2, Batch: 1, Loss: 7.4078
Epoch 2 Average Loss: 7.3947
Epoch: 3, Batch: 0, Loss: 7.2194
Epoch: 3, Batch: 1, Loss: 7.2320
Epoch 3 Average Loss: 7.2257
Training completed!

Generated Summaries:
Original Text: Scientists discover a new species of dinosaur in Argentina.
Generated Summary: New dinosaur species found in Argentina.
Reference Summary: New dinosaur species found in Argentina.
--------------------------------------------------
Original Text: Global tech conference announces 2024 event location as Singapore.
Generated Summary: New dinosaur species found in Argentina.
Reference Summary: Tech conference 2024 to be held in Singapore.
--------------------------------------------------
Original Text: New study shows coffee consumption linked to improved heart health.
Generated Summary: Coffee may benefit heart health, study finds.

In [None]:
# CELL 3: Verification Checklist 

'''
# AMR Processing:

Text → AMR graphs using amrlib

Graph linearization to penman format

Graph tokenization as model input

# Model Architecture:

2-layer bidirectional LSTM encoder

1-layer LSTM decoder

Attention with coverage mechanism

Pointer-generator network

Dropout (0.3) and gradient clipping (2.0)

# Training Parameters:

Batch size 64/32

Learning rate 0.001

Adam optimizer

15 epochs
'''

In [None]:
# CELL 4: Reinforcement Learning Model 
class RLWrapper(nn.Module):
    """Implements paper's Section 'Reinforcement-learning model'"""
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model  # AS2SP instance
        
    def forward(self, graph, summary=None, mode='train'):
        """Implements Equation 10 from paper"""
        if mode == 'train':
            # Sample sequence
            sample_out, _, _ = self._sample_sequence(graph)
            
            # Greedy sequence
            with torch.no_grad():
                greedy_out = self._greedy_decode(graph)
            
            return sample_out, greedy_out
        else:
            return self._greedy_decode(graph)
    
    def _sample_sequence(self, graph):
        """Implements sampling with teacher forcing"""
        return self.base_model(graph, summary=None)  
    def _greedy_decode(self, graph):
        """Implements beam search from paper"""
        return self.base_model(graph, summary=None)  

# CELL 5: Transformer Models 
from transformers import TransformerConfig, TransformerModel, BertModel

class TR(nn.Module):
    """Implements paper's 'TR' model (Sec 4.2)"""
    def __init__(self, vocab_size):
        super().__init__()
        config = TransformerConfig(
            vocab_size=vocab_size,
            d_model=512,
            nhead=8,
            num_encoder_layers=6,
            num_decoder_layers=6,
            dim_feedforward=2048
        )
        self.transformer = TransformerModel(config)
        self.fc = nn.Linear(512, vocab_size)
        
    def forward(self, graph, summary):
        outputs = self.transformer(
            src=graph,
            tgt=summary[:, :-1],
            src_key_padding_mask=(graph == 0),
            tgt_key_padding_mask=(summary[:, :-1] == 0)
        )
        return self.fc(outputs.last_hidden_state)

class TRCE(TR):
    """Implements 'TRCE' with contextual embeddings (Sec 4.2)"""
    def __init__(self, vocab_size):
        super().__init__(vocab_size)
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.transformer.encoder.embed_tokens = self.bert.embeddings.word_embeddings

class PETR(nn.Module):
    """Implements 'PETR' model (Sec 4.2)"""
    def __init__(self, vocab_size):
        super().__init__()
        self.encoder = BertModel.from_pretrained('bert-base-uncased')
        config = TransformerConfig(
            vocab_size=vocab_size,
            d_model=768,  # Match BERT hidden size
            nhead=12,
            num_decoder_layers=6,
            dim_feedforward=3072
        )
        self.decoder = TransformerModel(config).decoder
        self.fc = nn.Linear(768, vocab_size)
        
    def forward(self, graph, summary):
        encoded = self.encoder(graph).last_hidden_state
        outputs = self.decoder(
            tgt=summary[:, :-1],
            memory=encoded,
            tgt_key_padding_mask=(summary[:, :-1] == 0)
        )
        return self.fc(outputs.last_hidden_state)

# CELL 6: Unified Training Framework 
def train_unified(model_type='as2sp'):
    df = pd.read_csv('/path/to/data.csv')
    processor = AMRProcessor()
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    dataset = AMRDataset(df, processor, tokenizer)
    
    # Model selection
    if model_type == 'as2sp':
        model = AS2SP(tokenizer.vocab_size)
    elif model_type == 'rl':
        model = RLWrapper(AS2SP(tokenizer.vocab_size))
    elif model_type == 'tr':
        model = TR(tokenizer.vocab_size)
    elif model_type == 'trce':
        model = TRCE(tokenizer.vocab_size)
    elif model_type == 'petr':
        model = PETR(tokenizer.vocab_size)
        
    model = model.to(DEVICE)
    
    # Training 
    optimizer = Adam(model.parameters(), lr=0.001)
    
    for epoch in range(15):
        for batch in tqdm(DataLoader(dataset, batch_size=64, collate_fn=collate_fn)):
            optimizer.zero_grad()
            
            if model_type == 'rl':
                sample_out, greedy_out = model(batch['graphs'])
                # Calculate ROUGE rewards
                with torch.no_grad():
                    sample_rouge = calculate_rouge(sample_out, batch['summaries'])
                    greedy_rouge = calculate_rouge(greedy_out, batch['summaries'])
                
                # Implement Equation 10
                loss = -torch.mean((sample_rouge - greedy_rouge) * sample_out.log_probs)
            else:
                outputs = model(batch['graphs'], batch['summaries'])
                loss = F.cross_entropy(
                    outputs.view(-1, tokenizer.vocab_size),
                    batch['summaries'][:, 1:].contiguous().view(-1),
                    ignore_index=tokenizer.pad_token_id
                )
            
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 2.0)
            optimizer.step()

# CELL 7: Beam Search Implementation 
def beam_search(model, graph, beam_width=5, max_len=100):
    """Implements paper's beam search from Section 4.1"""
    model.eval()
    with torch.no_grad():
        # Initialize
        start_token = tokenizer.cls_token_id
        sequences = torch.tensor([[start_token]], device=DEVICE)
        scores = torch.zeros(1, device=DEVICE)
        
        for _ in range(max_len):
            all_candidates = []
            
            for i in range(len(sequences)):
                seq = sequences[i]
                score = scores[i]
                
                if seq[-1] == tokenizer.sep_token_id:
                    all_candidates.append((seq, score))
                    continue
                
                outputs = model(graph, seq.unsqueeze(0))
                next_token_logits = outputs[:, -1, :]
                next_probs = F.log_softmax(next_token_logits, dim=-1)
                top_probs, top_tokens = next_probs.topk(beam_width)
                
                for j in range(beam_width):
                    candidate_seq = torch.cat([seq, top_tokens[0][j].unsqueeze(0)])
                    candidate_score = score + top_probs[0][j]
                    all_candidates.append((candidate_seq, candidate_score))
            
            # Select top-k
            ordered = sorted(all_candidates, key=lambda x: x[1], reverse=True)
            sequences, scores = zip(*ordered[:beam_width])
            sequences = torch.stack(sequences)
            scores = torch.stack(scores)
            
        return sequences[0].cpu().tolist()