# ===========================================
# 3_training_demo.ipynb
# Training mit Feedback-basiertem Fine-Tuning (LSTM/Transformer)
# ===========================================

# Zelle 1: Importiere benötigte Bibliotheken und Setup
import os
import json
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from tqdm import tqdm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Trainiere auf Gerät:", DEVICE)

# Definiere Pfade
PROCESSED_DATA_DIR = "../data/processed"
TRAIN_CSV = os.path.join(PROCESSED_DATA_DIR, "train.csv")
VAL_CSV = os.path.join(PROCESSED_DATA_DIR, "val.csv")
VOCAB_JSON = os.path.join(PROCESSED_DATA_DIR, "vocab.json")

# Zelle 2: Lade Daten und Vokabular
df_train = pd.read_csv(TRAIN_CSV)
df_val = pd.read_csv(VAL_CSV)
with open(VOCAB_JSON, "r", encoding="utf-8") as f:
    word2id = json.load(f)
id2word = {v: k for k, v in word2id.items()}

vocab_size = len(word2id)
pad_id = word2id.get("<PAD>", 0)
print(f"Train Samples: {len(df_train)} | Val Samples: {len(df_val)}")
print("Vokabulargröße:", vocab_size)

# Zelle 3: Definiere Dataset-Klasse inklusive Feedback-Label
class FeedbackDataset(Dataset):
    """
    Dataset für Feedback AI.
    
    Erwartet eine CSV-Datei mit folgenden Spalten:
      - "text_ids": Tokenisierte Textsequenz (Liste als String, z. B. "[2, 45, 78, 3]")
      - "feedback_label": Label, das das Nutzerfeedback repräsentiert (z.B. 1 für thumbs_up, 0 für thumbs_down)
    
    Bei jedem Zugriff liefert __getitem__ ein Tupel (text_tensor, text_tensor, feedback_label),
    wobei hier für ein Seq2Seq-Training Input und Target identisch sein können, wenn der Text vervollständigt werden soll.
    """
    def __init__(self, df, text_col="token_ids", label_col="feedback_label"):
        self.df = df.copy()
        self.df[text_col] = self.df[text_col].apply(lambda x: eval(x) if isinstance(x, str) else x)
        self.df[label_col] = self.df[label_col].astype(int)
        self.texts = self.df[text_col].tolist()
        self.labels = self.df[label_col].tolist()
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        text_ids = self.texts[idx]
        label = self.labels[idx]
        # Hier nutzen wir den gleichen Text als Input und Target (um Vervollständigungen zu trainieren)
        text_tensor = torch.tensor(text_ids, dtype=torch.long)
        return text_tensor, text_tensor, label  # (Input, Target, Feedback)

def feedback_collate_fn(batch, pad_id=0):
    texts_in, texts_out, labels = zip(*batch)
    texts_in_padded = pad_sequence(texts_in, batch_first=True, padding_value=pad_id)
    texts_out_padded = pad_sequence(texts_out, batch_first=True, padding_value=pad_id)
    labels_tensor = torch.tensor(labels, dtype=torch.float)  # Als float für Gewichtung
    return texts_in_padded, texts_out_padded, labels_tensor

def create_feedback_dataloader(csv_file, batch_size=32, shuffle=True, pad_id=0, text_col="token_ids", label_col="feedback_label"):
    dataset = FeedbackDataset(pd.read_csv(csv_file), text_col=text_col, label_col=label_col)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=lambda batch: feedback_collate_fn(batch, pad_id=pad_id)
    )
    return loader

train_loader = create_feedback_dataloader(TRAIN_CSV, batch_size=32, shuffle=True, pad_id=pad_id)
val_loader = create_feedback_dataloader(VAL_CSV, batch_size=32, shuffle=False, pad_id=pad_id)

print("Beispielhafter Batch:")
batch_in, batch_out, batch_labels = next(iter(train_loader))
print("Input-Shape:", batch_in.shape, "Labels-Shape:", batch_labels.shape)

# Zelle 4: Modell definieren (Seq2Seq-LSTM)
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, pad_idx):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
    
    def forward(self, src):
        embedded = self.embedding(src)
        outputs, hidden = self.lstm(embedded)
        return outputs, hidden

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, pad_idx):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, trg, hidden):
        embedded = self.embedding(trg)
        outputs, hidden = self.lstm(embedded, hidden)
        predictions = self.fc(outputs)
        return predictions, hidden

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, pad_idx):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pad_idx = pad_idx
    
    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.size(0)
        trg_len = trg.size(1)
        vocab_size = self.decoder.fc.out_features
        outputs = torch.zeros(batch_size, trg_len, vocab_size).to(src.device)
        
        _, hidden = self.encoder(src)
        input_dec = trg[:, 0].unsqueeze(1)
        
        for t in range(1, trg_len):
            output, hidden = self.decoder(input_dec, hidden)
            outputs[:, t] = output.squeeze(1)
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            top1 = output.argmax(dim=2)
            input_dec = trg[:, t].unsqueeze(1) if teacher_force else top1
        
        return outputs

encoder = Encoder(vocab_size, embed_dim=128, hidden_dim=256, num_layers=2, pad_idx=pad_id)
decoder = Decoder(vocab_size, embed_dim=128, hidden_dim=256, num_layers=2, pad_idx=pad_id)
model = Seq2Seq(encoder, decoder, pad_id).to(DEVICE)
print(model)

