# ===========================================
# 3_training_demo.ipynb
# Training & Evaluierung der E-Mail-Vervollständigung
# ===========================================

# Zelle 1: Bibliotheken und Setup
import os
import json
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from tqdm import tqdm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Trainiere auf Gerät:", DEVICE)

# Definiere Pfade zu den verarbeiteten Daten
PROCESSED_DATA_DIR = "../data/processed"
TRAIN_CSV = os.path.join(PROCESSED_DATA_DIR, "train.csv")
VAL_CSV = os.path.join(PROCESSED_DATA_DIR, "val.csv")
VOCAB_JSON = os.path.join(PROCESSED_DATA_DIR, "vocab.json")

# Zelle 2: Daten und Vokabular laden
df_train = pd.read_csv(TRAIN_CSV)
df_val = pd.read_csv(VAL_CSV)

with open(VOCAB_JSON, "r", encoding="utf-8") as f:
    word2id = json.load(f)
id2word = {v: k for k, v in word2id.items()}

vocab_size = len(word2id)
pad_id = word2id.get("<PAD>", 0)
print(f"Train Samples: {len(df_train)} | Val Samples: {len(df_val)}")
print("Vokabulargröße:", vocab_size)
display(df_train.head(3))

# Zelle 3: Dataset-Klasse definieren
class MailDataset(Dataset):
    """
    Dataset für die E-Mail-Vervollständigung.
    Erwartet, dass die CSV-Dateien die Spalten "subject_ids" und "body_ids" enthalten.
    Hierbei werden die E-Mail-Inhalte als Input (z.B. unvollständiger Text) und die erwartete Vervollständigung als Target genutzt.
    """
    def __init__(self, df, input_col="body_ids", target_col="body_ids"):
        # Bei MailAssist kann der Input z. B. der Anfang eines E-Mail-Bodys sein,
        # und das Target die komplette E-Mail oder die fehlenden Teile.
        self.df = df.copy()
        self.df[input_col] = self.df[input_col].apply(lambda x: eval(x) if isinstance(x, str) else x)
        self.df[target_col] = self.df[target_col].apply(lambda x: eval(x) if isinstance(x, str) else x)
        self.inputs = self.df[input_col].tolist()
        self.targets = self.df[target_col].tolist()
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        inp_ids = self.inputs[idx]
        tgt_ids = self.targets[idx]
        inp_tensor = torch.tensor(inp_ids, dtype=torch.long)
        tgt_tensor = torch.tensor(tgt_ids, dtype=torch.long)
        return inp_tensor, tgt_tensor

def collate_fn(batch, pad_id=0):
    inputs, targets = zip(*batch)
    inputs_padded = pad_sequence(inputs, batch_first=True, padding_value=pad_id)
    targets_padded = pad_sequence(targets, batch_first=True, padding_value=pad_id)
    return inputs_padded, targets_padded

BATCH_SIZE = 32
train_dataset = MailDataset(df_train, input_col="body_ids", target_col="body_ids")
val_dataset = MailDataset(df_val, input_col="body_ids", target_col="body_ids")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=lambda batch: collate_fn(batch, pad_id=pad_id))
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        collate_fn=lambda batch: collate_fn(batch, pad_id=pad_id))

print("Beispielhafter Batch:")
sample_inp, sample_tgt = next(iter(train_loader))
print("Input-Shape:", sample_inp.shape, "Target-Shape:", sample_tgt.shape)

# Zelle 4: Modell definieren (Seq2Seq-LSTM für E-Mail-Vervollständigung)
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, pad_idx):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
    
    def forward(self, src):
        embedded = self.embedding(src)  # [Batch, src_len, embed_dim]
        outputs, hidden = self.lstm(embedded)
        return outputs, hidden

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, pad_idx):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, trg, hidden):
        embedded = self.embedding(trg)  # [Batch, trg_len, embed_dim]
        outputs, hidden = self.lstm(embedded, hidden)
        predictions = self.fc(outputs)  # [Batch, trg_len, vocab_size]
        return predictions, hidden

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, pad_idx):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pad_idx = pad_idx
    
    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.size(0)
        trg_len = trg.size(1)
        vocab_size = self.decoder.fc.out_features
        
        outputs = torch.zeros(batch_size, trg_len, vocab_size).to(src.device)
        _, hidden = self.encoder(src)
        
        # Der erste Decoder-Input ist das <BOS>-Token aus der Zielsequenz
        input_dec = trg[:, 0].unsqueeze(1)  # [Batch, 1]
        
        for t in range(1, trg_len):
            output, hidden = self.decoder(input_dec, hidden)
            outputs[:, t] = output.squeeze(1)
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            top1 = output.argmax(dim=2)  # [Batch, 1]
            input_dec = trg[:, t].unsqueeze(1) if teacher_force else top1
        
        return outputs

encoder = Encoder(vocab_size, embed_dim=128, hidden_dim=256, num_layers=2, pad_idx=pad_id)
decoder = Decoder(vocab_size, embed_dim=128, hidden_dim=256, num_layers=2, pad_idx=pad_id)
model = Seq2Seq(encoder, decoder, pad_id).to(DEVICE)
print(model)

