In [1]:
import os
import torch
import numpy as np 
import torchvision
import pandas as pd
from PIL import Image
import torch.nn as nn
from torchviz import make_dot
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [2]:
%matplotlib inline
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
lr = 0.001
epochs = 1000
img_size = 128
train_batch_size = 16
valid_batch_size = 1
cover_loss_weight = 1
secret_loss_weight = 1
decoder_weight_loss = 1
dataset_path = "./new_data"

In [4]:
transforms = torchvision.transforms.ToTensor()

In [5]:
class StegDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_csv, transforms):
        self.dataset = pd.read_csv(dataset_csv)
        self.dataset = self.dataset.reset_index(drop=True)            
        self.transforms = transforms
        self.folder_type = "training" if "train" in dataset_csv else "validation"
    
    def __getitem__(self, index):
        cover_image = self.dataset.iloc[index]["cover_image"]
        secret_image = self.dataset.iloc[index]["secret_image"]

        cover_image = Image.open(os.path.join(dataset_path, self.folder_type, cover_image))
        secret_image = Image.open(os.path.join(dataset_path, self.folder_type, secret_image))

        transformed_cover_image = self.transforms(cover_image)
        transformed_secret_image = self.transforms(secret_image)
        
        return {
            "cover_image": transformed_cover_image,
            "secret_image": transformed_secret_image
        }
    
    def __len__(self):
        return len(self.dataset)

In [6]:
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.depthwise = nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size=(3, 3),
            stride=1,
            padding=1,
            groups=in_channels,
        )
        self.pointwise = nn.Conv2d(
            in_channels, 
            out_channels, 
            kernel_size=1, 
            stride=1, 
            padding=0
        )

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

In [7]:
class MultiDepthwiseBlock(nn.Module):
    def __init__(self, in_channels, out_channels, count):
        super().__init__()
        
        layers = []
        for i in range(count):
            layers.append(DepthwiseSeparableConv(in_channels, out_channels))
            layers.append(nn.BatchNorm2d(out_channels))
            layers.append(nn.ReLU())

        self.sequential_layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.sequential_layers(x)

In [8]:
class PrepNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), stride=1, padding=1), 
            nn.BatchNorm2d(32),
            nn.ReLU(),

            MultiDepthwiseBlock(32, 32, 3),
            MultiDepthwiseBlock(32, 64, 1),
            
            MultiDepthwiseBlock(64, 64, 3),
            MultiDepthwiseBlock(64, 128, 1),

            MultiDepthwiseBlock(128, 128, 3),
            MultiDepthwiseBlock(128, 256, 1),

            MultiDepthwiseBlock(256, 256, 3),
            MultiDepthwiseBlock(256, 150, 1) 
        )

    def forward(self, secret_image):
        x = self.conv_layers(secret_image)
        return x

In [9]:
class HidingNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=153, out_channels=32, kernel_size=(3, 3), stride=1, padding=1), 
            nn.BatchNorm2d(32),
            nn.ReLU(),

            MultiDepthwiseBlock(32, 32, 3),
            MultiDepthwiseBlock(32, 64, 1),
            
            MultiDepthwiseBlock(64, 64, 3),
            MultiDepthwiseBlock(64, 128, 1),

            MultiDepthwiseBlock(128, 128, 3),
            MultiDepthwiseBlock(128, 256, 1),

            MultiDepthwiseBlock(256, 256, 3),
            MultiDepthwiseBlock(256, 3, 1) 
        )

    def forward(self, secret_image, cover_image):
        concatenated_image = torch.cat([secret_image, cover_image], dim=1)
        x = self.conv_layers(concatenated_image)
        return x

In [10]:
class RevealNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3), stride=1, padding=1), 
            nn.BatchNorm2d(32),
            nn.ReLU(),

            MultiDepthwiseBlock(32, 32, 3),
            MultiDepthwiseBlock(32, 64, 1),
            
            MultiDepthwiseBlock(64, 64, 3),
            MultiDepthwiseBlock(64, 128, 1),

            MultiDepthwiseBlock(128, 128, 3),
            MultiDepthwiseBlock(128, 256, 1),

            MultiDepthwiseBlock(256, 256, 3),
            MultiDepthwiseBlock(256, 3, 1) 
        )

    def forward(self, stego_image):
        x = self.conv_layers(stego_image)
        return x

In [11]:
class EncoderModel(nn.Module):
    def __init__(self, prepNet, hidingNet):
        super().__init__()
        self.prepNet = prepNet
        self.hidingNet = hidingNet
    
    def forward(self, cover_image, secret_image):
        encoded_image = self.prepNet(secret_image)
        stego_image = self.hidingNet(encoded_image, cover_image)
        return stego_image

In [12]:
class DecoderModel(nn.Module):
    def __init__(self, revealNet):
        super().__init__()
        self.revealNet = revealNet
    
    def forward(self, stego_image, secret_image):
        predicted_secret_image = self.revealNet(stego_image)
        return predicted_secret_image

In [13]:
class EncoderLoss(nn.Module):
    def __init__(self, cover_weight, secret_weight):
        super().__init__()
        self.cover_weight = cover_weight
        self.secret_weight = secret_weight
        
    def forward(self, cover_image, predicted_cover_image, secret_image, predicted_secret_image):
        cover_loss = self.cover_weight * F.mse_loss(cover_image, predicted_cover_image)
        secret_loss = self.secret_weight * F.mse_loss(secret_image, predicted_secret_image) 
        return cover_loss + secret_loss

