In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import os
import time
import functools
import torch
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.utils as vutils
from torch.autograd import Variable
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class UnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, num_downs, ngf=64,
                 norm_layer=nn.BatchNorm2d, use_dropout=False, output_function=nn.Sigmoid):
        super(UnetGenerator, self).__init__()
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)

        self.model = unet_block

    def forward(self, input):
        return self.model(input)

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None,submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, output_function=nn.Sigmoid):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            down = [downconv]
            if output_function == nn.Tanh:
                up = [uprelu, upconv, nn.Tanh()]
            else:
                up = [uprelu, upconv, nn.Sigmoid()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)

In [None]:
class RevealNet(nn.Module):
    def __init__(self):
        super(RevealNet, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Conv2d(256, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Conv2d(128, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.Conv2d(64, 3, 3, 1, 1),
            nn.Sigmoid())

    def forward(self, input):
        output=self.main(input)
        return output

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [None]:
def print_log(log_info, log_path, console=True):
    print(log_info)

In [None]:
def save_result_pic(this_batch_size, originalLabelv, ContainerImg, secretLabelv, RevSecImg, epoch, i, save_path):
    originalFrames = originalLabelv.resize_(this_batch_size, 3, 256, 256)
    containerFrames = ContainerImg.resize_(this_batch_size, 3, 256, 256)
    secretFrames = secretLabelv.resize_(this_batch_size, 3,256, 256)
    revSecFrames = RevSecImg.resize_(this_batch_size, 3, 256, 256)
    showContainer = torch.cat([originalFrames, containerFrames], 0)
    showReveal = torch.cat([secretFrames, revSecFrames], 0)
    resultImg = torch.cat([showContainer, showReveal], 0)
    resultImgName = '%s/ResultImage_epoch%03d_batch%04d.png' % (save_path, epoch, i)
    vutils.save_image(resultImg, resultImgName, nrow=this_batch_size, padding=1, normalize=True)

In [None]:
class AverageMeter(object):
    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def train(train_loader, epoch, Hnet, Rnet, criterion, epochs):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    Hlosses = AverageMeter() 
    Rlosses = AverageMeter() 
    SumLosses = AverageMeter()  

    Hnet.train()
    Rnet.train()

    start_time = time.time()
    for i, batch_data in enumerate(train_loader):
        data_time.update(time.time() - start_time)

        Hnet.zero_grad()
        Rnet.zero_grad()

        all_pics, _ = batch_data 
        this_batch_size = int(all_pics.size()[0] / 2)  

       
        cover_img = all_pics[0:this_batch_size, :, :, :] 
        secret_img = all_pics[this_batch_size:this_batch_size * 2, :, :, :]

        concat_img = torch.cat([cover_img, secret_img], dim=1)

        if  torch.cuda.is_available():
            cover_img = cover_img.cuda()
            secret_img = secret_img.cuda()
            concat_img = concat_img.cuda()

        concat_imgv = Variable(concat_img)
        cover_imgv = Variable(cover_img)

        container_img = Hnet(concat_imgv) 
        errH = criterion(container_img, cover_imgv)
        Hlosses.update(errH.item(), this_batch_size)

        rev_secret_img = Rnet(container_img) 
        secret_imgv = Variable(secret_img)
        errR = criterion(rev_secret_img, secret_imgv)  
        Rlosses.update(errR.item(), this_batch_size)

        betaerrR_secret = 0.75 * errR
        err_sum = errH + betaerrR_secret
        SumLosses.update(err_sum.item(), this_batch_size)

        err_sum.backward()

        optimizerH.step()
        optimizerR.step()

        batch_time.update(time.time() - start_time)
        start_time = time.time()

        log = '[%d/%d][%d/%d]\tHnet Loss: %.4f\tRnet Loss: %.4f\tOverall loss: %.4f \t datatime: %.4f \t batchtime: %.4f' % (
            epoch + 1, epochs, i, len(train_loader),
            Hlosses.val, Rlosses.val, SumLosses.val, data_time.val, batch_time.val)

        if i % 5 == 0:
            print(log)

        if i % resultPicFrequency == 0:
            save_result_pic(this_batch_size, cover_img, container_img.data, secret_img, rev_secret_img.data, epoch, i,
                            trainpics)

    epoch_log = "----------------------------  one epoch time is %.4f  ----------------------------" % (
        batch_time.sum) + "\n"
    epoch_log = epoch_log + "Learning rates: Hnet = %.8f      Rnet = %.8f" % (
        optimizerH.param_groups[0]['lr'], optimizerR.param_groups[0]['lr']) + "\n"
    epoch_log = epoch_log + "Average Hnet Loss for the epoch=%.6f\t Average Rnet Loss for the epoch=%.6f\t Total Average loss for the epoch=%.6f" % (
        Hlosses.avg, Rlosses.avg, SumLosses.avg)
    print(epoch_log)

In [None]:
DATA_DIR = '/content/drive/My Drive'
traindir = os.path.join(DATA_DIR, '')
train_dataset = ImageFolder(
        traindir,
        transforms.Compose([
            transforms.Resize([256, 256]),  
            transforms.ToTensor(),
           ]))

print(len(train_dataset))
resultPicFrequency = 50
trainpics = '/content/drive/My Drive/results'

49110


In [None]:
Hnet = UnetGenerator(input_nc=6, output_nc=3, num_downs=7, output_function=nn.Sigmoid)
Hnet = Hnet.cuda()
Hnet.apply(weights_init)

Rnet = RevealNet()
Rnet = Rnet.cuda()
Rnet.apply(weights_init)

criterion = nn.MSELoss().cuda()

In [None]:
optimizerH = optim.Adam(Hnet.parameters(), lr=0.001, betas=(0.5, 0.999))
optimizerR = optim.Adam(Rnet.parameters(), lr=0.001, betas=(0.5, 0.999))
num_epochs = 3
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=8)
print("------------------------- training is beginning -------------------------")
for epoch in range(num_epochs):
    print("Epoch - ", epoch + 1)
    train(train_loader, epoch, Hnet=Hnet, Rnet=Rnet, criterion=criterion, epochs=num_epochs)
    torch.save(Hnet.state_dict(), '/content/drive/My Drive/results/epoch{}.pth'.format(epoch+1))

------------------------- training is beginning -------------------------
Epoch -  1
[1/3][0/1535]	Hnet Loss: 0.0843	Rnet Loss: 0.0894	Overall loss: 0.1513 	 datatime: 13.2054 	 batchtime: 13.6508
[1/3][5/1535]	Hnet Loss: 0.0360	Rnet Loss: 0.0333	Overall loss: 0.0610 	 datatime: 0.0019 	 batchtime: 0.9849
[1/3][10/1535]	Hnet Loss: 0.0185	Rnet Loss: 0.0180	Overall loss: 0.0321 	 datatime: 0.0001 	 batchtime: 0.9908
[1/3][15/1535]	Hnet Loss: 0.0193	Rnet Loss: 0.0207	Overall loss: 0.0348 	 datatime: 0.0001 	 batchtime: 0.9869
[1/3][20/1535]	Hnet Loss: 0.0260	Rnet Loss: 0.0258	Overall loss: 0.0453 	 datatime: 0.0001 	 batchtime: 0.9888
[1/3][25/1535]	Hnet Loss: 0.0181	Rnet Loss: 0.0250	Overall loss: 0.0369 	 datatime: 0.0001 	 batchtime: 0.9929
[1/3][30/1535]	Hnet Loss: 0.0135	Rnet Loss: 0.0295	Overall loss: 0.0356 	 datatime: 0.0001 	 batchtime: 0.9932
[1/3][35/1535]	Hnet Loss: 0.0160	Rnet Loss: 0.0175	Overall loss: 0.0291 	 datatime: 0.0037 	 batchtime: 0.9979
[1/3][40/1535]	Hnet Loss: 0

In [None]:
torch.save(Hnet.state_dict(), '/content/drive/My Drive/results/Hnet-Final.pth')
torch.save(Rnet.state_dict(), '/content/drive/My Drive/results/Rnet-Final.pth')