# ===========================================
# 3_training_demo.ipynb
# Training von Q&A-Modellen (Reader) für AskMeNow
# ===========================================

# Zelle 1: Importiere benötigte Bibliotheken und Setup
import os
import json
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# Gerät festlegen: GPU falls verfügbar, sonst CPU
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Trainiere auf Gerät:", DEVICE)

# Zelle 2: Definiere Pfade und lade Daten/Vokabular
PROCESSED_DATA_DIR = os.path.join("..", "data", "processed")
TRAIN_CSV = os.path.join(PROCESSED_DATA_DIR, "train.csv")  # Erwartet Spalten "question_ids" und "answer_ids"
VAL_CSV   = os.path.join(PROCESSED_DATA_DIR, "val.csv")
VOCAB_JSON = os.path.join(PROCESSED_DATA_DIR, "vocab.json")

# Lade Trainings- und Validierungsdaten
df_train = pd.read_csv(TRAIN_CSV)
df_val = pd.read_csv(VAL_CSV)

with open(VOCAB_JSON, "r", encoding="utf-8") as f:
    word2id = json.load(f)
id2word = {v: k for k, v in word2id.items()}

vocab_size = len(word2id)
pad_id = word2id.get("<PAD>", 0)
print(f"Train Samples: {len(df_train)} | Val Samples: {len(df_val)}")
print("Vokabulargröße:", vocab_size)
display(df_train.head(3))

# Zelle 3: Dataset-Klasse für Q&A (Fragen und Antworten)
class QADataset(Dataset):
    """
    Dataset für Q&A-Aufgaben.
    Erwartet eine CSV-Datei mit den Spalten:
      - "question_ids": Tokenisierte Frage (als Liste, z. B. "[2, 45, 78, 3]")
      - "answer_ids": Tokenisierte Antwort (als Liste, z. B. "[2, 56, 90, 3]")
    """
    def __init__(self, df, question_col="question_ids", answer_col="answer_ids"):
        self.df = df.copy()
        # Konvertiere String-Repräsentationen in echte Listen
        self.df[question_col] = self.df[question_col].apply(lambda x: eval(x) if isinstance(x, str) else x)
        self.df[answer_col] = self.df[answer_col].apply(lambda x: eval(x) if isinstance(x, str) else x)
        self.questions = self.df[question_col].tolist()
        self.answers = self.df[answer_col].tolist()
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        question_ids = self.questions[idx]
        answer_ids = self.answers[idx]
        question_tensor = torch.tensor(question_ids, dtype=torch.long)
        answer_tensor = torch.tensor(answer_ids, dtype=torch.long)
        return question_tensor, answer_tensor

def collate_fn(batch, pad_id=0):
    questions, answers = zip(*batch)
    questions_padded = pad_sequence(questions, batch_first=True, padding_value=pad_id)
    answers_padded = pad_sequence(answers, batch_first=True, padding_value=pad_id)
    return questions_padded, answers_padded

def create_qa_dataloader(csv_file, batch_size=32, shuffle=True, pad_id=0, question_col="question_ids", answer_col="answer_ids"):
    dataset = QADataset(pd.read_csv(csv_file), question_col=question_col, answer_col=answer_col)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=lambda batch: collate_fn(batch, pad_id=pad_id)
    )
    return loader

train_loader = create_qa_dataloader(TRAIN_CSV, batch_size=32, shuffle=True, pad_id=pad_id)
val_loader = create_qa_dataloader(VAL_CSV, batch_size=32, shuffle=False, pad_id=pad_id)

print("Beispielhafter Batch:")
q_batch, a_batch = next(iter(train_loader))
print("Questions-Shape:", q_batch.shape, "Answers-Shape:", a_batch.shape)

# Zelle 4: Modell definieren (Seq2Seq-LSTM für Q&A)
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, pad_idx):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
    
    def forward(self, src):
        embedded = self.embedding(src)  # [Batch, src_len, embed_dim]
        outputs, hidden = self.lstm(embedded)  # [Batch, src_len, hidden_dim]
        return outputs, hidden

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, pad_idx):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, trg, hidden):
        embedded = self.embedding(trg)  # [Batch, trg_len, embed_dim]
        outputs, hidden = self.lstm(embedded, hidden)  # [Batch, trg_len, hidden_dim]
        predictions = self.fc(outputs)  # [Batch, trg_len, vocab_size]
        return predictions, hidden

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, pad_idx):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pad_idx = pad_idx
    
    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.size(0)
        trg_len = trg.size(1)
        vocab_size = self.decoder.fc.out_features
        
        outputs = torch.zeros(batch_size, trg_len, vocab_size).to(src.device)
        _, hidden = self.encoder(src)
        
        # Erster Decoder-Input: <BOS>-Token (erster Token in trg)
        input_dec = trg[:, 0].unsqueeze(1)  # [Batch, 1]
        for t in range(1, trg_len):
            output, hidden = self.decoder(input_dec, hidden)
            outputs[:, t] = output.squeeze(1)
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            top1 = output.argmax(dim=2)  # [Batch, 1]
            input_dec = trg[:, t].unsqueeze(1) if teacher_force else top1
        return outputs

