# ===========================================
# 3_training_demo.ipynb
# ===========================================

# Zelle 1: Bibliotheken & Setup
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import os
import json
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Verwende Gerät:", DEVICE)

PROCESSED_DATA_DIR = "../data/processed"

# ===========================================
# Zelle 2: Daten laden
# ===========================================
train_path = os.path.join(PROCESSED_DATA_DIR, "train.csv")
val_path = os.path.join(PROCESSED_DATA_DIR, "val.csv")
test_path = os.path.join(PROCESSED_DATA_DIR, "test.csv")

vocab_path = os.path.join(PROCESSED_DATA_DIR, "vocab.json")

train_df = pd.read_csv(train_path)
val_df = pd.read_csv(val_path)
test_df = pd.read_csv(test_path)

with open(vocab_path, "r", encoding="utf-8") as f:
    word2id = json.load(f)

print(f"Train Samples: {len(train_df)}, Val Samples: {len(val_df)}, Test Samples: {len(test_df)}")
vocab_size = len(word2id)
print("Vokabulargröße:", vocab_size)

# ===========================================
# Zelle 3: Datensatz & DataLoader erstellen
# ===========================================
# Wir definieren eine einfache Dataset-Klasse, um
# aus den CSV-Spalten "token_ids" Tensoren zu erzeugen.

class TextDataset(Dataset):
    def __init__(self, df, seq_col="token_ids"):
        # Wir gehen davon aus, dass df[seq_col] eine Liste von Token-IDs in String-Form enthält
        # (z.B. "[1, 45, 23, 99, ...]")
        self.samples = df[seq_col].apply(lambda x: eval(x)).tolist()
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        token_ids = self.samples[idx]
        # Beispiel: Vorbereiten einer Eingabesequenz (ohne letztes Token)
        # und Zielsequenz (ohne erstes Token)
        # -> Language-Model-Ansatz "Next Token Prediction"
        input_seq = token_ids[:-1]
        target_seq = token_ids[1:]
        
        # Tensoren erstellen
        input_seq = torch.tensor(input_seq, dtype=torch.long)
        target_seq = torch.tensor(target_seq, dtype=torch.long)
        
        return input_seq, target_seq

# Dataset-Instanzen
train_dataset = TextDataset(train_df)
val_dataset   = TextDataset(val_df)

# Um Batches zu erstellen, brauchen wir einen Collate-Funktion:
# -> wir füllen sequences (versch. Längen) auf eine einheitliche Länge per Padding.
#   Hier sehr einfach gehalten (Padding auf max. Sequenzlänge im Batch).
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch, pad_id=0):
    # batch ist eine Liste von (input_seq, target_seq)
    input_seqs, target_seqs = zip(*batch)
    
    # Pad variable-length sequences
    input_padded = pad_sequence(input_seqs, batch_first=True, padding_value=pad_id)
    target_padded = pad_sequence(target_seqs, batch_first=True, padding_value=pad_id)
    
    return input_padded, target_padded

# DataLoader
BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=collate_fn)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                          collate_fn=collate_fn)

print("Beispielhafter Batch (Train):")
sample_input, sample_target = next(iter(train_loader))
print("Input-Shape:", sample_input.shape, "Target-Shape:", sample_target.shape)

# ===========================================
# Zelle 4: LSTM-Modell definieren
# ===========================================
class LSTMTextModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=1):
        super(LSTMTextModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers,
                            batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, x, hidden=None):
        # x: [Batch, SeqLen]
        emb = self.embedding(x) # [Batch, SeqLen, embed_dim]
        out, hidden = self.lstm(emb, hidden) # out: [Batch, SeqLen, hidden_dim]
        logits = self.fc(out)   # [Batch, SeqLen, vocab_size]
        return logits, hidden

# Modell-Instanz
model = LSTMTextModel(vocab_size=vocab_size, embed_dim=128, hidden_dim=256, num_layers=2)
model = model.to(DEVICE)
print(model)

