In [1]:
import matplotlib.pyplot as plt
import torch
from torch import nn
import os
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import numpy as np
from torchvision.utils import save_image
from tqdm import tqdm
from torchvision.models import vgg19

In [2]:
class DiscriminatorBlock(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int=4,
            stride: int= 2
    ):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                bias=False,
                padding_mode='reflect'
            ),
            nn.BatchNorm2d(num_features=out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.block(x)

In [3]:
class Discriminator(nn.Module):
    def __init__(
            self,
            in_channels: int=3,
            kernel_size: int=4,
            stride: int=2,
            padding: int=1
    ):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels = in_channels*2,
                out_channels = 64,
                kernel_size = kernel_size,
                stride = stride,
                padding = padding,
                padding_mode = 'reflect'
            )
        )
        self.conv1 = DiscriminatorBlock(64, 128, kernel_size, 2)
        self.conv2 = DiscriminatorBlock(128, 256, kernel_size, 2)
        self.conv3 = DiscriminatorBlock(256, 512, kernel_size, 1)
        self.final = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=kernel_size, stride=1,
                               padding=1, padding_mode='reflect')

        self.model = nn.Sequential(
            self.initial, self.conv1, self.conv2, self.conv3, self.final
        )

    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)
        return self.model(x)

In [4]:
def test_discriminator():
    a = Discriminator()
    img = torch.ones(10, 3, 256, 256)
    with torch.no_grad():
        print(a.forward(img, img).shape)

In [5]:
test_discriminator()

torch.Size([10, 1, 26, 26])


In [6]:
class GeneratorBlockDown(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int
    ):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode='reflect',
                bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )
    def forward(self, x):
        return self.block(x)

In [7]:
class GeneratorBlockUp(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            dropout: bool
    ):
        super().__init__()
        self.block = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5) if dropout else nn.Identity()
        )
    def forward(self, x):
        return self.block(x)

