In [1]:
import torch
from torch import nn
from torch.optim import Adam
from tqdm import tqdm
import os
import sys
import numpy as np

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

# Импорт модели UNet из файла models/UNet.py
from models.UNet import UNet
from utils.helpers import *

# Настройки
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 200  # Количество эпох
LEARNING_RATE = 1e-5  # Скорость обучения
SAVE_MODEL_PATH = f"{project_root}/saved_models"  # Папка для сохранения модели

# Убедимся, что папка для сохранения модели существует
os.makedirs(SAVE_MODEL_PATH, exist_ok=True)

# Инициализация модели, функции потерь и оптимизатора
model = UNet(in_channels=1, output_channels=3).to(DEVICE)
criterion = nn.L1Loss()  # L1 Loss
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)

# Функция для тренировки одной эпохи
def train_one_epoch(loader, model, criterion, optimizer, device):
    model.train()
    epoch_loss = 0

    for sar, optical in tqdm(loader, desc="Training", leave=False):
        sar, optical = sar.to(device), optical.to(device)

        # Прямой проход
        output = model(sar)

        # Вычисление функции потерь
        loss = criterion(output, optical)
        epoch_loss += loss.item()

        # Обратный проход
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return epoch_loss / len(loader)

# Функция для валидации
def validate(loader, model, criterion, device):
    model.eval()
    epoch_loss = 0

    with torch.no_grad():
        for sar, optical in tqdm(loader, desc="Validation", leave=False):
            sar, optical = sar.to(device), optical.to(device)

            # Прямой проход
            output = model(sar)

            # Вычисление функции потерь
            loss = criterion(output, optical)
            epoch_loss += loss.item()

    return epoch_loss / len(loader)