In [14]:
class DecoderLoss(nn.Module):
    def __init__(self, decoder_loss_weight):
        super().__init__()
        self.decoder_loss_weight = decoder_loss_weight
    
    def forward(self, predicted_secret_image, secret_image):
        reveal_img = self.decoder_loss_weight * F.mse_loss(secret_image, predicted_secret_image)
        return reveal_img

In [15]:
prep_net = PrepNet()
hiding_net = HidingNet()
reveal_net = RevealNet()

encoder_model = EncoderModel(prep_net, hiding_net)
decoder_model = DecoderModel(reveal_net)

encoder_model.to(device), decoder_model.to(device)

(EncoderModel(
   (prepNet): PrepNet(
     (conv_layers): Sequential(
       (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (2): ReLU()
       (3): MultiDepthwiseBlock(
         (sequential_layers): Sequential(
           (0): DepthwiseSeparableConv(
             (depthwise): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
             (pointwise): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
           )
           (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
           (2): ReLU()
           (3): DepthwiseSeparableConv(
             (depthwise): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
             (pointwise): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
           )
           (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_

In [16]:
training_csv_path = os.path.join(dataset_path, "train_dataset.csv")
validation_csv_path = os.path.join(dataset_path, "validation_dataset.csv")

training_dataset = StegDataset(training_csv_path, transforms)
valid_dataset = StegDataset(validation_csv_path, transforms)

train_data_loader = torch.utils.data.DataLoader(training_dataset, 
                                                batch_size=train_batch_size, 
                                                shuffle=True,
                                                drop_last=True,
                                                num_workers=0
                                               )

valid_data_loader = torch.utils.data.DataLoader(valid_dataset, 
                                                batch_size=valid_batch_size, 
                                                shuffle=True,
                                                drop_last=True,
                                                num_workers=0
                                               )

In [17]:
encoder_loss_func = EncoderLoss(cover_loss_weight, secret_loss_weight)
decoder_loss_func = DecoderLoss(decoder_weight_loss)

encoder_optimizer = torch.optim.Adam(encoder_model.parameters(), lr=lr)
decoder_optimizer = torch.optim.Adam(decoder_model.parameters(), lr=lr)

In [18]:
def training(encoder_model,
             decoder_model,
             encoder_loss_func, 
             decoder_loss_func,
             encoder_optimizer,
             decoder_optimizer,
             train_loader, 
             epochs,
             print_every=50):
    
    encoder_loss_list = []
    decoder_loss_list = []

    for epoch in range(epochs):
        for batch in train_loader:
            cover_image = batch["cover_image"].to(device)
            secret_image = batch["secret_image"].to(device)

            # Phase 1: Train encoder
            for param in encoder_model.parameters():
                param.requires_grad = True
            for param in decoder_model.parameters():
                param.requires_grad = False
                
            stego_image = encoder_model(cover_image, secret_image)  
            predicted_secret_image = decoder_model(stego_image)
            encoder_loss = encoder_loss_func(predicted_secret_image, secret_image, stego_image, cover_image)
            encoder_optimizer.zero_grad()
            encoder_loss.backward()
            encoder_optimizer.step()

            # Phase 2: Train Decoder
            for param in encoder_model.parameters():
                param.requires_grad = False
            for param in decoder_model.parameters():
                param.requires_grad = True

            with torch.no_grad():
                stego_image = encoder_model(cover_image, secret_image)

            predicted_secret_image = decoder_model(stego_image)
            decoder_loss = decoder_loss_func(predicted_secret_image, secret_image)
            decoder_optimizer.zero_grad()
            decoder_loss.backwards()
            decoder_optimizer.step()
        
        encoder_loss_list.append(encoder_loss.item())
        decoder_loss_list.append(decoder_loss.item())
        
        if epoch % print_every == 0:
            print("encoder loss {} | decoder loss {}".format(encoder_loss.item(), decoder_loss.item()))
        
    return encoder_model, decoder_model, encoder_loss_list, decoder_loss_list

In [None]:
enc_model, dec_model, enc_loss_list, dec_loss_list = training(encoder_model, 
                                                              decoder_model,
                                                              encoder_loss_func,
                                                              decoder_loss_func,
                                                              encoder_optimizer,
                                                              decoder_optimizer,
                                                              train_data_loader, 
                                                              2,
                                                              50)

In [79]:
# plt.plot(enc_loss_list)
# plt.show()

In [80]:
# plt.plot(dec_loss_list)
# plt.show()

In [81]:
# encoder_model.eval()
# decoder_model.eval()

In [89]:
# val_data = next(iter(valid_data_loader))
# cover_image = val_data.get("cover_image")
# secret_image = val_data.get("secret_image")

# with torch.no_grad():
#     stego_image = encoder_model(cover_image, secret_image)
#     predicted_secret_image = decoder_model(stego_image)

#     enc_loss = encoder_loss_func(predicted_secret_image, secret_image, stego_image, cover_image)
#     dec_loss = decoder_loss_funct(predicted_secret_image, secret_image)
    
#     print("encoder loss => ", enc_loss)
#     print("decoder loss => ", dec_loss)
    
#     plt.imshow(stego_image), plt.imshow(predicted_secret_image)