In [1]:
!nvidia-smi

Thu Feb 24 19:34:41 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 495.44       CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  On   | 00000000:01:00.0 Off |                  N/A |
| 30%   30C    P8    27W / 350W |      0MiB / 24268MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
import torch
from torch import nn
import os
import cv2 as cv
import os.path as path
from torch.nn import functional as F
from torch.utils.data import Dataset
from torchvision import transforms
import time

In [4]:
class ImageDataset(Dataset):

    def __init__(self, origin_root, train_root):
        super().__init__()
        self.origin_root = origin_root
        self.train_root = train_root
        self.origin_files = os.listdir(self.origin_root)
        self.train_files = os.listdir(self.train_root)

    def __len__(self):
        return len(self.origin_files)

    def __getitem__(self, index):
        origin_data = cv.imread(path.join(self.origin_root, self.origin_files[index]))
        train_data = cv.imread(path.join(self.train_root, self.train_files[index]))
        to_tensor = transforms.ToTensor()
        origin_data = to_tensor(origin_data)[0].reshape((1, 678, 384))
        return origin_data, to_tensor(train_data)[0].reshape((1, 678, 384))

In [5]:
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)

    def forward(self, x):
        if torch.any(torch.isnan(x)):
            torch.isnan(x)
        residual = x
        out = F.relu(self.conv1(x))
        out = self.conv2(out)
        if residual.size()[1] != out.size()[1]:
            residual = self.proj(residual)
        out += residual
        return F.relu(out)


class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            # input is (1) x 678 x 384
            nn.Conv2d(1, 64, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm2d(64, affine=True),
            nn.ReLU(True),
            # state size. (64) x 678 x 384
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128, affine=True),
            nn.ReLU(True),
            # state size. (128) x 339 x 192
            ResNetBlock(128, 128),
            ResNetBlock(128, 256, stride=2),
            ResNetBlock(256, 256),
            ResNetBlock(256, 512, stride=2),
            # state size. (512) x 85 x 48
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(256, affine=True),
            nn.ReLU(True),
            # state size. (256) x 170 x 96
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128, affine=True),
            nn.ReLU(True),
            # state size. (128) x 340 x 192
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64, affine=True),
            nn.ReLU(True),
            # state size. (64) x 680 x 384
            nn.ConvTranspose2d(64, 1, kernel_size=7, stride=1, padding=(4, 3)),
            nn.Tanh()
            # output is (1) x 678 x 384
        )

    def forward(self, inputs):
        return self.model(inputs) + inputs


    
class StandardDiscriminator(nn.Module):
    def __init__(self):
        super(StandardDiscriminator, self).__init__()

        self.model = nn.Sequential(
            # input is (1) x 678 x 384
            nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (64) x 178 x 89
            nn.Conv2d(64, 128, kernel_size=3, stride=2),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (128) x 89 x 44
            nn.Conv2d(128, 256, kernel_size=3, stride=2),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, stride=2),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1),
            nn.Flatten(),
            nn.Linear(41 * 23, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.model(x)
        return x


In [6]:
generator_model_name = 'sd-gan/standard-generatorg5d1.pth'
discriminator_model_name = 'sd-gan/standard-discriminatorg5d1.pth'
start_point = 0


class StandarnGAN(nn.Module):
    def __init__(self):
        super(StandarnGAN, self).__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.G = Generator().to(self.device)
        self.D = StandardDiscriminator().to(self.device)
        self.batch_size = 8
        self.generator_iter = 500
        self.critic_iter = 1
        self.gen_iters = 5
        self.g_optimizer = torch.optim.Adam(self.G.parameters(), lr=1e-4, betas=(0.5, 0.999))
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), lr=1e-4, betas=(0.5, 0.999))

        dataset = ImageDataset(origin_root='dataset/train/images',
                               train_root='dataset/train/labels')
        self.data_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
        self.data = self.get_infinite_batches()

        self.loss_func = nn.BCELoss()

        self.init_model()

    def train(self):
        for epoch in range(self.generator_iter):
            print(f'Epoch: {epoch + 1 + start_point} / {self.generator_iter + start_point} ============================')
            begin = time.time()
            # 训练判别器
            for p in self.D.parameters():
                p.requires_grad = True
                
            d_loss_real = 0
            d_loss_fake = 0
            g_loss_fake = 0
            for d_iter in range(self.critic_iter):
                self.D.zero_grad()
                inputs, reals = next(self.data)

                d_loss_real = self.D(reals)
                l1 = self.loss_func(d_loss_real, torch.ones(d_loss_real.size()).to(self.device))

                fakes = self.G(inputs)
                d_loss_fake = self.D(fakes)
                l2 = self.loss_func(d_loss_fake, torch.zeros(d_loss_fake.size()).to(self.device))

                d_loss = (l1 + l2) / 2
                d_loss.backward()

                self.d_optimizer.step()

            # 固定判别器，训练生成器
            for p in self.D.parameters():
                p.requires_grad = False
            
            for g_iter in range(self.gen_iters):
                self.G.zero_grad()
                inputs, _ = next(self.data)
                # if inputs.size()[0] != self.batch_size:
                #     inputs, _ = next(self.data)
                fakes = self.G(inputs)
                g_loss_fake = self.D(fakes)
                g_loss = self.loss_func(g_loss_fake, torch.ones(g_loss_fake.size()).to(self.device))
                g_loss.backward()
                self.g_optimizer.step()

            print(f'D loss: {d_loss.item()}, G loss: {g_loss.item()}, cost time: {time.time() - begin}')
            print(f'd_loss_real: {d_loss_real.mean().item()}, d_loss_fake: {d_loss_fake.mean().item()}, g_loss_fake: {g_loss_fake.mean().item()}')
            self.save_model()
            
            if (epoch + 1) % 50 == 0:
                torch.save(self.G.state_dict(), f'sd-gan/sd-generatorg5d1-{epoch + 1 + start_point}.pth')
                torch.save(self.D.state_dict(), f'sd-gan/sd-discriminatorg5d1-{epoch + 1 + start_point}.pth')


    def save_model(self):
        torch.save(self.G.state_dict(), generator_model_name)
        torch.save(self.D.state_dict(), discriminator_model_name)
        print(f'Models save to {generator_model_name} & {discriminator_model_name} ')

    def init_model(self):
        if os.path.exists(generator_model_name) and os.path.exists(discriminator_model_name):
            self.G.load_state_dict(torch.load(generator_model_name))
            self.D.load_state_dict(torch.load(discriminator_model_name))
            print(f'Models load from {generator_model_name} & {discriminator_model_name}')
        else:
            print('No trained_models found, init new trained_models')
            self.G.apply(self.weights_init)
            self.D.apply(self.weights_init)

    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.xavier_normal_(m.weight.data)
            m.bias.data.fill_(0)

    def get_infinite_batches(self):
        while True:
            for i, (inputs, reals) in enumerate(self.data_loader):
                inputs = inputs.to(self.device)
                reals = reals.to(self.device)
                yield inputs, reals


In [None]:
model = StandarnGAN()
model.train()