# ===========================================
# 3_training_demo.ipynb
# Prototypisches Training & erste Inferenz-Demos für AutoSummary
# ===========================================

# Zelle 1: Bibliotheken und Setup
import os
import json
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from tqdm import tqdm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Verwende Gerät:", DEVICE)

# Verzeichnispfade
PROCESSED_DATA_DIR = "../data/processed"
TRAIN_CSV = os.path.join(PROCESSED_DATA_DIR, "train.csv")
VAL_CSV   = os.path.join(PROCESSED_DATA_DIR, "val.csv")
VOCAB_JSON = os.path.join(PROCESSED_DATA_DIR, "vocab.json")

# Zelle 2: Daten laden
# Annahme: Die CSV-Dateien enthalten zwei Spalten:
# "article_ids" (Liste von Token-IDs des Originaltextes)
# "summary_ids" (Liste von Token-IDs der Zusammenfassung)
df_train = pd.read_csv(TRAIN_CSV)
df_val   = pd.read_csv(VAL_CSV)

with open(VOCAB_JSON, "r", encoding="utf-8") as f:
    word2id = json.load(f)
id2word = {v: k for k, v in word2id.items()}

vocab_size = len(word2id)
pad_id = word2id.get("<PAD>", 0)

print(f"Train Samples: {len(df_train)}, Val Samples: {len(df_val)}")
print("Vokabulargröße:", vocab_size)
display(df_train.head(2))

# Zelle 3: Dataset und DataLoader
class SummarizationDataset(Dataset):
    """
    Dataset für AutoSummary.
    Erwartet eine CSV mit den Spalten "article_ids" und "summary_ids".
    Die Spalten enthalten tokenisierte Listen (als String, z.B. "[2, 45, 78, 3]").
    """
    def __init__(self, df, article_col="article_ids", summary_col="summary_ids"):
        self.df = df.copy()
        # Konvertiere die Stringrepräsentation in echte Listen
        self.df[article_col] = self.df[article_col].apply(lambda x: eval(x) if isinstance(x, str) else x)
        self.df[summary_col] = self.df[summary_col].apply(lambda x: eval(x) if isinstance(x, str) else x)
        self.articles = self.df[article_col].tolist()
        self.summaries = self.df[summary_col].tolist()
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        article_ids = self.articles[idx]
        summary_ids = self.summaries[idx]
        article_tensor = torch.tensor(article_ids, dtype=torch.long)
        summary_tensor = torch.tensor(summary_ids, dtype=torch.long)
        return article_tensor, summary_tensor

def collate_fn(batch, pad_id=0):
    articles, summaries = zip(*batch)
    articles_padded = pad_sequence(articles, batch_first=True, padding_value=pad_id)
    summaries_padded = pad_sequence(summaries, batch_first=True, padding_value=pad_id)
    return articles_padded, summaries_padded

BATCH_SIZE = 32

train_dataset = SummarizationDataset(df_train)
val_dataset   = SummarizationDataset(df_val)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda b: collate_fn(b, pad_id=pad_id))
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=lambda b: collate_fn(b, pad_id=pad_id))

# Zelle 4: Seq2Seq-Modell (Encoder-Decoder) definieren

# Encoder: Liest den Input-Text (Artikel)
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, pad_idx):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
    
    def forward(self, src):
        # src: [Batch, src_len]
        embedded = self.embedding(src)  # [Batch, src_len, embed_dim]
        outputs, hidden = self.lstm(embedded)  # outputs: [Batch, src_len, hidden_dim]
        return outputs, hidden

# Decoder: Generiert die Zusammenfassung Token für Token
class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, pad_idx):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, trg, hidden):
        # trg: [Batch, trg_len]
        embedded = self.embedding(trg)  # [Batch, trg_len, embed_dim]
        outputs, hidden = self.lstm(embedded, hidden)  # [Batch, trg_len, hidden_dim]
        predictions = self.fc(outputs)  # [Batch, trg_len, vocab_size]
        return predictions, hidden

# Seq2Seq-Modell: Kombiniert Encoder und Decoder
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, pad_idx):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pad_idx = pad_idx
    
    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        """
        Args:
            src: [Batch, src_len] (Artikel)
            trg: [Batch, trg_len] (Zusammenfassung, inklusive <BOS> am Anfang)
            teacher_forcing_ratio: Wahrscheinlichkeit, dass das Modell das wahre Token als Input nutzt.
        Returns:
            outputs: [Batch, trg_len, vocab_size]
        """
        batch_size = src.size(0)
        trg_len = trg.size(1)
        vocab_size = self.decoder.fc.out_features
        
        # Ausgabe-Tensor initialisieren
        outputs = torch.zeros(batch_size, trg_len, vocab_size).to(src.device)
        
        # Encoder: Lies den Artikel
        encoder_outputs, hidden = self.encoder(src)
        
        # Initialer Input für den Decoder: das erste Token der Zusammenfassung (<BOS>)
        input_dec = trg[:, 0].unsqueeze(1)  # [Batch, 1]
        
        for t in range(1, trg_len):
            output, hidden = self.decoder(input_dec, hidden)
            outputs[:, t] = output.squeeze(1)
            
            # Entscheide, ob Teacher Forcing genutzt wird
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            top1 = output.argmax(2)  # [Batch, 1]
            input_dec = trg[:, t].unsqueeze(1) if teacher_force else top1
        
        return outputs

