In [1]:
!nvidia-smi

Sat Feb 26 13:20:07 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 495.44       CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  On   | 00000000:01:00.0 Off |                  N/A |
| 44%   40C    P8    28W / 350W |      0MiB / 24268MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
import torch
from torch import nn
import os
import cv2 as cv
import os.path as path
from torch.nn import functional as F
from torch.utils.data import Dataset
from torchvision import transforms
import time
from torch.autograd import Variable
from torch import autograd

In [3]:
model_name = 'wgan-content-loss'
model_dir = 'wgan-content-loss'

In [4]:
class ImageDataset(Dataset):

    def __init__(self, origin_root, train_root):
        super().__init__()
        self.origin_root = origin_root
        self.train_root = train_root
        self.origin_files = os.listdir(self.origin_root)
        self.train_files = os.listdir(self.train_root)

    def __len__(self):
        return len(self.origin_files)

    def __getitem__(self, index):
        origin_data = cv.imread(path.join(self.origin_root, self.origin_files[index]))
        train_data = cv.imread(path.join(self.train_root, self.train_files[index]))
        to_tensor = transforms.ToTensor()
        origin_data = to_tensor(origin_data)[0].reshape((1, 678, 384))
        return origin_data, to_tensor(train_data)[0].reshape((1, 678, 384))

In [5]:
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)

    def forward(self, x):
        if torch.any(torch.isnan(x)):
            torch.isnan(x)
        residual = x
        out = F.relu(self.conv1(x))
        out = self.conv2(out)
        if residual.size()[1] != out.size()[1]:
            residual = self.proj(residual)
        out += residual
        return F.relu(out)


class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            # input is (1) x 678 x 384
            nn.Conv2d(1, 64, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm2d(64, affine=True),
            nn.ReLU(True),
            # state size. (64) x 678 x 384
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128, affine=True),
            nn.ReLU(True),
            # state size. (128) x 339 x 192
            ResNetBlock(128, 128),
            ResNetBlock(128, 256, stride=2),
            ResNetBlock(256, 256),
            ResNetBlock(256, 512, stride=2),
            # state size. (512) x 85 x 48
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(256, affine=True),
            nn.ReLU(True),
            # state size. (256) x 170 x 96
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128, affine=True),
            nn.ReLU(True),
            # state size. (128) x 340 x 192
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64, affine=True),
            nn.ReLU(True),
            # state size. (64) x 680 x 384
            nn.ConvTranspose2d(64, 1, kernel_size=7, stride=1, padding=(4, 3)),
            nn.Tanh()
            # output is (1) x 678 x 384
        )

    def forward(self, inputs):
        return self.model(inputs) + inputs

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            # input is (1) x 678 x 384
            nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (64) x 178 x 89
            nn.Conv2d(64, 128, kernel_size=3, stride=2),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (128) x 89 x 44
            nn.Conv2d(128, 256, kernel_size=3, stride=2),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, stride=2),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(1024, 1, kernel_size=3, stride=1),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, x):
        x = self.model(x)
        return x

In [7]:
generator_model_name = f'{model_dir}/{model_name}-generator.pth'
discriminator_model_name = f'{model_dir}/{model_name}-discriminator.pth'
start_point = 0

