<a href="https://colab.research.google.com/github/Eereenah/deep-steganography/blob/master/SSIM_Image_Steganography.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Imports

In [0]:
import functools

import argparse
import os
import shutil
import socket
import time

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.utils as vutils
#from tensorboardX import SummaryWriter
from torch.autograd import Variable
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader

In [0]:
import torchvision
import torchvision.transforms as transforms

In [0]:
import math
import random
from PIL import Image, ImageOps, ImageEnhance

In [0]:
try:
    import accimage
except ImportError:
    accimage = None
import numpy as np
import numbers
import types
import collections
import warnings

In [0]:
! pip install pytorch-msssim

In [0]:
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM

### Data Preparation

In [0]:
! wget images.cocodataset.org/zips/val2017.zip


In [0]:
! unzip -qq *.zip

In [0]:
! rm *.zip

In [0]:
! wget images.cocodataset.org/zips/test2017.zip

In [0]:
! unzip -qq *.zip

In [0]:
! rm *.zip

In [0]:
! mkdir 'val2017/cover'
! mkdir 'val2017/secret'

In [0]:
! find 'val2017/' -maxdepth 1 -type f -printf "." | wc -c

In [0]:
! ls 'val2017/'* | head -480 | xargs -I{} mv {} 'val2017/cover/'

In [0]:
! ls 'val2017/'* | head -480| xargs -I{} mv {} 'val2017/secret/'

In [0]:
! mkdir 'train2017'
! mkdir 'train2017/cover'
! mkdir 'train2017/secret'

In [0]:
! find 'test2017/' -maxdepth 1 -type f -printf "." | wc -c

In [0]:
! ls 'test2017/'* | head -4800 | xargs -I{} mv {} 'train2017/cover/'

In [0]:
! ls 'test2017/'* | head -4800 | xargs -I{} mv {} 'train2017/secret/'

In [0]:
! mkdir 'test2017/cover'
! mkdir 'test2017/secret'

In [0]:
! ls 'test2017/'* | head -480 | xargs -I{} mv {} 'test2017/cover/'

In [0]:
! ls 'test2017/'* | head -4800 | xargs -I{} mv {} 'test2017/secret/'

In [0]:
from google.colab import drive
drive.mount('/content/drive')

### Hiding Network

In [0]:
class UnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, num_downs, ngf=64,
                 norm_layer=nn.BatchNorm2d, use_dropout=False, output_function=nn.Sigmoid):
        super(UnetGenerator, self).__init__()
        # construct unet structure
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer, output_function=output_function)

        self.model = unet_block

    def forward(self, input):
        return self.model(input)

In [0]:
class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None,submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, output_function=nn.Sigmoid):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            down = [downconv]
            if output_function == nn.Tanh:
                up = [uprelu, upconv, nn.Tanh()]
            else:
                up = [uprelu, upconv, nn.Sigmoid()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)

### Revealing Network

In [0]:
class RevealNet(nn.Module):
    def __init__(self, nc=3, nhf=64, output_function=nn.Sigmoid):
        super(RevealNet, self).__init__()
        # input is (3) x 256 x 256
        self.main = nn.Sequential(
            nn.Conv2d(nc, nhf, 3, 1, 1),
            nn.BatchNorm2d(nhf),
            nn.ReLU(True),
            nn.Conv2d(nhf, nhf * 2, 3, 1, 1),
            nn.BatchNorm2d(nhf*2),
            nn.ReLU(True),
            nn.Conv2d(nhf * 2, nhf * 4, 3, 1, 1),
            nn.BatchNorm2d(nhf*4),
            nn.ReLU(True),
            nn.Conv2d(nhf * 4, nhf * 2, 3, 1, 1),
            nn.BatchNorm2d(nhf*2),
            nn.ReLU(True),
            nn.Conv2d(nhf * 2, nhf, 3, 1, 1),
            nn.BatchNorm2d(nhf),
            nn.ReLU(True),
            nn.Conv2d(nhf, nc, 3, 1, 1),
            output_function()
        )

    def forward(self, input):
        output=self.main(input)
        return output

### Training

In [0]:
opt_beta = 0.75
opt_niter = 10
opt_logFreq = 10
opt_saveFreq = 5

