# ===========================================
# 3_training_demo.ipynb
# Training & Evaluierung von Stiltransfer-Modellen
# ===========================================

# Zelle 1: Bibliotheken und Setup
import os
import json
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# Gerät festlegen: GPU falls verfügbar, ansonsten CPU
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Verwende Gerät:", DEVICE)

# Definiere Pfade
PROCESSED_DATA_DIR = "../data/processed"
TRAIN_CSV = os.path.join(PROCESSED_DATA_DIR, "train.csv")
VAL_CSV   = os.path.join(PROCESSED_DATA_DIR, "val.csv")
VOCAB_JSON = os.path.join(PROCESSED_DATA_DIR, "vocab.json")

# Zelle 2: Daten und Vokabular laden
df_train = pd.read_csv(TRAIN_CSV)
df_val   = pd.read_csv(VAL_CSV)

with open(VOCAB_JSON, "r", encoding="utf-8") as f:
    word2id = json.load(f)
id2word = {v: k for k, v in word2id.items()}

vocab_size = len(word2id)
pad_id = word2id.get("<PAD>", 0)

print(f"Train Samples: {len(df_train)} | Val Samples: {len(df_val)}")
print("Vokabulargröße:", vocab_size)
# Schau dir ein paar Zeilen an
display(df_train.head(3))

# Zelle 3: Dataset-Klasse definieren
class StyleShiftDataset(Dataset):
    """
    Dataset für den Stiltransfer.
    Erwartet in der CSV zwei Spalten:
      - modern_ids: Token-IDs im modernen Stil (Input)
      - shakespeare_ids: Token-IDs im Zielstil (Output)
    """
    def __init__(self, df, input_col="modern_ids", target_col="shakespeare_ids"):
        self.df = df.copy()
        # Konvertiere Stringrepräsentationen in echte Listen, falls nötig
        self.df[input_col] = self.df[input_col].apply(lambda x: eval(x) if isinstance(x, str) else x)
        self.df[target_col] = self.df[target_col].apply(lambda x: eval(x) if isinstance(x, str) else x)
        self.inputs = self.df[input_col].tolist()
        self.targets = self.df[target_col].tolist()
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        inp = torch.tensor(self.inputs[idx], dtype=torch.long)
        tgt = torch.tensor(self.targets[idx], dtype=torch.long)
        return inp, tgt

def collate_fn(batch, pad_id=0):
    inputs, targets = zip(*batch)
    inputs_padded = pad_sequence(inputs, batch_first=True, padding_value=pad_id)
    targets_padded = pad_sequence(targets, batch_first=True, padding_value=pad_id)
    return inputs_padded, targets_padded

BATCH_SIZE = 32

train_dataset = StyleShiftDataset(df_train, input_col="modern_ids", target_col="shakespeare_ids")
val_dataset = StyleShiftDataset(df_val, input_col="modern_ids", target_col="shakespeare_ids")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=lambda b: collate_fn(b, pad_id=pad_id))
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        collate_fn=lambda b: collate_fn(b, pad_id=pad_id))

print("Beispielhafter Batch:")
sample_inp, sample_tgt = next(iter(train_loader))
print("Input-Shape:", sample_inp.shape, "Target-Shape:", sample_tgt.shape)

# Zelle 4: Modell definieren (Seq2Seq-LSTM für Stiltransfer)
# Encoder: Liest den modernen Text
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, pad_idx):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
    
    def forward(self, src):
        # src: [Batch, src_len]
        embedded = self.embedding(src)  # [Batch, src_len, embed_dim]
        outputs, hidden = self.lstm(embedded)
        return outputs, hidden

# Decoder: Generiert den Text im Zielstil (z. B. Shakespeare)
class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, pad_idx):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, trg, hidden):
        # trg: [Batch, trg_len]
        embedded = self.embedding(trg)  # [Batch, trg_len, embed_dim]
        outputs, hidden = self.lstm(embedded, hidden)
        predictions = self.fc(outputs)  # [Batch, trg_len, vocab_size]
        return predictions, hidden

# Seq2Seq: Kombiniert Encoder und Decoder
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, pad_idx):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pad_idx = pad_idx
    
    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.size(0)
        trg_len = trg.size(1)
        vocab_size = self.decoder.fc.out_features
        
        outputs = torch.zeros(batch_size, trg_len, vocab_size).to(src.device)
        
        _, hidden = self.encoder(src)
        
        # Initialer Input für den Decoder: Das erste Token der Zielsequenz (<BOS>)
        input_dec = trg[:, 0].unsqueeze(1)  # [Batch, 1]
        
        for t in range(1, trg_len):
            output, hidden = self.decoder(input_dec, hidden)
            outputs[:, t] = output.squeeze(1)
            
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            top1 = output.argmax(2)  # [Batch, 1]
            input_dec = trg[:, t].unsqueeze(1) if teacher_force else top1
        
        return outputs