In [8]:
class Generator(nn.Module):
    def __init__(self, in_channels: int=3, features: int=64):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=features,
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode='reflect'
            ),
            nn.LeakyReLU(0.2)
        )
        self.down1 = GeneratorBlockDown(in_channels=features, out_channels=features*2)
        self.down2 = GeneratorBlockDown(in_channels=features*2, out_channels=features*4)
        self.down3 = GeneratorBlockDown(in_channels=features*4, out_channels=features*8)
        self.down4 = GeneratorBlockDown(in_channels=features*8, out_channels=features*8)
        self.down5 = GeneratorBlockDown(in_channels=features*8, out_channels=features*8)
        self.down6 = GeneratorBlockDown(in_channels=features*8, out_channels=features*8)

        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_channels=features*8,
                      out_channels=features*8,
                      kernel_size=(4,4),
                      stride=(2,2),
                      padding=(1,1),
                      padding_mode="reflect"),
            nn.ReLU()
        )

        self.up1 = GeneratorBlockUp(in_channels=features*8, out_channels=features*8, dropout=True)
        self.up2 = GeneratorBlockUp(in_channels=features*8*2, out_channels=features*8, dropout=True)
        self.up3 = GeneratorBlockUp(in_channels=features*8*2, out_channels=features*8, dropout=True)
        self.up4 = GeneratorBlockUp(in_channels=features*8*2, out_channels=features*8, dropout=False)
        self.up5 = GeneratorBlockUp(in_channels=features*8*2, out_channels=features*4, dropout=False)
        self.up6 = GeneratorBlockUp(in_channels=features*4*2, out_channels=features*2, dropout=False)
        self.up7 = GeneratorBlockUp(in_channels=features*4, out_channels=features, dropout=False)

        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=features*2,
                out_channels=in_channels,
                kernel_size=(4,4),
                stride=(2,2),
                padding=(1,1)
            ),
            nn.Tanh()
        )
    def forward(self, x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        # print(d7.shape, up1.shape)
        up2 = self.up2(torch.cat([d7, up1], dim=1))
        up3 = self.up3(torch.cat([d6, up2], dim=1))
        up4 = self.up4(torch.cat([d5, up3], dim=1))
        up5 = self.up5(torch.cat([d4, up4], dim=1))
        up6 = self.up6(torch.cat([d3, up5], dim=1))
        up7 = self.up7(torch.cat([d2, up6], dim=1))
        return self.final_up(torch.cat([d1, up7], dim=1))

In [9]:
def test_generator():
    test_imgs = torch.ones(10, 3, 256, 256)
    gen = Generator(3, 64)
    print(gen(test_imgs).shape)

In [10]:
test_generator()

torch.Size([10, 3, 256, 256])


In [11]:
class MyDataset(Dataset):
    def __init__(self, folder_path, val=False):
        super().__init__()
        self.dir = folder_path
        self.all_files = os.listdir(folder_path)
        self.both_transform = A.Compose(
            [
                A.Resize(width=256, height=256),
                A.HorizontalFlip(p=0.5 if not val else 0)
            ],
            additional_targets={
                'image0': 'image'
            }
        )
        self.transform_in = A.Compose(
            transforms=[
                        A.Normalize(mean=[0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5]),
                        ToTensorV2()]
        )
        self.transform_out = A.Compose(
            transforms=[
                 A.Normalize(mean=[0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5]),
                 ToTensorV2()
            ]
        )
    def __len__(self):
        return len(self.all_files)
    def __getitem__(self, idx):
        img_path = os.path.join(self.dir, self.all_files[idx])
        image = np.array(Image.open(img_path))

        in_image = image[:, :256, :]
        out_image = image[:, 256:, :]

        augmentations = self.both_transform(image=in_image, image0=out_image)
        in_image, out_image = augmentations['image'], augmentations['image0']
        in_image = self.transform_in(image=in_image)['image']
        out_image = self.transform_out(image=out_image)['image']
        return in_image, out_image

In [12]:
class VGGLoss:
    def __init__(self):
        self.vgg = vgg19(pretrained=True).features[:35].eval().to(device)
        for param in self.vgg.parameters():
            param.requires_grad = False
        self.mse = nn.MSELoss()

    def forward(self, x, y):
        real_f = self.vgg(x)
        gen_f = self.vgg(y)
        return self.mse(real_f, gen_f)

In [13]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [14]:
def save_example(generator, img_path='pr_data/val/cat.4038.jpg', res_name=''):
    image = np.array(Image.open(img_path))
    in_image = image[:, :256, :]
    out_image = image[:, 256:, :]

    Image.fromarray(in_image).save(f'examples_bad/{res_name+"_"}bad.jpg')
    Image.fromarray(out_image).save(f'examples_orig/{res_name+"_"}orig.jpg')

    transform = A.Compose(
            transforms=[
                A.Resize(256, 256),
                A.Normalize(mean=[0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5]),
                ToTensorV2()
            ]
        )
    in_img = transform(image=in_image)['image'].unsqueeze(0)
    with torch.no_grad():
        image_nn = generator(in_img.to(device)).detach().cpu()

    save_image(((image_nn*0.5)+0.5).abs(), f'examples_nn/{res_name+"_"}nn.jpg')

In [15]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
generator = Generator().to(device)
save_example(generator)

In [16]:
generator = Generator().to(device)

In [17]:
generator.eval()
save_example(generator)

In [18]:
def save_model_checkpoint(model, optimizer, epoch, filename):
    print('saving checkpoint')

    checkpoint = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch
    }
    torch.save(checkpoint, filename+f'_{epoch}.pth')

In [19]:
def load_checkpoint(model, optimizer, lr, file):

    global curr_epoch
    print('load checkpoint')
    checkpoint = torch.load(file, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    curr_epoch = checkpoint['epoch']
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [20]:
batch_size = 8

In [21]:
train_dataset = MyDataset(folder_path='pr_data/train', val=False)
test_dataset = MyDataset(folder_path='pr_data/val', val=True)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

Обучим ган

In [22]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [23]:
gen_optimizer = torch.optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
disc_optimizer = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))

In [24]:
preload = False
curr_epoch = 0
if preload:
        load_checkpoint(generator, gen_optimizer, 2e-4, 'save_models/pix2pix_gen_10.pth')
        load_checkpoint(discriminator, disc_optimizer, 2e-4, 'save_models/pix2pix_disc_10.pth')


In [25]:
generator.eval()
discriminator.eval()
save_example(generator)

