# ===========================================
# 3_training_demo.ipynb
# ===========================================

# Zelle 1: Bibliotheken & Setup
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import os
import json

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Gerät:", DEVICE)

PROCESSED_DATA_DIR = "../data/processed"

# ===========================================
# Zelle 2: Daten laden
# ===========================================
train_csv = os.path.join(PROCESSED_DATA_DIR, "train.csv")
val_csv   = os.path.join(PROCESSED_DATA_DIR, "val.csv")
vocab_json= os.path.join(PROCESSED_DATA_DIR, "vocab.json")

train_df = pd.read_csv(train_csv)
val_df   = pd.read_csv(val_csv)

with open(vocab_json, "r", encoding="utf-8") as f:
    word2id = json.load(f)
id2word = {v: k for k,v in word2id.items()}

vocab_size = len(word2id)
print(f"Train Samples: {len(train_df)} | Val Samples: {len(val_df)}")
print("Vokabulargröße:", vocab_size)

# Annahme: train_df/val_df haben Spalten user_ids, bot_ids (List of ints im Stringformat wie "[2,45,78]")
display(train_df.head(3))

# ===========================================
# Zelle 3: Dataset & DataLoader
# ===========================================
class ChatDataset(Dataset):
    """
    Erwartet user_ids und bot_ids in Spalten (z. B. 'user_ids', 'bot_ids').
    Jede Zeile repräsentiert ein User->Bot-Paar.
    """
    def __init__(self, df, user_col="user_ids", bot_col="bot_ids"):
        self.df = df
        self.user_col = user_col
        self.bot_col = bot_col
        
        # Konvertiere die Strings (z. B. "[2, 45, 78]") in Python-Listen
        self.df[self.user_col] = self.df[self.user_col].apply(lambda x: eval(x) if isinstance(x, str) else x)
        self.df[self.bot_col]  = self.df[self.bot_col].apply(lambda x: eval(x) if isinstance(x, str) else x)
        
        self.user_samples = self.df[self.user_col].tolist()
        self.bot_samples  = self.df[self.bot_col].tolist()
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        user_ids = self.user_samples[idx]
        bot_ids  = self.bot_samples[idx]
        
        # Input=User, Target=Bot => Training: "Vorhersage" Bot
        user_tensor = torch.tensor(user_ids, dtype=torch.long)
        bot_tensor  = torch.tensor(bot_ids,  dtype=torch.long)
        
        return user_tensor, bot_tensor


def collate_fn(batch, pad_id=0):
    """
    Collate-Funktion, um User/Bot-Sequenzen in einem Batch mit Pad auf gleiche Länge zu bringen.
    """
    from torch.nn.utils.rnn import pad_sequence
    user_seqs, bot_seqs = zip(*batch)
    
    user_padded = pad_sequence(user_seqs, batch_first=True, padding_value=pad_id)
    bot_padded  = pad_sequence(bot_seqs,  batch_first=True, padding_value=pad_id)
    
    return user_padded, bot_padded

BATCH_SIZE = 32

train_dataset = ChatDataset(train_df)
val_dataset   = ChatDataset(val_df)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=lambda b: collate_fn(b, pad_id=word2id.get("<PAD>",0))
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=lambda b: collate_fn(b, pad_id=word2id.get("<PAD>",0))
)

# ===========================================
# Zelle 4: Modell definieren (z. B. LSTM)
# ===========================================
class LSTMChatModel(nn.Module):
    """
    Einfaches LSTM-Modell für Chat:
    Eingabe=User (tokens), Ausgabe=Bot-Sequence pro Zeitschritt.
    """
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=1, pad_idx=0, dropout=0.0):
        super(LSTMChatModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout if num_layers>1 else 0.0)
        self.fc   = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, user_seq, hidden=None):
        # user_seq: [Batch, UserSeqLen]
        emb = self.embedding(user_seq)              # -> [Batch, UserSeqLen, embed_dim]
        out, hidden = self.lstm(emb, hidden)        # -> [Batch, UserSeqLen, hidden_dim]
        logits = self.fc(out)                       # -> [Batch, UserSeqLen, vocab_size]
        return logits, hidden

model = LSTMChatModel(
    vocab_size=vocab_size,
    embed_dim=128,
    hidden_dim=256,
    num_layers=2,
    pad_idx=word2id.get("<PAD>",0),
    dropout=0.1
).to(DEVICE)

print(model)

# ===========================================
# Zelle 5: Trainings-Setup
# ===========================================
criterion = nn.CrossEntropyLoss(ignore_index=word2id.get("<PAD>",0))
optimizer = optim.Adam(model.parameters(), lr=1e-3)
EPOCHS = 3

