In [None]:
# CELL 1: Intuition 
'''
graph TD
  A[Original Text] --> B(AMR Parsing)
  B --> C{AMR Graph}
  C --> D[Graph Linearization]
  D --> E[Graph Tokenization]
  E --> F[BERT Embedding]
  F --> G[Model Input]
  '''

In [None]:
# CELL 2: AS2SP Model 
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import amrlib

# SET CUDA 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# DATA LOADING
def load_data(path):
    df = pd.read_csv(path)
    articles = df['article'].tolist()
    highlights = df['highlights'].tolist()
    return articles, highlights

train_articles, train_highlights = load_data("/home/masih/Downloads/Telegram Desktop/data/train.csv") # REPLACE WITH YOUR PATH
val_articles, val_highlights = load_data("/home/masih/Downloads/Telegram Desktop/data/validation.csv")
test_articles, test_highlights = load_data("/home/masih/Downloads/Telegram Desktop/data/test.csv")

# AMR PARSING & PREPROCESSING 
stog = amrlib.load_stog_model(device='cpu') 

def parse_amr(articles):
    print("Parsing AMR graphs...")
    amr_graphs = stog.parse_sents(articles)
    return [g if g else "" for g in amr_graphs]

# Process all splits
train_graphs = parse_amr(train_articles)
val_graphs = parse_amr(val_articles)
test_graphs = parse_amr(test_articles)

# TOKENIZATION & VOCAB 
class Vocabulary:
    def __init__(self):
        self.word2idx = {"<pad>": 0, "<unk>": 1, "<sos>": 2, "<eos>": 3}
        self.idx2word = {0: "<pad>", 1: "<unk>", 2: "<sos>", 3: "<eos>"}
        
    def build_vocab(self, texts, max_size=2000):
        words = [word for text in texts for word in text.split()]
        word_counts = Counter(words)
        common_words = word_counts.most_common(max_size)
        
        for idx, (word, _) in enumerate(common_words, start=4):
            self.word2idx[word] = idx
            self.idx2word[idx] = word
            
VOCAB_SIZE = 2000
vocab = Vocabulary()
vocab.build_vocab(train_graphs + train_highlights, max_size=VOCAB_SIZE)

# DATASET & DATALOADER 
class SummaryDataset(Dataset):
    def __init__(self, graph_strings, highlights, vocab):
        self.graphs = [self.text_to_ids(gs, vocab) for gs in graph_strings]
        self.highlights = [self.text_to_ids(s, vocab, add_special=True) for s in highlights]
        
    def text_to_ids(self, text, vocab, add_special=False):
        ids = [vocab.word2idx.get(word, 1) for word in text.split()]
        if add_special:
            ids = [vocab.word2idx["<sos>"]] + ids + [vocab.word2idx["<eos>"]]
        return ids
    
    def __len__(self):
        return len(self.graphs)
    
    def __getitem__(self, idx):
        return (
            torch.tensor(self.graphs[idx]),
            torch.tensor(self.highlights[idx])
        )

def collate_fn(batch):
    srcs, trgs = zip(*batch)
    srcs = torch.nn.utils.rnn.pad_sequence(srcs, padding_value=0).transpose(0, 1)
    trgs = torch.nn.utils.rnn.pad_sequence(trgs, padding_value=0).transpose(0, 1)
    return srcs, trgs

# CREATE DATASETS AND DATALOADERS
train_dataset = SummaryDataset(train_graphs, train_highlights, vocab)
val_dataset = SummaryDataset(val_graphs, val_highlights, vocab)
test_dataset = SummaryDataset(test_graphs, test_highlights, vocab)

train_loader = DataLoader(train_dataset, batch_size=2, collate_fn=collate_fn, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=2, collate_fn=collate_fn)