In [26]:
num_epochs=200
bce = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()

In [29]:
for epoch in range(curr_epoch+1, num_epochs):
    print(f'epoch_{epoch}')
    discriminator.train(True)
    generator.train(True)
    train_loss = 0
    test_loss = 0

    for bad_image, image in tqdm(train_dataloader, leave=True):
        ### Train discriminator
        image_fake = generator(bad_image.to(device))
        disc_fake = discriminator(image_fake.detach(), bad_image.to(device))
        disc_real = discriminator(image.to(device), bad_image.to(device))

        discriminator_loss = 0.5 * (bce(disc_real, torch.ones_like(disc_real)).to(device) +\
                             bce(disc_fake, torch.zeros_like(disc_fake).to(device)))
        discriminator.zero_grad()
        discriminator_loss.backward()
        disc_optimizer.step()


        ### Train Generator
        disc_fake = discriminator(image_fake, bad_image.to(device))
        cotent_loss = l1_loss.forward(image_fake, image.to(device))
        adversarial_loss = bce(disc_fake, torch.ones_like(disc_fake).to(device))
        generator_loss = 100 * cotent_loss + adversarial_loss

        generator.zero_grad()
        generator_loss.backward()
        gen_optimizer.step()

    if epoch % 1 == 0:
        with torch.no_grad():
            generator.eval()
            save_example(generator, res_name=f'butterfly_gan{epoch}')

    if epoch % 5 == 0:
        save_model_checkpoint(generator, gen_optimizer, epoch, 'save_models/pix2pix_gen')
        save_model_checkpoint(discriminator, disc_optimizer, epoch, 'save_models/pix2pix_disc')




epoch_1


100%|██████████| 1000/1000 [07:50<00:00,  2.13it/s]


epoch_2


100%|██████████| 1000/1000 [07:33<00:00,  2.20it/s]


epoch_3


100%|██████████| 1000/1000 [07:36<00:00,  2.19it/s]


epoch_4


100%|██████████| 1000/1000 [07:30<00:00,  2.22it/s]


epoch_5


100%|██████████| 1000/1000 [07:30<00:00,  2.22it/s]


saving checkpoint
saving checkpoint
epoch_6


100%|██████████| 1000/1000 [07:30<00:00,  2.22it/s]


epoch_7


100%|██████████| 1000/1000 [07:30<00:00,  2.22it/s]


epoch_8


100%|██████████| 1000/1000 [07:30<00:00,  2.22it/s]


epoch_9


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_10


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


saving checkpoint
saving checkpoint
epoch_11


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_12


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_13


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_14


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_15


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


saving checkpoint
saving checkpoint
epoch_16


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_17


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_18


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_19


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_20


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


saving checkpoint
saving checkpoint
epoch_21


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_22


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_23


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_24


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_25


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


saving checkpoint
saving checkpoint
epoch_26


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_27


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_28


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_29


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_30


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


saving checkpoint
saving checkpoint
epoch_31


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_32


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_33


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_34


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_35


100%|██████████| 1000/1000 [07:29<00:00,  2.23it/s]


saving checkpoint
saving checkpoint
epoch_36


100%|██████████| 1000/1000 [07:29<00:00,  2.23it/s]


epoch_37


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_38


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_39


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_40


100%|██████████| 1000/1000 [07:29<00:00,  2.23it/s]


saving checkpoint
saving checkpoint
epoch_41


100%|██████████| 1000/1000 [07:29<00:00,  2.23it/s]


epoch_42


100%|██████████| 1000/1000 [07:29<00:00,  2.23it/s]


epoch_43


100%|██████████| 1000/1000 [07:29<00:00,  2.23it/s]


epoch_44


100%|██████████| 1000/1000 [07:29<00:00,  2.23it/s]


epoch_45


100%|██████████| 1000/1000 [07:29<00:00,  2.23it/s]


saving checkpoint
saving checkpoint
epoch_46


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_47


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_48


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_49


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_50


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


saving checkpoint
saving checkpoint
epoch_51


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_52


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_53


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_54


100%|██████████| 1000/1000 [07:29<00:00,  2.23it/s]


epoch_55


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


saving checkpoint
saving checkpoint
epoch_56


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_57


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_58


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_59