# Initialisierung des Modells
EMBED_DIM = 128
HIDDEN_DIM = 256
NUM_LAYERS = 2

encoder = Encoder(vocab_size, EMBED_DIM, HIDDEN_DIM, NUM_LAYERS, pad_id)
decoder = Decoder(vocab_size, EMBED_DIM, HIDDEN_DIM, NUM_LAYERS, pad_id)
model = Seq2Seq(encoder, decoder, pad_id).to(DEVICE)
print(model)

# Zelle 5: Trainings-Setup
criterion = nn.CrossEntropyLoss(ignore_index=pad_id)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
EPOCHS = 5

def train_one_epoch(model, loader, optimizer, criterion, device, teacher_forcing_ratio):
    model.train()
    total_loss = 0
    for src, trg in tqdm(loader, desc="Training", leave=False):
        src, trg = src.to(device), trg.to(device)
        optimizer.zero_grad()
        outputs = model(src, trg, teacher_forcing_ratio=teacher_forcing_ratio)
        output_dim = outputs.shape[-1]
        outputs = outputs[:, 1:].reshape(-1, output_dim)
        trg = trg[:, 1:].reshape(-1)
        loss = criterion(outputs, trg)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for src, trg in loader:
            src, trg = src.to(device), trg.to(device)
            outputs = model(src, trg, teacher_forcing_ratio=0.0)
            output_dim = outputs.shape[-1]
            outputs = outputs[:, 1:].reshape(-1, output_dim)
            trg = trg[:, 1:].reshape(-1)
            loss = criterion(outputs, trg)
            total_loss += loss.item()
    return total_loss / len(loader)

# Zelle 6: Trainings- und Validierungsschleife
best_val_loss = float("inf")
for epoch in range(1, EPOCHS + 1):
    print(f"\n=== Epoche {epoch}/{EPOCHS} ===")
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE, teacher_forcing_ratio=0.5)
    val_loss = evaluate(model, val_loader, criterion, DEVICE)
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        model_path = f"styleshift_best_epoch{epoch}.pt"
        torch.save(model.state_dict(), model_path)
        print(f"** Neues bestes Modell gespeichert: {model_path}")

# Zelle 7: Inferenz-Demo – Stiltransfer
def generate_styled_text(model, input_text, word2id, id2word, max_length=50):
    """
    Generiert eine stilisierte Version des input_text.
    Hier gehen wir davon aus, dass input_text im modernen Stil vorliegt.
    Wir wandeln den Text in Token-IDs um, fügen ein <BOS> hinzu und generieren dann den Zielstil.
    """
    # Tokenisierung (einfach per split – in der Praxis die gleiche Methode wie in der Vorverarbeitung verwenden)
    tokens = input_text.lower().split()
    input_ids = [word2id.get(token, word2id.get("<UNK>", 1)) for token in tokens]
    
    # Füge <BOS> hinzu (angenommen, <BOS> ist das erste Token im Vokabular für den Zielstil)
    bos_id = word2id.get("<BOS>", 2)
    input_dec = torch.tensor([[bos_id]], dtype=torch.long).to(DEVICE)
    
    # Konvertiere den modern-stiligen Input in einen Tensor
    src = torch.tensor([input_ids], dtype=torch.long).to(DEVICE)
    
    # Encoder: Artikel verarbeiten
    _, hidden = encoder(src)
    
    generated_ids = [bos_id]
    model.eval()
    with torch.no_grad():
        for _ in range(max_length):
            output, hidden = decoder(input_dec, hidden)
            next_id = output.argmax(dim=2).item()
            if next_id == word2id.get("<EOS>", 3):
                break
            generated_ids.append(next_id)
            input_dec = torch.tensor([[next_id]], dtype=torch.long).to(DEVICE)
    
    return " ".join([id2word.get(i, "<UNK>") for i in generated_ids])

# Beispiel für die Inferenz
example_input = "ich freue mich auf den sommer"
styled_output = generate_styled_text(model, example_input, word2id, id2word, max_length=20)
print("\nBeispiel für Stiltransfer:")
print("Modern:", example_input)
print("Stilisierte Version:", styled_output)

# Zelle 8: Fazit & Ausblick
print("""
Fazit:
- Das Modell wurde erfolgreich trainiert und zeigt einen sinkenden Validierungs-Loss.
- Erste Inferenz-Demos demonstrieren den Stiltransfer von modern zu einem anderen Stil.
Nächste Schritte:
- Mehr Daten und Epochen für bessere Ergebnisse.
- Erweiterung des Modells um Transformer-Ansätze (z.B. GPT-2/T5) für noch realistischere Stiltransfers.
- Feinabstimmung der Hyperparameter und Erprobung von Sampling-Methoden.
""")