def train_one_epoch(model, loader):
    model.train()
    total_loss = 0
    for user_batch, bot_batch in tqdm(loader, desc="Train", leave=False):
        user_batch = user_batch.to(DEVICE)
        bot_batch  = bot_batch.to(DEVICE)
        
        logits, _ = model(user_batch)
        # logits: [Batch, UserSeqLen, vocab_size]
        # Wir wollen, dass unser Modell die Bot-Sequenz Wort-für-Wort predicten kann.
        # Einfachster Ansatz: Gleiche Länge wie user_seq -> suboptimal, 
        # aber als Demo ok. (Besser: Concat user_seq + shifted bot_seq)
        
        vocab_s = logits.size(-1)
        logits = logits.reshape(-1, vocab_s)  # [Batch*UserSeqLen, vocab_size]
        bot_flat = bot_batch.view(-1)        # [Batch*UserSeqLen]
        
        loss = criterion(logits, bot_flat)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for user_batch, bot_batch in loader:
            user_batch = user_batch.to(DEVICE)
            bot_batch  = bot_batch.to(DEVICE)
            
            logits, _ = model(user_batch)
            vocab_s = logits.size(-1)
            logits = logits.view(-1, vocab_s)
            bot_flat = bot_batch.view(-1)
            loss = criterion(logits, bot_flat)
            total_loss += loss.item()
    return total_loss / len(loader)

# ===========================================
# Zelle 6: Training starten
# ===========================================
best_val_loss = float("inf")
for epoch in range(1, EPOCHS+1):
    print(f"\n=== Epoche {epoch}/{EPOCHS} ===")
    train_loss = train_one_epoch(model, train_loader)
    val_loss   = evaluate(model, val_loader)
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), f"chatpal_best_epoch{epoch}_valloss{val_loss:.4f}.pt")
        print("** Neues bestes Modell gespeichert")

# ===========================================
# Zelle 7: Kurzer Inferenz-Test (simple)
# ===========================================
def sample_token(logits, temperature=1.0):
    """
    Greedy oder random-Sampling je nach dem, was du willst.
    Hier: Greedy
    """
    scaled_logits = logits / max(temperature, 1e-8)
    return torch.argmax(scaled_logits).item()

def respond(model, user_tokens, max_len=20):
    """
    Einfache Generierung:
    - Eingabe: user_tokens (Liste von IDs)
    - Ausgabe: Bot-Vorhersage pro Zeitstep
      (hier vereinfachtes unvollständiges Verfahren, 
       nur um Idee zu zeigen; 
       i. d. R. müsste man ein Chat-Decoder-Ansatz oder 
       Concat user+bot implementieren)
    """
    model.eval()
    user_seq = torch.tensor([user_tokens], dtype=torch.long).to(DEVICE)
    with torch.no_grad():
        logits, _ = model(user_seq)  # [1, seq_len, vocab_size]
        # Nimm die letzten Zeitsteps nacheinander, 
        # hier sehr vereinfacht:
        # Besser wäre: user_seq -> <BOS_BOT> generieren
        generated = []
        # Greedy: Letztes Zeitstep hernehmen
        for t in range(min(max_len, logits.shape[1])):
            token_id = sample_token(logits[0, t, :], temperature=1.0)
            generated.append(token_id)
    return generated

test_user_text = "hello"
test_user_tokens = []
for w in test_user_text.split():
    test_user_tokens.append(word2id.get(w, word2id.get("<UNK>",1)))
# Füge BOS/EOS hinzu, falls nötig:
# (z. B. <BOS_USER> + user_tokens + <EOS_USER>)
BOS_USER_ID = word2id.get("<BOS_USER>", 2)
EOS_USER_ID = word2id.get("<EOS_USER>", 3)
test_user_tokens = [BOS_USER_ID] + test_user_tokens + [EOS_USER_ID]

gen_bot_ids = respond(model, test_user_tokens, max_len=10)
gen_bot_tokens = [id2word.get(i, "<UNK>") for i in gen_bot_ids]
print("\nBeispiel Inferenz: ")
print("User:", test_user_text)
print("Bot :", " ".join(gen_bot_tokens))

# ===========================================
# Zelle 8: Fazit
# ===========================================
print("""
In diesem Notebook haben wir ein einfaches LSTM-Modell trainiert, 
das bei Eingabe von user_tokens ein Bot-Vorhersage-Logits liefert.
Das Inferenzverfahren ist noch sehr vereinfacht; 
in einem richtigen Chat-Szenario würde man 
auch Bot-Token schrittweise decodieren 
und den gesamten Kontext (User + Bot Historie) beachten.
""")