100%|██████████| 1000/1000 [07:29<00:00,  2.23it/s]


epoch_60


100%|██████████| 1000/1000 [07:29<00:00,  2.23it/s]


saving checkpoint
saving checkpoint
epoch_61


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_62


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_63


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_64


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_65


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


saving checkpoint
saving checkpoint
epoch_66


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_67


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_68


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_69


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_70


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


saving checkpoint
saving checkpoint
epoch_71


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_72


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_73


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_74


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_75


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


saving checkpoint
saving checkpoint
epoch_76


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_77


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_78


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_79


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_80


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


saving checkpoint
saving checkpoint
epoch_81


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_82


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_83


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_84


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_85


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


saving checkpoint
saving checkpoint
epoch_86


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_87


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_88


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_89


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_90


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


saving checkpoint
saving checkpoint
epoch_91


100%|██████████| 1000/1000 [07:29<00:00,  2.22it/s]


epoch_92


 92%|█████████▏| 918/1000 [06:54<00:37,  2.21it/s]


KeyboardInterrupt: 

In [36]:
save_example(generator,img_path='maps/val/1.jpg', res_name=f'butterfly_gan{epoch}')

In [28]:
            save_example(generator, res_name=f'butterfly_gan{epoch}')

In [34]:
discriminator

Discriminator(
  (initial): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
  )
  (conv1): DiscriminatorBlock(
    (block): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), bias=False, padding_mode=reflect)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
  )
  (conv2): DiscriminatorBlock(
    (block): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), bias=False, padding_mode=reflect)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
  )
  (conv3): DiscriminatorBlock(
    (block): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), bias=False, padding_mode=reflect)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(neg

In [42]:
a = torch.load('save_models/pix2pix_gen_5.pth',  map_location=device)

In [43]:
a['state_dict']

OrderedDict([('initial_down.0.weight',
              tensor([[[[-0.0033, -0.1345,  0.0477, -0.0463],
                        [ 0.0957, -0.0897, -0.1139, -0.0269],
                        [-0.0705, -0.0390,  0.0255,  0.0397],
                        [-0.0860, -0.0269,  0.0663, -0.0710]],
              
                       [[ 0.1287,  0.0253, -0.0176,  0.1138],
                        [ 0.0641, -0.0190,  0.1613,  0.0374],
                        [-0.0987, -0.0024,  0.0264,  0.1194],
                        [-0.1053,  0.0696,  0.0535, -0.0433]],
              
                       [[ 0.0251, -0.1045, -0.0655, -0.0768],
                        [-0.0209, -0.0176, -0.0387,  0.0522],
                        [ 0.0186,  0.1250, -0.0332, -0.1034],
                        [ 0.0545,  0.0691, -0.0209,  0.1121]]],
              
              
                      [[[-0.0391, -0.0585,  0.0976, -0.0806],
                        [ 0.1550,  0.1090, -0.0311, -0.0029],
                        [ 0.0

In [44]:
generator.load_state_dict(a['state_dict'])

<All keys matched successfully>

In [30]:
save_model_checkpoint(generator, gen_optimizer, 12343, 'save_models/pix2pix_disc_5.pth')
save_model_checkpoint(discriminator, disc_optimizer, 12343, 'save_models/pix2pix_disc_5.pth')

saving checkpoint
saving checkpoint


In [31]:
load_checkpoint(generator, gen_optimizer, 2e-4, 'save_models/pix2pix_gen_10.pth')
load_checkpoint(discriminator, disc_optimizer, 2e-4, 'save_models/pix2pix_disc_10.pth')

load checkpoint
load checkpoint


In [30]:
# generator.train()
for i in os.listdir('pr_data/val')[:10]:
    save_example(generator, img_path=f'pr_data/val/{i}', res_name=f'{i.split(".")[0]+i.split(".")[1]}')

In [46]:
generator2 = Generator().to(device)

In [49]:
load_checkpoint(generator2, gen_optimizer, 2e-4, 'save_models/pix2pix_gen_85.pth')

load checkpoint


In [50]:
generator2.eval()
for i in os.listdir('pr_data/val')[:10]:
    save_example(generator2, img_path=f'pr_data/val/{i}', res_name=f'{i.split(".")[0]+i.split(".")[1]}')