In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import T5ForConditionalGeneration, T5Tokenizer, AdamW
import json
import logging
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
# ====== Гиперпараметры ======
MODEL_NAME = "t5-base"  # Можно попробовать "t5-large"
BATCH_SIZE = 8
EPOCHS = 5
LEARNING_RATE = 5e-5
MAX_LENGTH = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# ====== Загружаем обработанные датасеты ======
with open("datasets/processed_train.json", "r", encoding="utf-8") as f:
    train_data = json.load(f)

with open("datasets/processed_dev.json", "r", encoding="utf-8") as f:
    val_data = json.load(f)

In [None]:
# ====== Определяем Dataset ======
class SQLDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=MAX_LENGTH):
        self.tokenizer = tokenizer
        self.data = data
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        input_text = f"Schema: {item['schema']}\nQuery: {item['question']}\nSQL:"
        
        input_encodings = self.tokenizer(
            input_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        target_encodings = self.tokenizer(
            item["sql_query"], max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt"
        )

        return {
            "input_ids": input_encodings["input_ids"].squeeze(),
            "attention_mask": input_encodings["attention_mask"].squeeze(),
            "labels": target_encodings["input_ids"].squeeze(),
        }

In [None]:
# ====== Загружаем токенизатор и модель ======
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME).to(DEVICE)

In [None]:
# ====== Готовим DataLoader ======
train_dataset = SQLDataset(train_data, tokenizer)
val_dataset = SQLDataset(val_data, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

In [None]:
# ====== Оптимизатор ======
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

In [None]:
# ====== Логирование потерь ======
train_losses = []
val_losses = []

In [None]:
# ====== Обучение модели ======
for epoch in range(EPOCHS):
    model.train()
    total_train_loss = 0

    for batch in train_loader:
        optimizer.zero_grad()
        
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Валидация модели
    model.eval()
    total_val_loss = 0

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            total_val_loss += outputs.loss.item()

    avg_val_loss = total_val_loss / len(val_loader)
    val_losses.append(avg_val_loss)

    print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

    # Сохранение модели после каждой эпохи
    model.save_pretrained(f"t5_text2sql_epoch{epoch+1}")
    tokenizer.save_pretrained(f"t5_text2sql_epoch{epoch+1}")

In [None]:
# ====== Сохранение графика ======
plt.figure(figsize=(10, 5))
plt.plot(range(1, EPOCHS+1), train_losses, label="Train Loss", marker="o")
plt.plot(range(1, EPOCHS+1), val_losses, label="Validation Loss", marker="s")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training vs Validation Loss")
plt.legend()
plt.grid()
plt.savefig("training_loss_plot.png")
plt.show()

In [None]:
# ====== Сохранение логов в CSV ======
df = pd.DataFrame({"Epoch": list(range(1, EPOCHS+1)), "Train Loss": train_losses, "Validation Loss": val_losses})
df.to_csv("training_log.csv", index=False)

print("✅ Обучение завершено! График сохранен как training_loss_plot.png 🎉")
