In [None]:
import os
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
from models.Pix2Pix_youtube import Pix2PixGAN
from utils.Dataset import *
from utils.lossTracker import save_losses, load_losses
import matplotlib.gridspec as gridspec
from utils.ConfigLoader import ConfigLoader
config = ConfigLoader()
import torchvision

import torch.profiler
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()  # Инициализация логгера

def train(model, train_loader, device): 
    with tqdm(train_loader, desc=f"Epoch {epoch+1}/{end_epoch} Train", leave=False) as pbar:
        train_loss_G_history, train_loss_D_history, train_ssim_history, train_psnr_history = [], [], [], []
        for real_A, real_B in pbar:
            train_loss_G, train_loss_D, train_ssim, train_psnr = model.train_step(real_A.to(device), real_B.to(device))
            pbar.set_postfix({
                "Loss D": train_loss_D,
                "Loss G": train_loss_G,
                "LR D": model.optimizer_D.param_groups[0]['lr'],
                "LR G": model.optimizer_G.param_groups[0]['lr'],
            })
            train_loss_G_history.append(train_loss_G)
            train_loss_D_history.append(train_loss_D)
            train_ssim_history.append(train_ssim)
            train_psnr_history.append(train_psnr)
    return (torch.mean(torch.tensor(train_loss_G)),
            torch.mean(torch.tensor(train_loss_D)),
            torch.mean(torch.tensor(train_ssim_history)),
            torch.mean(torch.tensor(train_psnr_history)),
            model.optimizer_G.param_groups[0]['lr'],
            model.optimizer_D.param_groups[0]['lr'])

@torch.no_grad()
def validate(model, val_loader, device):
    """
    Валидационный цикл для оценки модели на валидационном датасете.
    """
    model.generator.eval()  # Перевод генератора в режим валидации
    model.discriminator.eval()  # Перевод дискриминатора в режим валидации

    with tqdm(val_loader, desc=f"Epoch {epoch+1}/{end_epoch} Validation", leave=False) as pbar:
        val_loss_G_history, val_loss_D_history, val_ssim_history, val_psnr_history, = [], [], [], []
        for real_A, real_B in pbar:
            val_loss_G, val_loss_D, val_ssim, val_psnr = model.val_step(real_A, real_B)

            val_loss_G_history.append(val_loss_G)
            val_loss_D_history.append(val_loss_D)
            val_ssim_history.append(val_ssim)
            val_psnr_history.append(val_psnr)

            pbar.set_postfix({
                "Val Loss G": val_loss_G,
                "Val Loss D": val_loss_D,
                "Val SSIM": val_ssim,
                "Val PSNR": val_psnr
            })

    return (torch.mean(torch.tensor(val_loss_G)),
            torch.mean(torch.tensor(val_loss_D)),
            torch.mean(torch.tensor(val_ssim_history)),
            torch.mean(torch.tensor(val_psnr_history)))

