# ===========================================
# 3_training_demo.ipynb
# ===========================================

# Zelle 1: Bibliotheken & Setup
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os
import json
from tqdm import tqdm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Gerät:", DEVICE)

PROCESSED_DATA_DIR = "../data/processed"

# ===========================================
# Zelle 2: Datensätze & Vokabular laden
# ===========================================
train_csv = os.path.join(PROCESSED_DATA_DIR, "train.csv")
val_csv   = os.path.join(PROCESSED_DATA_DIR, "val.csv")
test_csv  = os.path.join(PROCESSED_DATA_DIR, "test.csv")
vocab_path= os.path.join(PROCESSED_DATA_DIR, "vocab.json")

train_df = pd.read_csv(train_csv)
val_df   = pd.read_csv(val_csv)
test_df  = pd.read_csv(test_csv)

with open(vocab_path, "r", encoding="utf-8") as f:
    word2id = json.load(f)
id2word = {v: k for k, v in word2id.items()}

vocab_size = len(word2id)
print(f"Train-Samples: {len(train_df)}, Val-Samples: {len(val_df)}, Test-Samples: {len(test_df)}")
print("Vokabulargröße:", vocab_size)

# ===========================================
# Zelle 3: Dataset & DataLoader
# ===========================================
class TextDataset(Dataset):
    """
    Erwartet eine CSV mit Spalte "token_ids", 
    in der z. B. '[2, 45, 7, 13]' (List of ints) steht.
    """
    def __init__(self, df, seq_col="token_ids"):
        self.samples = df[seq_col].apply(lambda x: eval(x) if isinstance(x, str) else x).tolist()
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        tokens = self.samples[idx]
        # input_seq: tokens[:-1], target_seq: tokens[1:]
        # optional: BOS=2, EOS=3
        if len(tokens) < 2:
            tokens = [2, 3]  # fallback <BOS>, <EOS>
        inp = torch.tensor(tokens[:-1], dtype=torch.long)
        tgt = torch.tensor(tokens[1:], dtype=torch.long)
        return inp, tgt

def collate_fn(batch, pad_id=0):
    """ Padding der Sequenzen im Batch. """
    from torch.nn.utils.rnn import pad_sequence
    inputs, targets = zip(*batch)
    inp_padded = pad_sequence(inputs, batch_first=True, padding_value=pad_id)
    tgt_padded = pad_sequence(targets, batch_first=True, padding_value=pad_id)
    return inp_padded, tgt_padded

BATCH_SIZE = 32

train_dataset = TextDataset(train_df)
val_dataset   = TextDataset(val_df)

from torch.utils.data import DataLoader
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn
)

# Test-Dataset kannst du für abschließende Auswertungen behalten
# oder später in Inferenz-Szenarien nutzen.

# ===========================================
# Zelle 4: LSTM-Modell definieren
# ===========================================
class LSTMTextModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=1, pad_idx=0):
        super(LSTMTextModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, x, hidden=None):
        emb = self.embedding(x)
        out, hidden = self.lstm(emb, hidden)
        logits = self.fc(out)
        return logits, hidden

# Modellinstanz
model = LSTMTextModel(
    vocab_size=vocab_size, 
    embed_dim=128, 
    hidden_dim=256, 
    num_layers=2, 
    pad_idx=word2id["<PAD>"]  # In der Annahme, dass <PAD>=0
)
model.to(DEVICE)
print(model)

# ===========================================
# Zelle 5: Trainingssetup
# ===========================================
criterion = nn.CrossEntropyLoss(ignore_index=word2id["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=1e-3)
EPOCHS = 3

# ===========================================
# Zelle 6: Trainingsschleife
# ===========================================
def train_one_epoch(model, loader):
    model.train()
    total_loss = 0
    for inp, tgt in tqdm(loader, desc="Train", leave=False):
        inp = inp.to(DEVICE)
        tgt = tgt.to(DEVICE)
        
        logits, _ = model(inp)
        vocab_s = logits.size(-1)
        logits = logits.view(-1, vocab_s)
        tgt = tgt.view(-1)
        
        loss = criterion(logits, tgt)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for inp, tgt in loader:
            inp, tgt = inp.to(DEVICE), tgt.to(DEVICE)
            logits, _ = model(inp)
            logits = logits.view(-1, vocab_size)
            tgt = tgt.view(-1)
            loss = criterion(logits, tgt)
            total_loss += loss.item()
    return total_loss / len(loader)

best_val_loss = float("inf")
for epoch in range(EPOCHS):
    train_loss = train_one_epoch(model, train_loader)
    val_loss = evaluate(model, val_loader)
    print(f"Epoch [{epoch+1}/{EPOCHS}] - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), f"lstm_best_model_epoch{epoch+1}.pt")
        print("** Neues bestes Modell gespeichert.")

# ===========================================
# Zelle 7: Einfache Generierung (Greedy)
# ===========================================
def generate_text(model, prompt_tokens, max_len=50):
    model.eval()
    tokens = prompt_tokens[:]
    inp = torch.tensor([tokens], dtype=torch.long).to(DEVICE)
    hidden = None
    
    for _ in range(max_len):
        logits, hidden = model(inp, hidden)
        # logits: [1, seq_len, vocab_size], wir nehmen das letzte Token
        last_token_logits = logits[0, -1, :]
        next_id = torch.argmax(last_token_logits).item()
        
        tokens.append(next_id)
        inp = torch.tensor([tokens], dtype=torch.long).to(DEVICE)
        
        if next_id == word2id.get("<EOS>", -1):
            break
    return tokens

prompt_text = ["<BOS>"]
prompt_ids  = [word2id.get(t, word2id.get("<UNK>", 1)) for t in prompt_text]

generated_ids = generate_text(model, prompt_ids, max_len=20)
generated_tokens = [id2word.get(i, "<UNK>") for i in generated_ids]

print("Generierter Text (Greedy):")
print(" ".join(generated_tokens))

# ===========================================
# Zelle 8: Fazit & Ausblick
# ===========================================
print("""
In diesem Notebook wurde ein einfaches LSTM-Modell trainiert,
um Geschichten fortzusetzen. Mit mehr Epochen und Daten kann die
Textqualität gesteigert werden.

Nächste Schritte:
- Mehr Epochen & Hyperparameter-Tuning
- Top-k oder Temperature-Sampling (anstatt Greedy)
- Upgrade auf Transformer-Architekturen, z. B. GPT-2
""")