encoder_model = Encoder(vocab_size, embed_dim=128, hidden_dim=256, num_layers=2, pad_idx=pad_id)
decoder_model = Decoder(vocab_size, embed_dim=128, hidden_dim=256, num_layers=2, pad_idx=pad_id)
model = Seq2Seq(encoder_model, decoder_model, pad_id).to(DEVICE)
print(model)

# Zelle 5: Trainings-Setup
criterion = nn.CrossEntropyLoss(ignore_index=pad_id)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
EPOCHS = 10

def train_one_epoch(model, loader, optimizer, criterion, device, teacher_forcing_ratio):
    model.train()
    total_loss = 0
    for questions, answers in tqdm(loader, desc="Training", leave=False):
        questions, answers = questions.to(device), answers.to(device)
        optimizer.zero_grad()
        outputs = model(questions, answers, teacher_forcing_ratio=teacher_forcing_ratio)
        output_dim = outputs.shape[-1]
        outputs = outputs[:, 1:].reshape(-1, output_dim)
        targets = answers[:, 1:].reshape(-1)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for questions, answers in loader:
            questions, answers = questions.to(device), answers.to(device)
            outputs = model(questions, answers, teacher_forcing_ratio=0.0)
            output_dim = outputs.shape[-1]
            outputs = outputs[:, 1:].reshape(-1, output_dim)
            targets = answers[:, 1:].reshape(-1)
            loss = criterion(outputs, targets)
            total_loss += loss.item()
    return total_loss / len(loader)

# Zelle 6: Trainings- und Validierungsschleife
best_val_loss = float("inf")
for epoch in range(1, EPOCHS + 1):
    print(f"\n=== Epoche {epoch}/{EPOCHS} ===")
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE, teacher_forcing_ratio=0.5)
    val_loss = evaluate(model, val_loader, criterion, DEVICE)
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        checkpoint_path = os.path.join("models", f"askmenow_epoch{epoch}_valloss{val_loss:.4f}.pt")
        os.makedirs("models", exist_ok=True)
        torch.save(model.state_dict(), checkpoint_path)
        print(f"** Neues bestes Modell gespeichert: {checkpoint_path}")

print("\nTraining abgeschlossen.")

# Zelle 7: Inferenz-Demo – Generierung einer Antwort
def generate_answer(model, question_ids, word2id, id2word, max_length=20, device="cpu"):
    model.eval()
    # Konvertiere die Frage in einen Tensor
    src_tensor = torch.tensor([question_ids], dtype=torch.long).to(device)
    _, hidden = encoder_model(src_tensor)
    # Initialer Decoder-Input: <BOS>-Token
    bos_id = word2id.get("<BOS>", 2)
    input_dec = torch.tensor([[bos_id]], dtype=torch.long).to(device)
    generated_ids = [bos_id]
    with torch.no_grad():
        for _ in range(max_length):
            output, hidden = decoder_model(input_dec, hidden)
            next_id = output.argmax(dim=2).item()
            if next_id == word2id.get("<EOS>", 3):
                break
            generated_ids.append(next_id)
            input_dec = torch.tensor([[next_id]], dtype=torch.long).to(device)
    generated_answer = " ".join([id2word.get(i, "<UNK>") for i in generated_ids])
    return generated_answer

# Beispiel-Inferenz:
example_question = "Wie funktioniert ein neuronales Netz?"
# Tokenisierung: einfache Splittung (in der Praxis dieselbe Logik wie in der Vorverarbeitung)
example_tokens = example_question.strip().lower().split()
question_ids = [word2id.get(token, word2id.get("<UNK>", 1)) for token in example_tokens]

generated_answer = generate_answer(model, question_ids, word2id, id2word, max_length=20, device=DEVICE)
print("\nBeispiel generierte Antwort:")
print("Frage:", example_question)
print("Antwort:", generated_answer)

# Zelle 8: Fazit & Ausblick
print("""
Fazit:
- Das Q&A-Modell (Reader) wurde erfolgreich auf den FAQ-Daten trainiert.
- Die Evaluierung zeigt einen kontinuierlich sinkenden Validierungs-Loss.
- Erste Inferenz-Demos liefern sinnvolle Antworten auf gestellte Fragen.
Nächste Schritte:
- Training auf größeren und diversifizierten Datensätzen.
- Erweiterung des Modells um Transformer-basierte Ansätze.
- Integration von Retrieval-Komponenten zur Verbesserung der Antwortgenauigkeit.
""")