@torch.no_grad()
def save_training_images(model, epoch, train_loss_G, train_loss_D, val_loss_G, val_loss_D, save_dir, train_fixed_sar, train_fixed_optical, val_fixed_sar, val_fixed_optical):
    """
    Сохраняет графики генератора/дискриминатора потерь и фиксированные пять изображений.
    """
    os.makedirs(save_dir, exist_ok=True)

    # Генерация фиксированных изображений
    train_generated = model.generator(train_fixed_sar.to(device))
    val_generated = model.generator(val_fixed_sar.to(device))

    # Move tensors to CPU for visualization and ensure same device
    train_fixed_sar = train_fixed_sar.cpu().repeat(1, 3, 1, 1)
    train_generated = train_generated.cpu()
    train_fixed_optical = train_fixed_optical.cpu()
    val_fixed_sar = val_fixed_sar.cpu().repeat(1, 3, 1, 1)
    val_generated = val_generated.cpu()
    val_fixed_optical = val_fixed_optical.cpu()

    # Log generated images to TensorBoard
    train_grid = torchvision.utils.make_grid(
        torch.cat([
            train_fixed_sar, 
            train_generated,
            train_fixed_optical
        ], dim=0),
        nrow=5,
        normalize=True
    )

    val_grid = torchvision.utils.make_grid(
        torch.cat([
            val_fixed_sar,
            val_generated, 
            val_fixed_optical
        ], dim=0),
        nrow=5, 
        normalize=True
    )

    fig = plt.figure(figsize=(30, 40))
    gs = gridspec.GridSpec(7, 6, figure=fig)

    fig.suptitle(f"Epoch: {epoch+1}, G lr: {model.optimizer_D.param_groups[0]['lr']}, D lr: {model.optimizer_G.param_groups[0]['lr']}", fontsize=16)  # y задает отступ сверху

    # График потерь генератора
    ax1 = fig.add_subplot(gs[0, :3])
    ax1.plot(range(1, len(train_loss_G) + 1), train_loss_G, label="Train Generator Loss", color="#3b82f6")
    ax1.set_title("Train Generator Loss")
    ax1.set_xlabel("Epochs")
    ax1.set_ylabel("Loss")
    ax1.grid()
    ax1.legend()

    # График потерь дискриминатора
    ax2 = fig.add_subplot(gs[1, :3])
    ax2.plot(range(1, len(train_loss_D) + 1), train_loss_D, label="Train Discriminator Loss", color="#ef4444")
    ax2.set_title("Train Discriminator Loss")
    ax2.set_xlabel("Epochs")
    ax2.set_ylabel("Loss")
    ax2.grid()
    ax2.legend()

    ax3 = fig.add_subplot(gs[0, 3:])
    ax3.plot(range(1, len(val_loss_G) + 1), val_loss_G, label="Val Generator Loss", color="#22c55e")
    ax3.set_title("Validation Generator Loss")
    ax3.set_xlabel("Epochs")
    ax3.set_ylabel("Loss")
    ax3.grid()
    ax3.legend()

    ax4 = fig.add_subplot(gs[1, 3:])
    ax4.plot(range(1, len(val_loss_D) + 1), val_loss_D, label="Val Discriminator Loss", color="#f59e0b")
    ax4.set_title("Validation Discriminator Loss")
    ax4.set_xlabel("Epochs")
    ax4.set_ylabel("Loss")
    ax4.grid()
    ax4.legend()

    # Добавляем изображения: SAR, Generated, Target
    for i in range(5):
        # SAR Image
        ax_sar = fig.add_subplot(gs[2 + i, 0])
        sar_image = train_fixed_sar[i, 0, :, :].cpu().detach().numpy()
        ax_sar.imshow(sar_image * 0.5 + 0.5, cmap='gray')
        ax_sar.set_title(f"Train SAR Image {i+1}")
        ax_sar.axis('off')

        # Generated Image
        ax_gen = fig.add_subplot(gs[2 + i, 1])
        generated_image = train_generated[i].permute(1, 2, 0).cpu().detach().numpy()
        ax_gen.imshow((generated_image * 0.5 + 0.5))
        ax_gen.set_title(f"Train Generated Image {i+1}")
        ax_gen.axis('off')

        # Target Image
        ax_opt = fig.add_subplot(gs[2 + i, 2])
        optical_image = train_fixed_optical[i].permute(1, 2, 0).cpu().detach().numpy()
        ax_opt.imshow((optical_image * 0.5 + 0.5))
        ax_opt.set_title(f"Train Target Image {i+1}")
        ax_opt.axis('off')

        # SAR Image
        ax_sar = fig.add_subplot(gs[2 + i, 3])
        sar_image = val_fixed_sar[i, 0, :, :].cpu().detach().numpy()
        ax_sar.imshow(sar_image * 0.5 + 0.5, cmap='gray')
        ax_sar.set_title(f"Val SAR Image {i+1}")
        ax_sar.axis('off')

        # Generated Image
        ax_gen = fig.add_subplot(gs[2 + i, 4])
        generated_image = val_generated[i].permute(1, 2, 0).cpu().detach().numpy()
        ax_gen.imshow((generated_image * 0.5 + 0.5))
        ax_gen.set_title(f"Val Generated Image {i+1}")
        ax_gen.axis('off')

        # Target Image
        ax_opt = fig.add_subplot(gs[2 + i, 5])
        optical_image = val_fixed_optical[i].permute(1, 2, 0).cpu().detach().numpy()
        ax_opt.imshow((optical_image * 0.5 + 0.5))
        ax_opt.set_title(f"Val Target Image {i+1}")
        ax_opt.axis('off')

    # Настройка расстояний между элементами
    plt.tight_layout()

    # Сохранение итогового изображения
    save_path = os.path.join(save_dir, f"epoch_{epoch+1}_images.png")
    plt.savefig(save_path)
    plt.close(fig)
    return train_grid, val_grid


# Устройство для вычислений
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu")
print(f'Using {device}')
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Получаем 5 первых изображений из тренировочного загрузчика
train_iterator = iter(train_loader)  # Создаем итератор для train_loader
fixed_batch = next(train_iterator)  # Получаем первый batch
train_fixed_sar, train_fixed_optical = fixed_batch[0][:5], fixed_batch[1][:5]  # Берем 5 первых изображений