# Основной цикл обучения
def train_model(train_loader, test_loader, model, criterion, optimizer, epochs, device):
    for epoch in range(epochs):
        print(f"Epoch [{epoch + 1}/{epochs}]")

        train_loss = train_one_epoch(train_loader, model, criterion, optimizer, device)
        val_loss = validate(test_loader, model, criterion, device)

        print(f"Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

        # Сохраняем модель после каждой эпохи
        model_save_path = os.path.join(SAVE_MODEL_PATH, f"unet_epoch_{epoch + 1}.pth")
        torch.save(model.state_dict(), model_save_path)
        print(f"Model saved to {model_save_path}")

def train_model_with_early_stopping(train_loader, test_loader, model, criterion, optimizer, epochs, device, patience=25):
    best_val_loss = float("inf")  # Изначально лучшее значение — бесконечность
    save_path = "../saved_models/best_model.pth"  # Путь для сохранения лучшей модели
    model.load_state_dict(torch.load(save_path, map_location=DEVICE))
    no_improvement_epochs = 0  # Счётчик эпох без улучшения
    train_losses = []
    val_losses = []

    for epoch in range(37, epochs):
        print(f"Epoch [{epoch + 1}/{epochs}]")

        # Тренировка
        train_loss = train_one_epoch(train_loader, model, criterion, optimizer, device)
        # Валидация
        val_loss = validate(test_loader, model, criterion, device)
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        np.savez('losses.npz', train_losses=train_losses, val_losses=val_losses)

        print(f"Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

        # Проверяем, улучшилась ли валидационная потеря
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            no_improvement_epochs = 0
            print(f"Validation loss improved. Model saved to {save_path}")
        else:
            no_improvement_epochs += 1
            print(f"No improvement for {no_improvement_epochs} epoch(s).")

        # Проверяем условие ранней остановки
        if no_improvement_epochs >= patience:
            print(f"Early stopping triggered after {patience} epochs without improvement. Last epoch: {epoch + 1}")
            break

    print(f"Training complete. Best model saved at: {save_path}")


# Запуск тренировки
if __name__ == "__main__":
    train_model_with_early_stopping(train_loader, test_loader, model, criterion, optimizer, EPOCHS, DEVICE)


  model.load_state_dict(torch.load(save_path, map_location=DEVICE))


Epoch [38/200]


                                                           

Train Loss: 0.2632, Validation Loss: 0.3114
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [39/200]


                                                           

Train Loss: 0.2616, Validation Loss: 0.3121
No improvement for 1 epoch(s).
Epoch [40/200]


                                                           

Train Loss: 0.2620, Validation Loss: 0.3057
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [41/200]


                                                           

Train Loss: 0.2644, Validation Loss: 0.3139
No improvement for 1 epoch(s).
Epoch [42/200]


                                                           

Train Loss: 0.2612, Validation Loss: 0.3021
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [43/200]


                                                           

Train Loss: 0.2601, Validation Loss: 0.3070
No improvement for 1 epoch(s).
Epoch [44/200]


                                                           

Train Loss: 0.2619, Validation Loss: 0.3014
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [45/200]


                                                           

Train Loss: 0.2608, Validation Loss: 0.3045
No improvement for 1 epoch(s).
Epoch [46/200]


                                                           

Train Loss: 0.2612, Validation Loss: 0.3049
No improvement for 2 epoch(s).
Epoch [47/200]


                                                           

Train Loss: 0.2598, Validation Loss: 0.2998
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [48/200]


                                                           

Train Loss: 0.2624, Validation Loss: 0.3088
No improvement for 1 epoch(s).
Epoch [49/200]


                                                           

Train Loss: 0.2592, Validation Loss: 0.3078
No improvement for 2 epoch(s).
Epoch [50/200]


                                                           

Train Loss: 0.2586, Validation Loss: 0.3060
No improvement for 3 epoch(s).
Epoch [51/200]


                                                           

Train Loss: 0.2598, Validation Loss: 0.3039
No improvement for 4 epoch(s).
Epoch [52/200]


                                                           

Train Loss: 0.2621, Validation Loss: 0.3094
No improvement for 5 epoch(s).
Epoch [53/200]


                                                           

Train Loss: 0.2591, Validation Loss: 0.3132
No improvement for 6 epoch(s).
Epoch [54/200]


                                                           

Train Loss: 0.2556, Validation Loss: 0.3126
No improvement for 7 epoch(s).
Epoch [55/200]


                                                           

Train Loss: 0.2611, Validation Loss: 0.3058
No improvement for 8 epoch(s).
Epoch [56/200]


                                                           

Train Loss: 0.2583, Validation Loss: 0.3086
No improvement for 9 epoch(s).
Epoch [57/200]


                                                           

Train Loss: 0.2615, Validation Loss: 0.3011
No improvement for 10 epoch(s).
Epoch [58/200]


                                                           

Train Loss: 0.2577, Validation Loss: 0.3687
No improvement for 11 epoch(s).
Epoch [59/200]


                                                           

Train Loss: 0.2567, Validation Loss: 0.2979
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [60/200]


                                                           

Train Loss: 0.2592, Validation Loss: 0.3187
No improvement for 1 epoch(s).
Epoch [61/200]


                                                           

Train Loss: 0.2563, Validation Loss: 0.3024
No improvement for 2 epoch(s).
Epoch [62/200]


                                                           

Train Loss: 0.2565, Validation Loss: 0.3179
No improvement for 3 epoch(s).
Epoch [63/200]


                                                           

Train Loss: 0.2595, Validation Loss: 0.3173
No improvement for 4 epoch(s).
Epoch [64/200]


                                                           

Train Loss: 0.2565, Validation Loss: 0.3035
No improvement for 5 epoch(s).
Epoch [65/200]


                                                           

Train Loss: 0.2582, Validation Loss: 0.3016
No improvement for 6 epoch(s).
Epoch [66/200]


                                                           

Train Loss: 0.2556, Validation Loss: 0.3024
No improvement for 7 epoch(s).
Epoch [67/200]


                                                           

Train Loss: 0.2574, Validation Loss: 0.3098
No improvement for 8 epoch(s).
Epoch [68/200]


                                                           

Train Loss: 0.2557, Validation Loss: 0.2975
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [69/200]


                                                           

Train Loss: 0.2563, Validation Loss: 0.3068
No improvement for 1 epoch(s).
Epoch [70/200]


                                                           

Train Loss: 0.2550, Validation Loss: 0.3047
No improvement for 2 epoch(s).
Epoch [71/200]


                                                           

Train Loss: 0.2547, Validation Loss: 0.3092
No improvement for 3 epoch(s).
Epoch [72/200]


                                                           

Train Loss: 0.2570, Validation Loss: 0.2940
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [73/200]


                                                           

Train Loss: 0.2562, Validation Loss: 0.2948
No improvement for 1 epoch(s).
Epoch [74/200]


                                                           

Train Loss: 0.2549, Validation Loss: 0.3018
No improvement for 2 epoch(s).
Epoch [75/200]


                                                           

Train Loss: 0.2556, Validation Loss: 0.2954
No improvement for 3 epoch(s).
Epoch [76/200]


                                                           

Train Loss: 0.2538, Validation Loss: 0.3119
No improvement for 4 epoch(s).
Epoch [77/200]


                                                           

Train Loss: 0.2519, Validation Loss: 0.3044
No improvement for 5 epoch(s).
Epoch [78/200]


                                                           

Train Loss: 0.2554, Validation Loss: 0.3060
No improvement for 6 epoch(s).
Epoch [79/200]


                                                           

Train Loss: 0.2530, Validation Loss: 0.3051
No improvement for 7 epoch(s).
Epoch [80/200]


                                                           

Train Loss: 0.2542, Validation Loss: 0.3102
No improvement for 8 epoch(s).
Epoch [81/200]


                                                           

Train Loss: 0.2553, Validation Loss: 0.3066
No improvement for 9 epoch(s).
Epoch [82/200]


                                                           

Train Loss: 0.2554, Validation Loss: 0.3140
No improvement for 10 epoch(s).
Epoch [83/200]


                                                           

Train Loss: 0.2535, Validation Loss: 0.3079
No improvement for 11 epoch(s).
Epoch [84/200]


                                                           

Train Loss: 0.2513, Validation Loss: 0.3132
No improvement for 12 epoch(s).
Epoch [85/200]


                                                           

Train Loss: 0.2521, Validation Loss: 0.3120
No improvement for 13 epoch(s).
Epoch [86/200]


                                                           

Train Loss: 0.2532, Validation Loss: 0.3106
No improvement for 14 epoch(s).
Epoch [87/200]


                                                           

Train Loss: 0.2530, Validation Loss: 0.2975
No improvement for 15 epoch(s).
Epoch [88/200]


                                                           

Train Loss: 0.2537, Validation Loss: 0.3184
No improvement for 16 epoch(s).
Epoch [89/200]


                                                           

Train Loss: 0.2544, Validation Loss: 0.2991
No improvement for 17 epoch(s).
Epoch [90/200]


                                                           

Train Loss: 0.2547, Validation Loss: 0.2980
No improvement for 18 epoch(s).
Epoch [91/200]


                                                           

Train Loss: 0.2512, Validation Loss: 0.3416
No improvement for 19 epoch(s).
Epoch [92/200]


                                                           

Train Loss: 0.2512, Validation Loss: 0.3021
No improvement for 20 epoch(s).
Epoch [93/200]


                                                           

Train Loss: 0.2515, Validation Loss: 0.2921
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [94/200]


                                                           

Train Loss: 0.2495, Validation Loss: 0.2965
No improvement for 1 epoch(s).
Epoch [95/200]


                                                           

Train Loss: 0.2538, Validation Loss: 0.3006
No improvement for 2 epoch(s).
Epoch [96/200]


                                                           

Train Loss: 0.2503, Validation Loss: 0.3061
No improvement for 3 epoch(s).
Epoch [97/200]


                                                          

KeyboardInterrupt: 