In [1]:
import torch
from torch import nn
from torch.optim import Adam
from tqdm import tqdm
import os
import sys
import numpy as np

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

# Импорт модели UNet из файла models/UNet.py
from models.UNet import UNet
from utils.helpers import *

# Настройки
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 200  # Количество эпох
LEARNING_RATE = 1e-4  # Скорость обучения
SAVE_MODEL_PATH = f"{project_root}/saved_models"  # Папка для сохранения модели

# Убедимся, что папка для сохранения модели существует
os.makedirs(SAVE_MODEL_PATH, exist_ok=True)

# Инициализация модели, функции потерь и оптимизатора
model = UNet(in_channels=1, output_channels=3).to(DEVICE)
criterion = nn.L1Loss()  # L1 Loss
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
model.load_state_dict(torch.load('../saved_models/best_model.pth', map_location=DEVICE))

# Функция для тренировки одной эпохи
def train_one_epoch(loader, model, criterion, optimizer, device):
    model.train()
    epoch_loss = 0

    for sar, optical in tqdm(loader, desc="Training", leave=False):
        sar, optical = sar.to(device), optical.to(device)

        # Прямой проход
        output = model(sar)

        # Вычисление функции потерь
        loss = criterion(output, optical)
        epoch_loss += loss.item()

        # Обратный проход
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return epoch_loss / len(loader)

# Функция для валидации
def validate(loader, model, criterion, device):
    model.eval()
    epoch_loss = 0

    with torch.no_grad():
        for sar, optical in tqdm(loader, desc="Validation", leave=False):
            sar, optical = sar.to(device), optical.to(device)

            # Прямой проход
            output = model(sar)

            # Вычисление функции потерь
            loss = criterion(output, optical)
            epoch_loss += loss.item()

    return epoch_loss / len(loader)

