In [1]:
import torch
import torch.nn as nn

class Discriminator(nn.Module):
    """
    Класс Discriminator представляет собой сверточную нейронную сеть, используемую для различения изображений.
    
    Аргументы:
    in_channels (int): Количество входных каналов. По умолчанию 4 (1 канал для источника и 3 канала для цели).
    
    Методы:
    __init__(self, in_channels=4): Инициализирует слои дискриминатора.
    forward(self, src, target): Выполняет прямое распространение через сеть. Объединяет входные изображения (src и target) по каналу и пропускает их через модель.
    
    Пример использования:
    discriminator = Discriminator()
    output = discriminator(src_image, target_image)
    """
    def __init__(self, in_channels=4):  # 2 channel source + 3 channels target
        super(Discriminator, self).__init__()
        def discriminator_block(in_channels, out_channels, stride):
            layers = [
                nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride, padding=1),
                nn.BatchNorm2d(out_channels),  # No BN in first layer
                nn.LeakyReLU(0.2, inplace=True)
            ]
            return layers

        self.model = nn.Sequential(
            # C64: 4x4 kernel, stride 2, padding 1
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            # C128: 4x4 kernel, stride 2, padding 1
            *discriminator_block(64, 128, stride=2),

            # C256: 4x4 kernel, stride 2, padding 1
            *discriminator_block(128, 256, stride=2),

            # C512: 4x4 kernel, stride 1, padding 1
            *discriminator_block(256, 512, stride=1),

            # C1: 4x4 kernel, stride 1, padding 1
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),

            # Sigmoid activation
            nn.Sigmoid()
        )


    def forward(self, src, target):
        x = torch.cat((src, target), dim=1)
        return self.model(x)

if __name__ == '__main__':
    D = Discriminator()
    x = torch.randn(1, 1, 256, 256)
    target = torch.randn(1, 3, 256, 256)
    print(f'{x.shape} + {target.shape} -> {D(x, target).shape}')  # torch.Size([1, 1, 30, 30])


torch.Size([1, 1, 256, 256]) + torch.Size([1, 3, 256, 256]) -> torch.Size([1, 1, 30, 30])


In [2]:
class UNetGenerator(nn.Module):
    def __init__(self, input_channels=1, output_channels=3):
        super(UNetGenerator, self).__init__()
        
        def encoder_block(in_channels, out_channels, use_bn=True):
            layers = [
                nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
                nn.LeakyReLU(0.2, inplace=True)
            ]
            if use_bn:
                layers.append(nn.BatchNorm2d(out_channels))
            return nn.Sequential(*layers)

        def decoder_block(in_channels, out_channels, dropout=0):
            layers = [
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
                nn.ReLU(inplace=True),
                nn.BatchNorm2d(out_channels)
            ]
            if dropout != 0:
                layers.append(nn.Dropout(dropout))
            return nn.Sequential(*layers)

        def bottleneck_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
                nn.ReLU(inplace=True),
            )

        # Encoder
        self.enc1 = encoder_block(input_channels, 64, use_bn=False)
        self.enc2 = encoder_block(64, 128)
        self.enc3 = encoder_block(128, 256)
        self.enc4 = encoder_block(256, 512)
        self.enc5 = encoder_block(512, 512)
        self.enc6 = encoder_block(512, 512)
        self.enc7 = encoder_block(512, 512)

        self.bottleneck = bottleneck_block(512, 512)

        # Decoder
        self.dec1 = decoder_block(512, 512, dropout=0.5)
        self.dec2 = decoder_block(1024, 512, dropout=0.5)
        self.dec3 = decoder_block(1024, 512, dropout=0.5)
        self.dec4 = decoder_block(1024, 512)
        self.dec5 = decoder_block(1024, 256)
        self.dec6 = decoder_block(512, 128)
        self.dec7 = decoder_block(256, 64)
        self.final = nn.ConvTranspose2d(128, output_channels, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)
        e6 = self.enc6(e5)
        e7 = self.enc7(e6)

        b = self.bottleneck(e7)

        # Decoder + Skip connections
        d1 = self.dec1(b)
        d2 = self.dec2(torch.cat([d1, e7], dim=1))
        d3 = self.dec3(torch.cat([d2, e6], dim=1))
        d4 = self.dec4(torch.cat([d3, e5], dim=1))
        d5 = self.dec5(torch.cat([d4, e4], dim=1))
        d6 = self.dec6(torch.cat([d5, e3], dim=1))
        d7 = self.dec7(torch.cat([d6, e2], dim=1))
        return torch.tanh(self.final(torch.cat([d7, e1], dim=1)))

if __name__ == '__main__':
    G = UNetGenerator()
    x = torch.randn(1, 1, 256, 256)
    print(f'{x.shape} -> {G(x).shape}')  # torch.Size([1, 3, 256, 256])