# MODEL ARCHITECTURE 
class AS2SP(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.enc_embed = nn.Embedding(vocab_size, 128)
        self.encoder = nn.LSTM(128, 64, 
                             num_layers=1,
                             bidirectional=True,
                             batch_first=True)
        self.hidden_proj = nn.Linear(64 * 2, 256)
        self.cell_proj = nn.Linear(64 * 2, 256)
        self.dec_embed = nn.Embedding(vocab_size, 128)
        self.decoder = nn.LSTM(128, 256, num_layers=1, batch_first=True)
        self.W_h = nn.Linear(64 * 2, 256)
        self.W_s = nn.Linear(256, 256)
        self.v = nn.Linear(256, 1)
        self.p_gen = nn.Linear(128 + 256 + 128, 1)
        self.fc = nn.Linear(256, vocab_size)
        self.dropout = nn.Dropout(0.2)

    def forward(self, src_graph, trg_text):
        enc_embedded = self.dropout(self.enc_embed(src_graph))
        enc_out, (h_n, c_n) = self.encoder(enc_embedded)
        
        h_n = torch.cat([h_n[0], h_n[1]], dim=-1)
        c_n = torch.cat([c_n[0], c_n[1]], dim=-1)
        
        decoder_hidden = self.hidden_proj(h_n).unsqueeze(0)
        decoder_cell = self.cell_proj(c_n).unsqueeze(0)
        
        dec_embedded = self.dropout(self.dec_embed(trg_text))
        dec_out, _ = self.decoder(dec_embedded, (decoder_hidden, decoder_cell))
        
        enc_proj = self.W_h(enc_out).unsqueeze(2) 
        dec_proj = self.W_s(dec_out).unsqueeze(1)  
        
        attn_energy = torch.tanh(enc_proj + dec_proj)
        attn_scores = self.v(attn_energy).squeeze(-1)
        attn_weights = F.softmax(attn_scores, dim=1)
        attn_weights = attn_weights.permute(0, 2, 1)
        context = torch.bmm(attn_weights, enc_out)
        
        p_gen_input = torch.cat([context, dec_out, dec_embedded], dim=-1)
        p_gen = torch.sigmoid(self.p_gen(p_gen_input))
        
        output = self.fc(dec_out)
        return output, attn_weights, p_gen

# TRAINING SETUP 
model = AS2SP(VOCAB_SIZE).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=0)

# TRAINING LOOP 
for epoch in range(3):
    model.train()
    total_loss = 0
    for batch_idx, (src, trg) in enumerate(train_loader):
        src, trg = src.to(device), trg.to(device)
        
        outputs, _, _ = model(src[:, :-1], trg[:, :-1])
        loss = criterion(outputs.reshape(-1, VOCAB_SIZE), 
                        trg[:, 1:].reshape(-1))
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 100 == 0:
            print(f"Epoch: {epoch+1}, Batch: {batch_idx}, Loss: {loss.item():.4f}")
    
    print(f"Epoch {epoch+1} Average Loss: {total_loss/len(train_loader):.4f}")

print("Training completed!")

# GENERATION FUNCTION 
def generate_summary(model, graph_string, vocab, max_len=20):
    model.eval()
    tokenized = [vocab.word2idx.get(word, 1) for word in graph_string.split()]
    src = torch.tensor([tokenized]).to(device)
    
    decoder_input = torch.tensor([[vocab.word2idx["<sos>"]]]).to(device)
    summary = []
    
    with torch.no_grad():
        enc_embedded = model.enc_embed(src)
        enc_out, (h_n, c_n) = model.encoder(enc_embedded)
        
        h_n = torch.cat([h_n[0], h_n[1]], dim=-1)
        c_n = torch.cat([c_n[0], c_n[1]], dim=-1)
        decoder_hidden = model.hidden_proj(h_n).unsqueeze(0)
        decoder_cell = model.cell_proj(c_n).unsqueeze(0)
        
        for _ in range(max_len):
            dec_embedded = model.dec_embed(decoder_input)
            dec_out, (decoder_hidden, decoder_cell) = model.decoder(
                dec_embedded, (decoder_hidden, decoder_cell)
            )
            
            output = model.fc(dec_out)
            next_token = output.argmax(-1)[:, -1].item()
            
            if next_token == vocab.word2idx["<eos>"]:
                break
            
            summary.append(vocab.idx2word.get(next_token, "<unk>"))
            decoder_input = torch.tensor([[next_token]]).to(device)
            
    return " ".join(summary)

# GENERATE SUMMARIES FOR TEST SET
print("\nGenerated Summaries for Test Set:")
for i in range(len(test_dataset)):
    input_graph = test_graphs[i]
    generated = generate_summary(model, input_graph, vocab)
    print(f"Original Article: {test_articles[i]}")
    print(f"Generated Summary: {generated}")
    print(f"Reference Summary: {test_highlights[i]}\n{'-'*50}")