In [1]:
!nvidia-smi -L

GPU 0: NVIDIA GeForce RTX 3090 (UUID: GPU-5371946b-bc7a-8404-76f2-94eebc42c1b7)


In [2]:
!tree -d ../../images

[01;34m../../images/png[0m
├── [01;34mCollection1[0m
│   └── [01;34mBAYC[0m
└── [01;34mCollection2[0m
    └── [01;34mEAPES[0m

4 directories


In [3]:
!pip3 install numpy -q
!pip3 install torch -q
!pip3 install torchvision -q
!pip3 install matplotlib -q
!pip3 install tqdm -q
!pip3 install ipywidgets -q
!pip3 install opencv-python -q

In [5]:
import os
import torch
import random
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
import pytorch_ssim
from torchvision.models.vgg import vgg16
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import h5py

In [None]:
# Set random seed for reproducibility
# manualSeed = 42
manualSeed = random.randint(1, 10000) # use if you want new results
random.seed(manualSeed)
torch.manual_seed(manualSeed)
print("Random Seed: ", manualSeed)

In [None]:
class Read_dataset_h5(data.Dataset):
    def __init__(self, file_path):
        super(Read_dataset_h5, self).__init__()
        hf = h5py.File(file_path)
        self.input = hf.get('input')
        self.label = hf.get('label')

    def __getitem__(self, index):
        return torch.from_numpy(self.input[index,:,:,:]).float(), torch.from_numpy(self.label[index,:,:,:]).float()

    def __len__(self):
        return self.input.shape[0]

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.prelu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)

        return x + residual

In [6]:
class G_Net(nn.Module):
    def __init__(self):
        super(G_Net, self).__init__()
        self.input = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, padding=4),
            nn.PReLU()
        )
        self.ResidualBlock1 = ResidualBlock(64)
        self.ResidualBlock2 = ResidualBlock(64)
        self.ResidualBlock3 = ResidualBlock(64)
        self.ResidualBlock4 = ResidualBlock(64)
        self.ResidualBlock5 = ResidualBlock(64)
        self.output_residual = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64)
        )
        self.pixel_shuffle = nn.Sequential(
            nn.Conv2d(64, 64 * 2 ** 2, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.PReLU()
        )
        self.pixel_shuffle2 = nn.Sequential(
            nn.Conv2d(64, 64 * 2 ** 2, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.PReLU()
        )
        self.output =  nn.Sequential(
            nn.Conv2d(64, 3, kernel_size=9, padding=4),
            nn.Tanh()
        )

    def forward(self, x):
        input = self.input(x)
        ResidualBlock1 = self.ResidualBlock1(input)
        ResidualBlock2 = self.ResidualBlock2(ResidualBlock1)
        ResidualBlock3 = self.ResidualBlock3(ResidualBlock2)
        ResidualBlock4 = self.ResidualBlock4(ResidualBlock3)
        ResidualBlock5 = self.ResidualBlock5(ResidualBlock4)
        output_residual = self.output_residual(ResidualBlock5)
        pixel_shuffle = self.pixel_shuffle(output_residual + input)
        pixel_shuffle2 = self.pixel_shuffle2(pixel_shuffle)
        output = self.output(pixel_shuffle2)
        output = (output+1)/2
        return output

In [None]:
class D_Net(nn.Module):
    def __init__(self):
        super(D_Net, self).__init__()
        self.Net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.Net(x)
        return x.squeeze()

In [None]:
class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss, self).__init__()
        vgg = vgg16(pretrained=True)
        vgg_loss = nn.Sequential(*list(vgg.features)[:31]).eval()
        for param in vgg_loss.parameters():
            param.requires_grad = False
        self.vgg_loss = vgg_loss
        self.mse_loss = nn.MSELoss()
        self.cross_entropy = nn.BCELoss()

    def forward(self, fake_rate, SR, HR):
        # MSE Loss
        MSE_loss = self.mse_loss(SR, HR)
        # VGG Loss
        VGG_loss = self.mse_loss(self.vgg_loss(SR), self.vgg_loss(HR))
        # Adversarial Loss
        Adversarial_loss = self.cross_entropy(fake_rate,torch.ones(fake_rate.size(0)).cuda())
        return MSE_loss + 6e-3 * VGG_loss + 1e-3 * Adversarial_loss

In [None]:
class DiscriminatorLoss(nn.Module):
    def __init__(self):
        super(DiscriminatorLoss, self).__init__()
        self.cross_entropy = nn.BCELoss()

    def forward(self, fake_rate, real_rate):
        # Fake_img Correct Rate
        Fake_img_CR = self.cross_entropy(fake_rate,torch.zeros(fake_rate.size(0)).cuda())
        # Real_img Correct Rate
        Real_img_CR = self.cross_entropy(real_rate,torch.ones(real_rate.size(0)).cuda())
        return Fake_img_CR + Real_img_CR

In [None]:
global opt, G_Net, D_Net , G_optim, D_optim
epoch = 1
opt = parser.parse_args() # opt < parser
print(opt)

print("===> Setting GPU")
cuda = opt.cuda
if cuda:
    print("=> use gpu id: '{}'".format(opt.gpus))
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus # set gpu
    if not torch.cuda.is_available():
            raise Exception("No GPU found or Wrong gpu id, please run without --cuda")

opt.seed = random.randint(1, 10000)
print("Random Seed: ", opt.seed)
torch.manual_seed(opt.seed) # set seed
if cuda:
    torch.cuda.manual_seed(opt.seed)

cudnn.benchmark = True # find optimal algorithms for hardware

print("===> Loading datasets")
train_set = Read_dataset_h5("data/train.h5")
training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads,
    batch_size=opt.batchSize, shuffle=True) # read to DataLoader

print("===> Building model")
G_Net = G_Net()
D_Net = D_Net()
G_Loss = GeneratorLoss()
D_Loss = DiscriminatorLoss()

# optionally copy weights from a checkpoint
if opt.pretrained:
    if os.path.isfile(opt.pretrained):
        print("=> loading model '{}'".format(opt.pretrained))
        checkpoint = torch.load(opt.pretrained)
        G_Net.load_state_dict(checkpoint['G_Net_state_dict'])
        D_Net.load_state_dict(checkpoint['D_Net_state_dict'])
        epoch = checkpoint['epoch'] + 1 # load model
    else:
        print("=> no model found at '{}'".format(opt.pretrained))

if cuda:
    G_Net = G_Net.cuda()
    D_Net = D_Net.cuda()
    G_Loss = G_Loss.cuda()
    D_Loss = D_Loss.cuda() # set model&loss for use gpu

print("===> Setting Optimizer")
G_optim = optim.Adam(G_Net.parameters())
D_optim = optim.Adam(D_Net.parameters())

if opt.pretrained:
    if os.path.isfile(opt.pretrained):
        G_optim.load_state_dict(checkpoint['G_optim_state_dict'])
        D_optim.load_state_dict(checkpoint['D_optim_state_dict'])
        print("===> Setting Pretrained Optimizer")

print("=> start epoch '{}'".format(epoch))
print("===> Training")
for epoch_ in range(epoch, opt.nEpochs + 1):
    print("===>  Start epoch {} #################################################################".format(epoch_))
    G_Net.train()
    D_Net.train()
    for _, (input, label) in enumerate(training_data_loader):
        HR = Variable(label)/255
        LR = Variable(input)/255
        if torch.cuda.is_available():
            HR = HR.cuda()
            LR = LR.cuda()
        fake_img = G_Net(LR)

        # Train Discriminator model
        D_Net.zero_grad()
        real_rate = D_Net(HR)
        fake_rate = D_Net(fake_img)
        d_loss = D_Loss(fake_rate, real_rate)
        d_loss.backward()
        D_optim.step()

        # Train Generator model
        G_Net.zero_grad()
        g_loss = G_Loss(fake_rate, fake_img, HR)
        g_loss.backward()
        G_optim.step()

        # loss 출력
        if _%10 == 0:
            print("===> Epoch[[{}]({}/{})]: D_Loss : {:.10f}, G_Loss : {:.10f}, SSIM : {:.10f}".format(epoch_, _, len(training_data_loader), d_loss, g_loss, pytorch_ssim.ssim(HR, fake_img)))

    model_out_path = "checkpoint/" + "SRGAN_Adam_epoch_{}.tar".format(epoch_)
    if not os.path.exists("checkpoint/"):
        os.makedirs("checkpoint/")
    torch.save({
            'epoch': epoch_,
            'G_Net_state_dict': G_Net.state_dict(),
            'D_Net_state_dict': D_Net.state_dict(),
            'G_optim_state_dict': G_optim.state_dict(),
            'D_optim_state_dict': D_optim.state_dict()
            }, model_out_path)
    print("Checkpoint has been saved to the {}".format(model_out_path))