In [0]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [0]:
def train(train_loader, epoch, Hnet, Rnet, criterion):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    Hlosses = AverageMeter()  # record loss of H-net
    Rlosses = AverageMeter()  # record loss of R-net
    SumLosses = AverageMeter()  # record Hloss + β*Rloss

    # switch to train mode
    Hnet.train()
    Rnet.train()

    start_time = time.time()
    for i, data in enumerate(train_loader, 0):
        data_time.update(time.time() - start_time)

        Hnet.zero_grad()
        Rnet.zero_grad()

        all_pics = data[0]  # allpics contains cover images and secret images
        this_batch_size = int(all_pics.size()[0] / 2)  # get true batch size of this step 

        # first half of images will become cover images, the rest are treated as secret images
        cover_img = all_pics[0:this_batch_size, :, :, :]  # batchsize,3,256,256
        secret_img = all_pics[this_batch_size:this_batch_size * 2, :, :, :]

        # concat cover images and secret images as input of H-net
        concat_img = torch.cat([cover_img, secret_img], dim=1)

        concat_imgv = Variable(concat_img)
        cover_imgv = Variable(cover_img)

        container_img = Hnet(concat_imgv)  # put concat_image into H-net and get container image
        #errH = criterion(container_img, cover_imgv)  # loss between cover and container
        errH = -torch.mean(ssim(container_img, cover_imgv, data_range=255, size_average=False, nonnegative_ssim=True))
        Hlosses.update(errH.data.item(), this_batch_size)

        rev_secret_img = Rnet(container_img)  # put concatenated image into R-net and get revealed secret image
        secret_imgv = Variable(secret_img)
        #errR = criterion(rev_secret_img, secret_imgv)  # loss between secret image and revealed secret image
        errR = -torch.mean(ssim(rev_secret_img, secret_imgv, data_range=255, size_average=False, nonnegative_ssim=True))
        Rlosses.update(errR.data.item(), this_batch_size)

        betaerrR_secret = opt_beta * errR
        err_sum = errH + betaerrR_secret
        SumLosses.update(err_sum.data.item(), this_batch_size)

        err_sum.backward()

        optimizerH.step()
        optimizerR.step()

        batch_time.update(time.time() - start_time)
        start_time = time.time()

        log = '[%d/%d][%d/%d]\tLoss_H: %.4f Loss_R: %.4f Loss_sum: %.4f \tdatatime: %.4f \tbatchtime: %.4f' % (
            epoch, opt_niter, i, len(train_loader),
            Hlosses.val, Rlosses.val, SumLosses.val, data_time.val, batch_time.val)

        if i % opt_logFreq == 0:
            print(log)
            with open('drive/My Drive/Steganography/Models/UNet-Pytorch/Checkpoints/Checkpoints_SSIM_4800_10_pt2/Parameters.txt', 'a') as f: 
                f.write(log + "\n")

        if epoch % 1 == 0 and i % opt_saveFreq == 0:
            save_result_pic(this_batch_size, cover_img, container_img.data, secret_img, rev_secret_img.data, epoch, i,
                            'drive/My Drive/Steganography/Models/UNet-Pytorch/Train/Train_SSIM_4800_10_pt2/')


    # epcoh log
    epoch_log = "one epoch time is %.4f======================================================================" % (
        batch_time.sum) + "\n"
    epoch_log = epoch_log + "epoch learning rate: optimizerH_lr = %.8f      optimizerR_lr = %.8f" % (
        optimizerH.param_groups[0]['lr'], optimizerR.param_groups[0]['lr']) + "\n"
    epoch_log = epoch_log + "epoch_Hloss=%.6f\tepoch_Rloss=%.6f\tepoch_sumLoss=%.6f" % (
        Hlosses.avg, Rlosses.avg, SumLosses.avg)
    print(epoch_log)
    with open('drive/My Drive/Steganography/Models/UNet-Pytorch/Checkpoints/Checkpoints_SSIM_4800_10_pt2/Parameters.txt', 'a') as f: 
                f.write(epoch_log + "\n")


In [0]:
def validation(val_loader, epoch, Hnet, Rnet, criterion):
    print(
        "#################################################### validation begin ########################################################")
    start_time = time.time()
    Hnet.eval()
    Rnet.eval()
    Hlosses = AverageMeter()
    Rlosses = AverageMeter()
    for i, data in enumerate(val_loader, 0):
        Hnet.zero_grad()
        Rnet.zero_grad()
        all_pics = data[0]
        this_batch_size = int(all_pics.size()[0] / 2)

        cover_img = all_pics[0:this_batch_size, :, :, :]
        secret_img = all_pics[this_batch_size:this_batch_size * 2, :, :, :]

        concat_img = torch.cat([cover_img, secret_img], dim=1)


        concat_imgv = Variable(concat_img, volatile=True)
        cover_imgv = Variable(cover_img, volatile=True)

        container_img = Hnet(concat_imgv)
        #errH = criterion(container_img, cover_imgv)
        errH = -torch.mean(ssim(container_img, cover_imgv, data_range=255, size_average=False, nonnegative_ssim=True))
        Hlosses.update(errH.data.item(), this_batch_size)

        rev_secret_img = Rnet(container_img)
        secret_imgv = Variable(secret_img, volatile=True)
        #errR = criterion(rev_secret_img, secret_imgv)
        errR = -torch.mean(ssim(rev_secret_img, secret_imgv, data_range=255, size_average=False, nonnegative_ssim=True))
        Rlosses.update(errR.data.item(), this_batch_size)

        if i % 50 == 0:
            save_result_pic(this_batch_size, cover_img, container_img.data, secret_img, rev_secret_img.data, epoch, i,
                            'drive/My Drive/Steganography/Models/UNet-Pytorch/Validation/Validation_SSIM_4800_10_p2/')


    val_hloss = Hlosses.avg
    val_rloss = Rlosses.avg
    val_sumloss = val_hloss + opt_beta * val_rloss

    val_time = time.time() - start_time
    val_log = "validation[%d] val_Hloss = %.6f\t val_Rloss = %.6f\t val_Sumloss = %.6f\t validation time=%.2f" % (
        epoch, val_hloss, val_rloss, val_sumloss, val_time)
    #print_log(val_log, logPath)
    print(val_log)
    with open('drive/My Drive/Steganography/Models/UNet-Pytorch/Checkpoints/Checkpoints_SSIM_4800_10_p2/Parameters.txt', 'a') as f: 
                f.write(val_log + "\n")



    print(
        "#################################################### validation end ########################################################")
    return val_hloss, val_rloss, val_sumloss