torch.Size([1, 1, 256, 256]) -> torch.Size([1, 3, 256, 256])


In [3]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn as nn

class Pix2PixGAN(nn.Module):
    def __init__(self, device):
        super(Pix2PixGAN, self).__init__()
        self.device = device
        self.generator = UNetGenerator().to(self.device)
        self.discriminator = Discriminator().to(self.device)

        self.optimizer_G = torch.optim.AdamW(self.generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
        self.optimizer_D = torch.optim.AdamW(self.discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

        self.scheduler_G = ReduceLROnPlateau(self.optimizer_G, mode='min', factor=0.5, patience=10)
        self.scheduler_D = ReduceLROnPlateau(self.optimizer_D, mode='min', factor=0.5, patience=10)

        self.criterion_GAN = nn.BCEWithLogitsLoss()
        self.criterion_L1 = nn.L1Loss()

    def train_step(self, real_A, real_B):
        real_A, real_B = real_A.to(self.device), real_B.to(self.device)

        # Train Discriminator
        self.optimizer_D.zero_grad()
        fake_B = self.generator(real_A)

        # Получаем выходные данные дискриминатора
        output_real = self.discriminator(real_A, real_B)
        output_fake = self.discriminator(real_A, fake_B.detach())

        # Создаем целевые метки
        target_real = torch.ones_like(output_real)
        target_fake = torch.zeros_like(output_fake)

        # Вычисляем потери
        loss_D_real = self.criterion_GAN(output_real, target_real)
        loss_D_fake = self.criterion_GAN(output_fake, target_fake)
        loss_D = (loss_D_real + loss_D_fake) / 2
        loss_D.backward()
        self.optimizer_D.step()

        # Train Generator
        self.optimizer_G.zero_grad()
        loss_G_GAN = self.criterion_GAN(self.discriminator(real_A, fake_B), torch.ones_like(output_real))
        loss_G_L1 = self.criterion_L1(fake_B, real_B) * 100
        loss_G = loss_G_GAN + loss_G_L1
        loss_G.backward()
        self.optimizer_G.step()

        return loss_D.item(), loss_G.item()

    def step_schedulers(self, loss_G, loss_D):
        self.scheduler_G.step(loss_G)
        self.scheduler_D.step(loss_D)


In [None]:
from utils.helpers import *
from utils.lossTracker import save_losses, load_losses
from utils.checkpointLogic import save_checkpoint, load_checkpoint
from tqdm.notebook import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu")
print(f'Using {device}')
torch.cuda.empty_cache() if torch.cuda.is_available() else torch.mps.empty_cache() if torch.mps.is_available() else None
model = Pix2PixGAN(device)

model = torch.nn.DataParallel(model).to(device)


load_model = False
if load_model:
    start_epoch, loss_G, loss_D = load_checkpoint('checkpoint_epoch_400', model, device)
    losses_dict = load_losses()
    if losses_dict:
        g_loss = losses_dict['g_loss']
        d_loss = losses_dict['d_loss']
    g_loss = []
    d_loss = []
else:
    start_epoch = 0
    g_loss = []
    d_loss = []

Using cuda


In [None]:
end_epoch = 1500

torch.backends.cudnn.benchmark = True

for epoch in range(start_epoch, end_epoch):
    with tqdm(train_loader, desc=f"Epoch {epoch+1}/{end_epoch}", leave=False) as pbar:
        for real_A, real_B in pbar:
            loss_D, loss_G = model.module.train_step(real_A, real_B)
            pbar.set_postfix({
                "Loss D": loss_D,
                "Loss G": loss_G,
                "LR D": model.module.optimizer_D.param_groups[0]['lr'],
                "LR G": model.module.optimizer_G.param_groups[0]['lr']
            })

        # Адаптация скорости обучения
        model.module.step_schedulers(loss_G, loss_D)

        g_loss.append(loss_G)
        d_loss.append(loss_D)

        if (epoch + 1) % 100 == 0:
            save_checkpoint(epoch, model, loss_G, loss_D)
            save_losses(g_loss=g_loss, d_loss=d_loss)


Epoch 1/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 2/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 3/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 4/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 5/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 6/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 7/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 8/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 9/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 10/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 11/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 12/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 13/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 14/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 15/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 16/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 17/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 18/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 19/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 20/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 21/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 22/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 23/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 24/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 25/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 26/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 27/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 28/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 29/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 30/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 31/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 32/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 33/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 34/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 35/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 36/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 37/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 38/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 39/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 40/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 41/1500:   0%|          | 0/182 [00:00<?, ?it/s]

Epoch 42/1500:   0%|          | 0/182 [00:00<?, ?it/s]