# Initialisierung des Modells
EMBED_DIM = 128
HIDDEN_DIM = 256
NUM_LAYERS = 2

encoder = Encoder(vocab_size, EMBED_DIM, HIDDEN_DIM, NUM_LAYERS, pad_id)
decoder = Decoder(vocab_size, EMBED_DIM, HIDDEN_DIM, NUM_LAYERS, pad_id)
model = Seq2Seq(encoder, decoder, pad_id).to(DEVICE)
print(model)

# Zelle 5: Trainings-Setup
criterion = nn.CrossEntropyLoss(ignore_index=pad_id)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
EPOCHS = 3

def train_one_epoch(model, loader, optimizer, criterion, device, teacher_forcing_ratio=0.5):
    model.train()
    total_loss = 0
    for src, trg in tqdm(loader, desc="Training", leave=False):
        src, trg = src.to(device), trg.to(device)
        optimizer.zero_grad()
        # Forward-Pass: trg enthält den gesamten Zieltext (mit <BOS> am Anfang)
        output = model(src, trg, teacher_forcing_ratio)
        # Output: [Batch, trg_len, vocab_size] -> reshape für Loss: [Batch*trg_len, vocab_size]
        output_dim = output.shape[-1]
        output = output[:, 1:].reshape(-1, output_dim)  # Ignoriere das erste Token (<BOS>)
        trg = trg[:, 1:].reshape(-1)
        
        loss = criterion(output, trg)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for src, trg in loader:
            src, trg = src.to(device), trg.to(device)
            output = model(src, trg, teacher_forcing_ratio=0.0)
            output_dim = output.shape[-1]
            output = output[:, 1:].reshape(-1, output_dim)
            trg = trg[:, 1:].reshape(-1)
            loss = criterion(output, trg)
            total_loss += loss.item()
    return total_loss / len(loader)

# Zelle 6: Trainingsschleife
best_val_loss = float("inf")
for epoch in range(1, EPOCHS+1):
    print(f"\n=== Epoche {epoch}/{EPOCHS} ===")
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE, teacher_forcing_ratio=0.5)
    val_loss = evaluate(model, val_loader, criterion, DEVICE)
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        # Speichern des besten Modells (Pfad anpassen)
        torch.save(model.state_dict(), f"autosummary_best_epoch{epoch}.pt")
        print("** Bestes Modell gespeichert.")

# Zelle 7: Inferenz-Demo – Zusammenfassung generieren
def generate_summary(model, article_ids, word2id, id2word, max_length=50):
    """
    Generiert eine Zusammenfassung für einen gegebenen Artikel (Liste von Token-IDs).
    Die Funktion nutzt Greedy-Decoding.
    """
    model.eval()
    # article_ids: Liste von Token-IDs des Artikels
    src = torch.tensor([article_ids], dtype=torch.long).to(DEVICE)
    
    # Encoder: Artikel verarbeiten
    encoder_outputs, hidden = encoder(src)
    
    # Decoder-Start: Setze das erste Token auf <BOS> für die Zusammenfassung
    bos_id = word2id.get("<BOS>", 2)
    input_dec = torch.tensor([[bos_id]], dtype=torch.long).to(DEVICE)
    
    generated_ids = [bos_id]
    
    for _ in range(max_length):
        output, hidden = decoder(input_dec, hidden)
        # output: [1, 1, vocab_size]
        next_id = output.argmax(2).item()
        if next_id == word2id.get("<EOS>", 3):
            break
        generated_ids.append(next_id)
        input_dec = torch.tensor([[next_id]], dtype=torch.long).to(DEVICE)
    
    return generated_ids

# Beispiel-Inferenz
# Wähle einen Artikel aus dem Validierungsdatensatz
sample_article_ids = df_val["article_ids"].iloc[0]
# Konvertiere den Artikel (String) in eine Liste, falls nötig
if isinstance(sample_article_ids, str):
    sample_article_ids = eval(sample_article_ids)

gen_ids = generate_summary(model, sample_article_ids, word2id, id2word, max_length=50)
gen_summary = " ".join([id2word.get(i, "<UNK>") for i in gen_ids])
print("\nGenerierte Zusammenfassung:")
print(gen_summary)

# Zelle 8: Ausblick & Fazit
print("""
In diesem Notebook haben wir:
- Einen Prototypen für ein Seq2Seq-Modell (Encoder-Decoder-LSTM) zum Erzeugen von Zusammenfassungen trainiert.
- Den Trainingsprozess und die Evaluierung demonstriert.
- Eine erste Inferenz-Demo durchgeführt, bei der ein Artikel zusammengefasst wurde.

Nächste Schritte:
- Mehr Daten und Epochen zum besseren Training nutzen.
- Experimentiere mit Teacher Forcing, Sampling-Methoden und Hyperparametern.
- Erweitere den Ansatz um Transformer-Modelle (z.B. T5 oder BART) für qualitativ hochwertigere Zusammenfassungen.
""")