# Zelle 5: Trainings-Setup
criterion = nn.CrossEntropyLoss(ignore_index=pad_id)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
EPOCHS = 10

def train_one_epoch(model, loader, optimizer, criterion, device, teacher_forcing_ratio):
    model.train()
    total_loss = 0.0
    for src, trg in tqdm(loader, desc="Training", leave=False):
        src, trg = src.to(device), trg.to(device)
        optimizer.zero_grad()
        outputs = model(src, trg, teacher_forcing_ratio=teacher_forcing_ratio)
        output_dim = outputs.shape[-1]
        outputs = outputs[:, 1:].reshape(-1, output_dim)
        trg = trg[:, 1:].reshape(-1)
        loss = criterion(outputs, trg)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for src, trg in loader:
            src, trg = src.to(device), trg.to(device)
            outputs = model(src, trg, teacher_forcing_ratio=0.0)
            output_dim = outputs.shape[-1]
            outputs = outputs[:, 1:].reshape(-1, output_dim)
            trg = trg[:, 1:].reshape(-1)
            loss = criterion(outputs, trg)
            total_loss += loss.item()
    return total_loss / len(loader)

# Zelle 6: Trainings- und Validierungsschleife
best_val_loss = float("inf")
for epoch in range(1, EPOCHS + 1):
    print(f"\n=== Epoche {epoch}/{EPOCHS} ===")
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE, teacher_forcing_ratio=0.5)
    val_loss = evaluate(model, val_loader, criterion, DEVICE)
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        checkpoint_path = os.path.join("models", f"mailassist_epoch{epoch}_valloss{val_loss:.4f}.pt")
        os.makedirs("models", exist_ok=True)
        torch.save(model.state_dict(), checkpoint_path)
        print(f"** Neues bestes Modell gespeichert: {checkpoint_path}")

print("\nTraining abgeschlossen.")

# Zelle 7: Inferenz-Demo – E-Mail-Vervollständigung
def generate_completion(model, input_ids, word2id, id2word, max_length=20, device="cpu"):
    """
    Generiert eine E-Mail-Vervollständigung basierend auf einem gegebenen Input.
    
    Args:
        model: Das Seq2Seq-Modell.
        input_ids: Liste von Token-IDs, die den Anfang der E-Mail repräsentieren.
        word2id: Mapping von Token zu ID.
        id2word: Inverses Mapping.
        max_length: Maximale Anzahl an generierten Tokens (ohne das initiale <BOS>).
        device: "cuda" oder "cpu".
        
    Returns:
        Generierter Text als String.
    """
    model.eval()
    # Konvertiere den Input in einen Tensor
    src_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)
    # Encoder: Verarbeite den Input
    _, hidden = encoder(src_tensor)
    
    # Initialer Decoder-Input: <BOS>-Token der Zielsequenz
    bos_id = word2id.get("<BOS>", 2)
    input_dec = torch.tensor([[bos_id]], dtype=torch.long).to(device)
    generated_ids = [bos_id]
    
    with torch.no_grad():
        for _ in range(max_length):
            output, hidden = decoder(input_dec, hidden)
            next_id = output.argmax(dim=2).item()
            if next_id == word2id.get("<EOS>", 3):
                break
            generated_ids.append(next_id)
            input_dec = torch.tensor([[next_id]], dtype=torch.long).to(device)
    
    # Konvertiere generierte IDs in Text
    generated_text = " ".join([id2word.get(i, "<UNK>") for i in generated_ids])
    return generated_text

# Beispiel-Inferenz:
# Angenommen, "mailassist" trainierte Modelle nutzen den E-Mail-Body,
# hier ein Beispiel-Input (unvollständiger E-Mail-Text):
example_input = "sehr geehrte damen und herren, ich möchte sie darüber informieren"
# Tokenisierung: Hier einfach per split (in der Praxis dieselbe Tokenisierung wie in der Vorverarbeitung)
example_tokens = example_input.split()
input_ids = [word2id.get(token, word2id.get("<UNK>", 1)) for token in example_tokens]

completion = generate_completion(model, input_ids, word2id, id2word, max_length=20, device=DEVICE)
print("\nBeispiel für E-Mail-Vervollständigung:")
print("Input:", example_input)
print("Vervollständigung:", completion)

# Zelle 8: Fazit & Ausblick
print("""
Fazit:
- Das Seq2Seq-Modell für die E-Mail-Vervollständigung wurde erfolgreich trainiert.
- Die Evaluierung zeigt einen abnehmenden Validierungs-Loss.
- Erste Inferenz-Demos liefern sinnvolle Vorschläge für E-Mail-Vervollständigungen.
Nächste Schritte:
- Training auf einem größeren Datensatz.
- Experimentiere mit verschiedenen Hyperparametern und Sampling-Methoden.
- Optional: Erweiterung um Transformer-Modelle für noch natürlichere Textvorschläge.
""")