class WGAN(object):
    def __init__(self):
        super(WGAN, self).__init__()
        print("WGAN With Content Loss Init Model.")

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.G = Generator().to(self.device)
        self.D = Discriminator().to(self.device)
        self.batch_size = 16
        self.generator_iters = 3000
        self.critic_iter = 5
        self.g_optimizer = torch.optim.Adam(self.G.parameters(), lr=1e-4, betas=(0.5, 0.999))
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), lr=1e-4, betas=(0.5, 0.999))
        self.lambda_term = 10
        
        self.loss_function = nn.L1Loss()

        self.one = torch.tensor(1, dtype=torch.float).to(self.device)
        self.mone = (self.one * -1).to(self.device)

        dataset = ImageDataset(origin_root='dataset/train/images',
                               train_root='dataset/train/labels')
        self.data_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
        self.data = self.get_infinite_batches()

        self.init_model()

    def init_model(self):
        if os.path.exists(generator_model_name) and os.path.exists(discriminator_model_name):
            self.G.load_state_dict(torch.load(generator_model_name))
            self.D.load_state_dict(torch.load(discriminator_model_name))
            print(f'Models load from {generator_model_name} & {discriminator_model_name}')
        else:
            print('No trained_models found, init new trained_models')
            self.G.apply(self.weights_init)
            self.D.apply(self.weights_init)

    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.xavier_normal_(m.weight.data)
            m.bias.data.fill_(0)

    def train(self):
        for epoch in range(self.generator_iters):
            begin = time.time()
            print(f'Epoch: {epoch + 1 + start_point} / {self.generator_iters + start_point} ============================')
            # 训练判别器
            for p in self.D.parameters():
                p.requires_grad = True

            d_loss=0
            Wasserstein_D=0
            # 训练 5 次判别器，训练 1 次生成器
            for d_iter in range(self.critic_iter):
                self.D.zero_grad()
                inputs, reals = next(self.data)
                
                if inputs.size()[0] != self.batch_size:
                    continue

                d_loss_real = self.D(reals).mean()
                d_loss_real.backward(self.mone)

                fakes = self.G(inputs)
                d_loss_fake = self.D(fakes).mean()
                d_loss_fake.backward(self.one)

                gradient_penalty = self.calculate_gradient_penalty(reals.data, fakes.data)
                gradient_penalty.backward()

                d_loss = d_loss_fake - d_loss_real + gradient_penalty
                Wasserstein_D = d_loss_real - d_loss_fake
                self.d_optimizer.step()
                
            # 固定判别器，训练生成器
            for p in self.D.parameters():
                p.requires_grad = False

            self.G.zero_grad()
            inputs, reals = next(self.data)
            fakes = self.G(inputs)
            g_loss = self.D(fakes).mean()
            g_cost = -g_loss
            
            # L1 损失
            l1_loss = self.loss_function(fakes, reals).mean()
            
            l = 1000 * l1_loss + g_cost
            l.backward()
            
            self.g_optimizer.step()

            print(f'D loss: {d_loss.item()}, G loss: {g_loss.item()}, l1 loss: {100 * l1_loss.item()}, Wasserstein_D: {Wasserstein_D.item()}, cost time: {time.time() - begin}')
            self.save_model()
            
            if (epoch + 1) % 100 == 0:
                torch.save(self.G.state_dict(), f'{model_dir}/{model_name}-generator-{epoch + start_point + 1}.pth')
            if (epoch + 1) % 500 == 0:
                torch.save(self.D.state_dict(), f'{model_dir}/{model_name}-discriminator-{epoch + start_point + 1}.pth')

    def get_infinite_batches(self):
        while True:
            for i, (inputs, reals) in enumerate(self.data_loader):
                inputs = inputs.to(self.device)
                reals = reals.to(self.device)
                yield inputs, reals

    def calculate_gradient_penalty(self, real_images, fake_images):
        eta = torch.FloatTensor(self.batch_size, 1, 1, 1).uniform_(0, 1)
        eta = eta.expand(self.batch_size, real_images.size(1), real_images.size(2), real_images.size(3))
        eta = eta.to(self.device)

        interpolated = eta * real_images + ((1 - eta) * fake_images)
        interpolated = interpolated.to(self.device)

        # define it to calculate gradient
        interpolated = Variable(interpolated, requires_grad=True)

        # calculate probability of interpolated examples
        prob_interpolated = self.D(interpolated)

        # calculate gradients of probabilities with respect to examples
        gradients = autograd.grad(outputs=prob_interpolated, inputs=interpolated,
                                  grad_outputs=torch.ones(
                                      prob_interpolated.size()).cuda(0),
                                  create_graph=True, retain_graph=True)[0]

        grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.lambda_term
        return grad_penalty

    def save_model(self):
        torch.save(self.G.state_dict(), generator_model_name)
        torch.save(self.D.state_dict(), discriminator_model_name)
        print(f'Models save to {generator_model_name} & {discriminator_model_name}')

In [None]:
model = WGAN()
model.train()