In [0]:
def test(test_loader, epoch, Hnet, Rnet, criterion):
    print(
        "#################################################### test begin ########################################################")
    start_time = time.time()
    Hnet.eval()
    Rnet.eval()
    Hlosses = AverageMeter()  # record the Hloss in one epoch
    Rlosses = AverageMeter()  # record the Rloss in one epoch
    for i, data in enumerate(test_loader, 0):
        Hnet.zero_grad()
        Rnet.zero_grad()
        all_pics = data[0]  # allpics contains cover images and secret images
        this_batch_size = int(all_pics.size()[0] / 2)  # get true batch size of this step 

        # first half of images will become cover images, the rest are treated as secret images
        cover_img = all_pics[0:this_batch_size, :, :, :]  # batchSize,3,256,256
        secret_img = all_pics[this_batch_size:this_batch_size * 2, :, :, :]

        # concat cover and original secret to get the concat_img with 6 channels
        concat_img = torch.cat([cover_img, secret_img], dim=1)


        concat_imgv = Variable(concat_img, volatile=True)  # concat_img as input of Hiding net
        cover_imgv = Variable(cover_img, volatile=True)  # cover_imgv as label of Hiding net

        container_img = Hnet(concat_imgv)  # take concat_img as input of H-net and get the container_img
        #errH = criterion(container_img, cover_imgv)  # H-net reconstructed error
        errH = -torch.mean(ssim(container_img, cover_imgv, data_range=255, size_average=False, nonnegative_ssim=True))
        Hlosses.update(errH.data.item(), this_batch_size)

        rev_secret_img = Rnet(container_img)  # containerImg as input of R-net and get "rev_secret_img"
        secret_imgv = Variable(secret_img, volatile=True)  # secret_imgv as label of R-net
        #errR = criterion(rev_secret_img, secret_imgv)  # R-net reconstructed error
        errR = -torch.mean(ssim(rev_secret_img, secret_imgv, data_range=255, size_average=False, nonnegative_ssim=True))
        Rlosses.update(errR.data.item(), this_batch_size)
        save_result_pic(this_batch_size, cover_img, container_img.data, secret_img, rev_secret_img.data, epoch, i,
                            'drive/My Drive/Steganography/Models/UNet-Pytorch/Test/Test_SSIM_4800_10/')
        # save_result_pic(this_batch_size, cover_img, container_img.data, secret_img, rev_secret_img.data, epoch, i,
        #                 opt.testPics)

    val_hloss = Hlosses.avg
    val_rloss = Rlosses.avg
    val_sumloss = val_hloss + opt_beta * val_rloss

    val_time = time.time() - start_time
    val_log = "validation[%d] val_Hloss = %.6f\t val_Rloss = %.6f\t val_Sumloss = %.6f\t validation time=%.2f" % (
        epoch, val_hloss, val_rloss, val_sumloss, val_time)
    #print_log(val_log, logPath)
    print(val_log)
    with open('drive/My Drive/Steganography/Models/UNet-Pytorch/Checkpoints/Checkpoints_SSIM_4800_10_pt2/Parameters.txt', 'a') as f: 
                f.write(val_log + "\n")


    print(
        "#################################################### test end ########################################################")
    return val_hloss, val_rloss, val_sumloss