# Zelle 5: Trainings-Setup
criterion = nn.CrossEntropyLoss(ignore_index=pad_id, reduction="none")
optimizer = optim.Adam(model.parameters(), lr=1e-3)
EPOCHS = 5

def train_one_epoch(model, loader, optimizer, criterion, device, teacher_forcing_ratio):
    model.train()
    total_loss = 0.0
    # Wir berechnen den Verlust pro Beispiel und multiplizieren ihn mit dem Feedback-Gewicht.
    for src, trg, feedback in tqdm(loader, desc="Training", leave=False):
        src, trg, feedback = src.to(device), trg.to(device), feedback.to(device)  # feedback: [Batch]
        optimizer.zero_grad()
        outputs = model(src, trg, teacher_forcing_ratio=teacher_forcing_ratio)
        # outputs: [Batch, trg_len, vocab_size]
        output_dim = outputs.shape[-1]
        outputs = outputs[:, 1:].reshape(-1, output_dim)  # Ignoriere <BOS>
        trg = trg[:, 1:].reshape(-1)
        # Berechne den Verlust pro Token (ohne Mittelung)
        loss_per_token = criterion(outputs, trg)  # [Batch*trg_len]
        # Reshape, um den Verlust pro Beispiel zu summieren:
        loss_per_example = loss_per_token.view(src.size(0), -1).sum(dim=1)  # [Batch]
        # Feedback-Gewichtung: Angenommen, positives Feedback (Label 1) erhält Gewicht 1, negatives (0) 0.5
        weights = torch.where(feedback == 1, torch.tensor(1.0, device=device), torch.tensor(0.5, device=device))
        weighted_loss = (loss_per_example * weights).mean()
        weighted_loss.backward()
        optimizer.step()
        total_loss += weighted_loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for src, trg, _ in loader:
            src, trg = src.to(device), trg.to(device)
            outputs = model(src, trg, teacher_forcing_ratio=0.0)
            output_dim = outputs.shape[-1]
            outputs = outputs[:, 1:].reshape(-1, output_dim)
            trg = trg[:, 1:].reshape(-1)
            loss = criterion(outputs, trg)
            total_loss += loss.mean().item()
    return total_loss / len(loader)

# Zelle 6: Trainings- und Validierungsschleife
best_val_loss = float("inf")
for epoch in range(1, EPOCHS + 1):
    print(f"\n=== Epoche {epoch}/{EPOCHS} ===")
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE, teacher_forcing_ratio=0.5)
    val_loss = evaluate(model, val_loader, criterion, DEVICE)
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        checkpoint_path = os.path.join("models", f"feedback_ai_epoch{epoch}_valloss{val_loss:.4f}.pt")
        os.makedirs("models", exist_ok=True)
        torch.save(model.state_dict(), checkpoint_path)
        print(f"** Neues bestes Modell gespeichert: {checkpoint_path}")

print("\nTraining abgeschlossen.")

# Zelle 7: Inferenz-Demo – Generierung eines Feedback-gesteuerten Textes
def generate_feedback_text(model, input_ids, word2id, id2word, max_length=20, device="cpu"):
    model.eval()
    src_tensor = torch.tensor([input_ids], dtype=torch.long).to(device)
    _, hidden = encoder(src_tensor)
    bos_id = word2id.get("<BOS>", 2)
    input_dec = torch.tensor([[bos_id]], dtype=torch.long).to(device)
    generated_ids = [bos_id]
    with torch.no_grad():
        for _ in range(max_length):
            output, hidden = decoder(input_dec, hidden)
            next_id = output.argmax(dim=2).item()
            if next_id == word2id.get("<EOS>", 3):
                break
            generated_ids.append(next_id)
            input_dec = torch.tensor([[next_id]], dtype=torch.long).to(device)
    generated_text = " ".join([id2word.get(i, "<UNK>") for i in generated_ids])
    return generated_text

# Beispiel-Inferenz:
example_input = "sehr geehrte damen und herren, ich möchte sie darüber informieren"
# Tokenisierung: einfach per split (in der Praxis dieselbe Tokenisierung wie in der Vorverarbeitung)
example_tokens = example_input.strip().lower().split()
input_ids = [word2id.get(token, word2id.get("<UNK>", 1)) for token in example_tokens]

completion = generate_feedback_text(model, input_ids, word2id, id2word, max_length=20, device=DEVICE)
print("\nBeispiel generierter Textvorschlag:")
print("Input:", example_input)
print("Vervollständigung:", completion)

# Zelle 8: Fazit & Ausblick
print("""
Fazit:
- Das Modell wurde mit Feedback-basiertem Fine-Tuning trainiert.
- Positive Feedback-Beispiele wurden höher gewichtet als negative.
- Erste Inferenz-Demos zeigen, dass das Modell bevorzugt Texte generiert, die positives Feedback erhalten.
Nächste Schritte:
- Training auf einem größeren Datensatz.
- Erweiterung auf Transformer-basierte Ansätze (z.B. GPT-2/T5) mit RLHF.
- Analyse der Korrelation zwischen bestimmten Textmustern und dem Nutzerfeedback.
""")