test_iterator = iter(test_loader)  # Создаем итератор для train_loader
fixed_batch = next(test_iterator)  # Получаем первый batch
val_fixed_sar, val_fixed_optical = fixed_batch[0][:5], fixed_batch[1][:5]  # Берем 5 первых изображений

# Создание модели
model = Pix2PixGAN(device)

# Загрузка модели
if config.get('model', 'load_model'):
    start_epoch = model.load_state('checkpoint_epoch_100', device)
    losses_dict = load_losses()
    if losses_dict:
        train_G_losses = list(losses_dict['train_G_losses'])
        train_D_losses = list(losses_dict['train_D_losses'])
        val_G_losses = list(losses_dict['val_G_losses'])
        val_D_losses = list(losses_dict['val_D_losses'])
else:
    start_epoch = 0
    train_G_losses = []
    train_D_losses = []
    val_G_losses = []
    val_D_losses = []

# Конечная эпоха
end_epoch = config.get('model', 'end_epoch')
torch.backends.cudnn.benchmark = True

# Обучение модели
for epoch in range(start_epoch, end_epoch):
    train_loss_G, train_loss_D, train_ssim, train_psnr, lr_G, lr_D = train(model, train_loader, device)
    val_loss_G, val_loss_D, val_ssim, val_psnr = validate(model, test_loader, device)

    model.step_schedulers(val_loss_G, val_loss_D)

    # memory_allocated = torch.cuda.memory_allocated(device) / (1024 ** 2)  # В мегабайтах
    # memory_reserved = torch.cuda.memory_reserved(device) / (1024 ** 2)   # В мегабайтах

    train_G_losses.append(train_loss_G)
    train_D_losses.append(train_loss_D)
    val_G_losses.append(val_loss_G)
    val_D_losses.append(val_loss_D)

    if writer:
        writer.add_scalar("Train/Loss_G", train_loss_G.item(), epoch + 1)
        writer.add_scalar("Train/Loss_D", train_loss_D.item(), epoch + 1)
        writer.add_scalar("Train/PSNR", train_psnr.item(), epoch + 1)
        writer.add_scalar("Train/SSIM", train_ssim.item(), epoch + 1)
        writer.add_scalar("Train/Learning_Rate_G", lr_G, epoch + 1)
        writer.add_scalar("Train/Learning_Rate_D", lr_D, epoch + 1)

        writer.add_scalar("Val/Loss_G", val_loss_G.item(), epoch + 1)
        writer.add_scalar("Val/Loss_D", val_loss_D.item(), epoch + 1)
        writer.add_scalar("Val/PSNR", val_psnr.item(), epoch + 1)
        writer.add_scalar("Val/SSIM", val_ssim.item(), epoch + 1)

        # Гистограммы весов
        # for name, param in model.generator.named_parameters():
        #     writer.add_histogram(f'Generator/{name}', param, epoch + 1)

        # Логгирование памяти в TensorBoard
        # writer.add_scalar("Performance/Memory_Allocated_MB", memory_allocated, global_step=epoch)
        # writer.add_scalar("Performance/Memory_Reserved_MB", memory_reserved, global_step=epoch)

    # Сохранение модели и метрик
    if (epoch + 1) % 50 == 0:
        # model.save_state(epoch, save_dir=config.get('paths', 'model_save_dir'))
        save_losses(
            train_G_losses=train_G_losses,
            train_D_losses=train_D_losses,
            val_G_losses=val_G_losses,
            val_D_losses=val_D_losses
        )

    # Сохранение изображений каждые 20 эпох
    if (epoch + 1) % 20 == 0:
        train_grid, val_grid = save_training_images(
            model,
            epoch,
            train_G_losses,
            train_D_losses,
            val_G_losses,
            val_D_losses,
            config.get('paths', 'image_save_dir'),
            train_fixed_sar, train_fixed_optical,
            val_fixed_sar, val_fixed_optical
        )
        writer.add_image('Train/Train_Images', train_grid, global_step=epoch+1)
        writer.add_image('Val/Val_Images', val_grid, global_step=epoch+1)

  check_for_updates()


Using cuda


  self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)


Setting up [LPIPS] perceptual loss: trunk [squeeze], v[0.1], spatial [off]
Loading model from: c:\Users\tiruu\AppData\Local\Programs\Python\Python39\lib\site-packages\lpips\weights\v0.1\squeeze.pth


Epoch 5/400 Train:  36%|███▌      | 65/182 [00:20<00:36,  3.25it/s, Loss D=1.09, Loss G=27.8, LR D=0.000394, LR G=9.84e-5]                   

In [None]:


import os
os.system("shutdown /s /t 60")

0