In [0]:
def save_result_pic(this_batch_size, originalLabelv, ContainerImg, secretLabelv, RevSecImg, epoch, i, save_path):
      originalFrames = originalLabelv.resize_(this_batch_size, 3, 256, 256)
      containerFrames = ContainerImg.resize_(this_batch_size, 3, 256, 256)
      secretFrames = secretLabelv.resize_(this_batch_size, 3, 256, 256)
      revSecFrames = RevSecImg.resize_(this_batch_size, 3, 256, 256)

      showContainer = torch.cat([originalFrames, containerFrames], 0)
      showReveal = torch.cat([secretFrames, revSecFrames], 0)
      resultImg = torch.cat([showContainer, showReveal], 0)
      resultImgName = '%s/ResultPics_epoch%03d_batch%04d.png' % (save_path, epoch, i)
      vutils.save_image(resultImg, resultImgName, nrow=this_batch_size, padding=1, normalize=True)

In [0]:
class AverageMeter(object):
    """
    Computes and stores the average and current value.
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [0]:
def main():
    #global parameters
    global opt, optimizerH, optimizerR, writer, logPath, schedulerH, schedulerR, val_loader, smallestLoss

    Hnet = UnetGenerator(input_nc=6, output_nc=3, num_downs=7, output_function=nn.Sigmoid)
    Hnet.apply(weights_init)

    Rnet = RevealNet(output_function=nn.Sigmoid)
    Rnet.apply(weights_init)

    Hnet.load_state_dict(torch.load("drive/My Drive/Steganography/Models/UNet-Pytorch/Checkpoints/Checkpoints_SSIM_4800_10/netH_epoch_4,sumloss=-1.749747,Hloss=-0.999883.pth"))
    Rnet.load_state_dict(torch.load("drive/My Drive/Steganography/Models/UNet-Pytorch/Checkpoints/Checkpoints_SSIM_4800_10/netR_epoch_4,sumloss=-1.749747,Rloss=-0.999818.pth"))

    criterion = nn.MSELoss()
    optimizerH = optim.Adam(Hnet.parameters(), lr=0.001, betas=(0.5, 0.999))
    schedulerH = ReduceLROnPlateau(optimizerH, mode='min', factor=0.2, patience=5, verbose=True)

    optimizerR = optim.Adam(Rnet.parameters(), lr=0.001, betas=(0.5, 0.999))
    schedulerR = ReduceLROnPlateau(optimizerR, mode='min', factor=0.2, patience=8, verbose=True)

    train_list = torchvision.datasets.ImageFolder(
        root='train2017/',
        transform= torchvision.transforms.Compose([                                                      
                                       torchvision.transforms.Resize([256, 256]),
                                       torchvision.transforms.ToTensor()
       ])
    )


    val_list = torchvision.datasets.ImageFolder(
       root='val2017/',
       transform= torchvision.transforms.Compose([                                                      
                                       torchvision.transforms.Resize([256, 256]),
                                       torchvision.transforms.ToTensor()
       ])
    )

    train_loader = DataLoader(train_list, batch_size=32,
                              shuffle=True, num_workers=int(8))
    val_loader = DataLoader(val_list, batch_size=2,
                            shuffle=False, num_workers=int(8))
    smallestLoss = 10000
    for epoch in range(opt_niter):
        #train
        train(train_loader, epoch, Hnet=Hnet, Rnet=Rnet, criterion=criterion)

        #validation
        val_hloss, val_rloss, val_sumloss = validation(val_loader, epoch, Hnet=Hnet, Rnet=Rnet, criterion=criterion)

        #learning rate
        schedulerH.step(val_sumloss)
        schedulerR.step(val_rloss)

        #save model
        if val_sumloss < globals()["smallestLoss"]:
            globals()["smallestLoss"] = val_sumloss
            torch.save(Hnet.state_dict(),
                        '%s/netH_epoch_%d,sumloss=%.6f,Hloss=%.6f.pth' % (
                            'drive/My Drive/Steganography/Models/UNet-Pytorch/Checkpoints/Checkpoints_SSIM_4800_10_pt2/', epoch, val_sumloss, val_hloss))
            torch.save(Rnet.state_dict(),
                        '%s/netR_epoch_%d,sumloss=%.6f,Rloss=%.6f.pth' % (
                            'drive/My Drive/Steganography/Models/UNet-Pytorch/Checkpoints/Checkpoints_SSIM_4800_10_pt2/', epoch, val_sumloss, val_rloss))
            
    test_list = torchvision.datasets.ImageFolder(
        root='test2017/',
        transform= torchvision.transforms.Compose([                                                      
                                        torchvision.transforms.Resize([256, 256]),
                                        torchvision.transforms.ToTensor()
        ])
    )


    test_loader = DataLoader(test_list, batch_size=32,
                                  shuffle=False, num_workers=int(8))
    test(test_loader, 0, Hnet=Hnet, Rnet=Rnet, criterion=criterion)
    print("##################   test is completed, the result pic is saved in the ./training/yourcompuer+time/testPics/   ######################")

In [0]:
main()