# ===========================================
# Zelle 5: Trainingssetup (Loss, Optimizer)
# ===========================================
criterion = nn.CrossEntropyLoss(ignore_index=0)  # Wir ignorieren <PAD>=0 als Target
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# ===========================================
# Zelle 6: Trainingsloop
# ===========================================
EPOCHS = 3

def train_one_epoch(model, loader):
    model.train()
    total_loss = 0
    for batch_input, batch_target in tqdm(loader, desc="Training", leave=False):
        batch_input = batch_input.to(DEVICE)
        batch_target = batch_target.to(DEVICE)
        
        # Forward
        logits, _ = model(batch_input)
        # logits: [Batch, SeqLen, vocab_size]
        # batch_target: [Batch, SeqLen]
        
        # Wir reshapen, um mit CrossEntropyLoss klarzukommen
        logits = logits.reshape(-1, vocab_size)       # -> [Batch*SeqLen, vocab_size]
        batch_target = batch_target.view(-1)          # -> [Batch*SeqLen]
        
        loss = criterion(logits, batch_target)
        
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(loader)

def evaluate(model, loader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch_input, batch_target in loader:
            batch_input = batch_input.to(DEVICE)
            batch_target = batch_target.to(DEVICE)
            
            logits, _ = model(batch_input)
            logits = logits.reshape(-1, vocab_size)
            batch_target = batch_target.view(-1)
            
            loss = criterion(logits, batch_target)
            total_loss += loss.item()
    return total_loss / len(loader)

for epoch in range(EPOCHS):
    train_loss = train_one_epoch(model, train_loader)
    val_loss = evaluate(model, val_loader)
    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

# ===========================================
# Zelle 7: Kurzer Inferenz-Test
# ===========================================
# Wir probieren eine einfache Greedy-Generierung aus,
# um zu sehen, wie das Modell sich verhält (Frühstadium).
id2word = {v:k for k,v in word2id.items()}

def generate_text(model, start_tokens, max_len=20):
    model.eval()
    input_seq = torch.tensor([start_tokens], dtype=torch.long).to(DEVICE)
    hidden = None
    
    generated = start_tokens[:]  # Kopie
    for _ in range(max_len):
        logits, hidden = model(input_seq, hidden)
        # logits: [1, SeqLen, vocab_size]
        last_logits = logits[0, -1, :]  # letztes Token
        next_id = torch.argmax(last_logits).item()  # Greedy
        generated.append(next_id)
        
        input_seq = torch.tensor([generated], dtype=torch.long).to(DEVICE)
        
        # Optional: Abbruchkriterium (z.B. <EOS>)
        if next_id == word2id.get("<EOS>", -1):
            break
    return generated

# Beispiel-Starttokens (BOS oder ein Wort)
start_prompt = "<BOS>"
prompt_id = word2id.get(start_prompt, 1)  # falls nicht im Vokabular -> <UNK>
gen_ids = generate_text(model, [prompt_id], max_len=10)

generated_words = [id2word[i] for i in gen_ids]
print("Generated Sequence:", " ".join(generated_words))

# ===========================================
# Zelle 8: Speichern des Modells
# ===========================================
# In echten Projekten speichert man das Modell oft in einem separaten Ordner (z. B. 'models/')
os.makedirs("../models", exist_ok=True)
model_path = "../models/lstm_model.pt"
torch.save(model.state_dict(), model_path)
print("Modell gespeichert unter:", model_path)

# ===========================================
# Zelle 9: Zusammenfassung & Ausblick
# ===========================================
print("""
In diesem Notebook haben wir ein einfaches LSTM-Sprachmodell auf unseren Daten trainiert.
Wir sahen erste Trainingsfortschritte anhand des Losses und konnten ein paar Tokens generieren.

Nächste Schritte könnten sein:
- Training über mehr Epochen
- Hyperparameter-Tuning (Embed-Dim, Hidden-Dim, LR etc.)
- Fortgeschrittene Generierung (Top-k-Sampling, Temperature)
- Umgang mit sehr langen Sequenzen (Truncation, BPTT-Splitting)
""")