# Основной цикл обучения
def train_model(train_loader, test_loader, model, criterion, optimizer, epochs, device):
    for epoch in range(epochs):
        print(f"Epoch [{epoch + 1}/{epochs}]")

        train_loss = train_one_epoch(train_loader, model, criterion, optimizer, device)
        val_loss = validate(test_loader, model, criterion, device)

        print(f"Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

        # Сохраняем модель после каждой эпохи
        model_save_path = os.path.join(SAVE_MODEL_PATH, f"unet_epoch_{epoch + 1}.pth")
        torch.save(model.state_dict(), model_save_path)
        print(f"Model saved to {model_save_path}")

def train_model_with_early_stopping(train_loader, test_loader, model, criterion, optimizer, epochs, device, patience=10):
    best_val_loss = float(0.23165545669160312)  # Изначально лучшее значение — бесконечность
    save_path = "../saved_models/best_model.pth"  # Путь для сохранения лучшей модели
    no_improvement_epochs = 0  # Счётчик эпох без улучшения
    train_losses = []
    val_losses = []

    for epoch in range(epochs):
        print(f"Epoch [{epoch + 1}/{epochs}]")

        # Тренировка
        train_loss = train_one_epoch(train_loader, model, criterion, optimizer, device)
        # Валидация
        val_loss = validate(test_loader, model, criterion, device)
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        np.savez('losses.npz', train_losses=train_losses, val_losses=val_losses)

        print(f"Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

        # Проверяем, улучшилась ли валидационная потеря
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path)
            no_improvement_epochs = 0
            print(f"Validation loss improved. Model saved to {save_path}")
        else:
            no_improvement_epochs += 1
            print(f"No improvement for {no_improvement_epochs} epoch(s).")

        # Проверяем условие ранней остановки
        if no_improvement_epochs >= patience:
            print(f"Early stopping triggered after {patience} epochs without improvement. Last epoch: {epoch + 1}")
            break

    print(f"Training complete. Best model saved at: {save_path}")


# Запуск тренировки
if __name__ == "__main__":
    train_model_with_early_stopping(train_loader, test_loader, model, criterion, optimizer, EPOCHS, DEVICE)


  model.load_state_dict(torch.load('../saved_models/best_model.pth', map_location=DEVICE))


Epoch [1/200]


                                                             

Train Loss: 0.1711, Validation Loss: 0.2255
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [2/200]


                                                             

Train Loss: 0.1664, Validation Loss: 0.2220
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [3/200]


                                                             

Train Loss: 0.1619, Validation Loss: 0.2240
No improvement for 1 epoch(s).
Epoch [4/200]


                                                             

Train Loss: 0.1623, Validation Loss: 0.2258
No improvement for 2 epoch(s).
Epoch [5/200]


                                                             

Train Loss: 0.1524, Validation Loss: 0.2189
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [6/200]


                                                             

Train Loss: 0.1476, Validation Loss: 0.2188
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [7/200]


                                                             

Train Loss: 0.1434, Validation Loss: 0.2178
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [8/200]


                                                             

Train Loss: 0.1379, Validation Loss: 0.2165
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [9/200]


                                                             

Train Loss: 0.1361, Validation Loss: 0.2146
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [10/200]


                                                             

Train Loss: 0.1446, Validation Loss: 0.2189
No improvement for 1 epoch(s).
Epoch [11/200]


                                                             

Train Loss: 0.1292, Validation Loss: 0.2148
No improvement for 2 epoch(s).
Epoch [12/200]


                                                             

Train Loss: 0.1260, Validation Loss: 0.2149
No improvement for 3 epoch(s).
Epoch [13/200]


                                                             

Train Loss: 0.1234, Validation Loss: 0.2182
No improvement for 4 epoch(s).
Epoch [14/200]


                                                             

Train Loss: 0.1228, Validation Loss: 0.2260
No improvement for 5 epoch(s).
Epoch [15/200]


                                                             

Train Loss: 0.1210, Validation Loss: 0.2118
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [16/200]


                                                             

Train Loss: 0.1170, Validation Loss: 0.2150
No improvement for 1 epoch(s).
Epoch [17/200]


                                                             

Train Loss: 0.1152, Validation Loss: 0.2138
No improvement for 2 epoch(s).
Epoch [18/200]


                                                             

Train Loss: 0.1142, Validation Loss: 0.2141
No improvement for 3 epoch(s).
Epoch [19/200]


                                                             

Train Loss: 0.1109, Validation Loss: 0.2114
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [20/200]


                                                             

Train Loss: 0.1109, Validation Loss: 0.2120
No improvement for 1 epoch(s).
Epoch [21/200]


                                                             

Train Loss: 0.1075, Validation Loss: 0.2126
No improvement for 2 epoch(s).
Epoch [22/200]


                                                             

Train Loss: 0.1069, Validation Loss: 0.2121
No improvement for 3 epoch(s).
Epoch [23/200]


                                                             

Train Loss: 0.1047, Validation Loss: 0.2116
No improvement for 4 epoch(s).
Epoch [24/200]


                                                             

Train Loss: 0.1040, Validation Loss: 0.2149
No improvement for 5 epoch(s).
Epoch [25/200]


                                                             

Train Loss: 0.1042, Validation Loss: 0.2108
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [26/200]


                                                             

Train Loss: 0.1051, Validation Loss: 0.2117
No improvement for 1 epoch(s).
Epoch [27/200]


                                                             

Train Loss: 0.0995, Validation Loss: 0.2118
No improvement for 2 epoch(s).
Epoch [28/200]


                                                             

Train Loss: 0.0975, Validation Loss: 0.2122
No improvement for 3 epoch(s).
Epoch [29/200]


                                                             

Train Loss: 0.0964, Validation Loss: 0.2105
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [30/200]


                                                             

Train Loss: 0.0962, Validation Loss: 0.2096
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [31/200]


                                                             

Train Loss: 0.0958, Validation Loss: 0.2106
No improvement for 1 epoch(s).
Epoch [32/200]


                                                             

Train Loss: 0.0954, Validation Loss: 0.2117
No improvement for 2 epoch(s).
Epoch [33/200]


                                                             

Train Loss: 0.0931, Validation Loss: 0.2098
No improvement for 3 epoch(s).
Epoch [34/200]


                                                             

Train Loss: 0.0921, Validation Loss: 0.2109
No improvement for 4 epoch(s).
Epoch [35/200]


                                                             

Train Loss: 0.0912, Validation Loss: 0.2116
No improvement for 5 epoch(s).
Epoch [36/200]


                                                             

Train Loss: 0.0920, Validation Loss: 0.2108
No improvement for 6 epoch(s).
Epoch [37/200]


                                                             

Train Loss: 0.0898, Validation Loss: 0.2101
No improvement for 7 epoch(s).
Epoch [38/200]


                                                             

Train Loss: 0.0902, Validation Loss: 0.2095
Validation loss improved. Model saved to ../saved_models/best_model.pth
Epoch [39/200]


                                                             

Train Loss: 0.0886, Validation Loss: 0.2112
No improvement for 1 epoch(s).
Epoch [40/200]


                                                             

Train Loss: 0.0880, Validation Loss: 0.2103
No improvement for 2 epoch(s).
Epoch [41/200]


                                                             

Train Loss: 0.0870, Validation Loss: 0.2101
No improvement for 3 epoch(s).
Epoch [42/200]


                                                             

Train Loss: 0.0870, Validation Loss: 0.2111
No improvement for 4 epoch(s).
Epoch [43/200]


                                                             

Train Loss: 0.0856, Validation Loss: 0.2122
No improvement for 5 epoch(s).
Epoch [44/200]


                                                             

Train Loss: 0.0869, Validation Loss: 0.2103
No improvement for 6 epoch(s).
Epoch [45/200]


                                                             

Train Loss: 0.0852, Validation Loss: 0.2121
No improvement for 7 epoch(s).
Epoch [46/200]


                                                             

Train Loss: 0.0850, Validation Loss: 0.2104
No improvement for 8 epoch(s).
Epoch [47/200]


                                                             

Train Loss: 0.0839, Validation Loss: 0.2106
No improvement for 9 epoch(s).
Epoch [48/200]


                                                             

Train Loss: 0.0840, Validation Loss: 0.2109
No improvement for 10 epoch(s).
Early stopping triggered after 10 epochs without improvement. Last epoch: 48
Training complete. Best model saved at: ../saved_models/best_model.pth


