In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np 
import pandas as pd
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
from torchviz import make_dot

In [2]:
%matplotlib inline
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [14]:
dataset_path = "../scripts/new_data"
train_csv = "train_dataset.csv"
valid_csv = "validation_dataset.csv"

In [15]:
IMG_SIZE = 128
LEARNING_RATE  = 0.001
COVER_LOSS_WEIGHT = 1
SECRET_LOSS_WEIGHT = 1
DECODER_LOSS_WEIGHT = 1
TRAIN_BATCH_SIZE = 16
VALID_BATCH_SIZE = 1
EPOCHS = 1000

In [49]:
transformations = {
    'train_transforms':torchvision.transforms.Compose([torchvision.transforms.Resize((IMG_SIZE, IMG_SIZE)), torchvision.transforms.ToTensor()]),
}

class StegDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_csv, transforms):
        self.dataset = pd.read_csv(dataset_csv)
        self.dataset = self.dataset.reset_index(drop=True)            
        self.transforms = transforms
    
    def __getitem__(self, index):
        cover_image = self.dataset.iloc[index]["cover_image"]
        secret_image = self.dataset.iloc[index]["secret_image"]
        
        cover_image = Image.open(os.path.join(dataset_path, "training", cover_image))
        secret_image = Image.open(os.path.join(dataset_path, "validation", secret_image))
        
        transformed_cover_image = self.transforms(cover_image)
        transformed_secret_image = self.transforms(secret_image)
        
        return {
            "cover_image": transformed_cover_image,
            "secret_image": transformed_secret_image
        }
    
    def __len__(self):
        return len(self.dataset)

In [50]:
class PrepNet(nn.Module):
    def __init__(self):
        super().__init__()   
        # first inception block
        self.conv1 = nn.Conv2d(in_channels=3,  out_channels=50, kernel_size=(3, 3), stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=50, out_channels=50, kernel_size=(3, 3), stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=50, out_channels=50, kernel_size=(3, 3), stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=50, out_channels=50, kernel_size=(3, 3), stride=1, padding=1)
        self.conv5 = nn.Conv2d(in_channels=50, out_channels=50, kernel_size=(3, 3), stride=1, padding=1)

        # second inception block
        self.conv6 = nn.Conv2d(in_channels=3,   out_channels=50, kernel_size=(4, 4), stride=1, padding="same")
        self.conv7 = nn.Conv2d(in_channels=50,  out_channels=50, kernel_size=(4, 4), stride=1, padding="same")
        self.conv8 = nn.Conv2d(in_channels=50,  out_channels=50, kernel_size=(4, 4), stride=1, padding="same")
        self.conv9 = nn.Conv2d(in_channels=50,  out_channels=50, kernel_size=(4, 4), stride=1, padding="same")
        self.conv10 = nn.Conv2d(in_channels=50, out_channels=50, kernel_size=(4, 4), stride=1, padding="same")

        # third inception block
        self.conv11 = nn.Conv2d(in_channels=3,  out_channels=50, kernel_size=(5, 5), stride=1, padding=2)
        self.conv12 = nn.Conv2d(in_channels=50, out_channels=50, kernel_size=(5, 5), stride=1, padding=2)
        self.conv13 = nn.Conv2d(in_channels=50, out_channels=50, kernel_size=(5, 5), stride=1, padding=2)
        self.conv14 = nn.Conv2d(in_channels=50, out_channels=50, kernel_size=(5, 5), stride=1, padding=2)
        self.conv15 = nn.Conv2d(in_channels=50, out_channels=50, kernel_size=(5, 5), stride=1, padding=2)
    
    def forward(self, secret_image):
        # first inception block (3x3)
        x1 = F.relu(self.conv1(secret_image))
        x1 = F.relu(self.conv2(x1))
        x1 = F.relu(self.conv3(x1))
        x1 = F.relu(self.conv4(x1))
        x1 = F.relu(self.conv5(x1))

        # second inception block (4x4)
        x2 = F.relu(self.conv6(secret_image))
        x2 = F.relu(self.conv7(x2))
        x2 = F.relu(self.conv8(x2))
        x2 = F.relu(self.conv9(x2))
        x2 = F.relu(self.conv10(x2))

        #  third inception block (5x5)
        x3 = F.relu(self.conv11(secret_image))
        x3 = F.relu(self.conv12(x3))
        x3 = F.relu(self.conv13(x3))
        x3 = F.relu(self.conv14(x3))
        x3 = F.relu(self.conv15(x3))
        
        final_concat_image = torch.cat([x1, x2, x3], dim=1)
        return final_concat_image

In [51]:
class HidingNet(nn.Module):
    def __init__(self):
        super().__init__()
       # first inception block
        self.conv1 = nn.Conv2d(in_channels=153, out_channels=50, kernel_size=(3, 3), stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=50,  out_channels=50, kernel_size=(3, 3), stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=50,  out_channels=50, kernel_size=(3, 3), stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=50,  out_channels=50, kernel_size=(3, 3), stride=1, padding=1)
        self.conv5 = nn.Conv2d(in_channels=50,  out_channels=50, kernel_size=(3, 3), stride=1, padding=1)

        # second inception block
        self.conv6 = nn.Conv2d(in_channels=153, out_channels=50, kernel_size=(4, 4), stride=1, padding="same")
        self.conv7 = nn.Conv2d(in_channels=50,  out_channels=50, kernel_size=(4, 4), stride=1, padding="same")
        self.conv8 = nn.Conv2d(in_channels=50,  out_channels=50, kernel_size=(4, 4), stride=1, padding="same")
        self.conv9 = nn.Conv2d(in_channels=50,  out_channels=50, kernel_size=(4, 4), stride=1, padding="same")
        self.conv10 = nn.Conv2d(in_channels=50, out_channels=50, kernel_size=(4, 4), stride=1, padding="same")

        # third inception block
        self.conv11 = nn.Conv2d(in_channels=153, out_channels=50, kernel_size=(5, 5), stride=1, padding=2)
        self.conv12 = nn.Conv2d(in_channels=50,  out_channels=50, kernel_size=(5, 5), stride=1, padding=2)
        self.conv13 = nn.Conv2d(in_channels=50,  out_channels=50, kernel_size=(5, 5), stride=1, padding=2)
        self.conv14 = nn.Conv2d(in_channels=50,  out_channels=50, kernel_size=(5, 5), stride=1, padding=2)
        self.conv15 = nn.Conv2d(in_channels=50,  out_channels=50, kernel_size=(5, 5), stride=1, padding=2)

        self.final_layer = nn.Conv2d(in_channels=150, out_channels=3, kernel_size=(3,3), stride=1, padding=1)
                
    def forward(self, secret_image, cover_image):
        concatenated_secrets = torch.cat([cover_image, secret_image], dim=1)

        # first inception block (3x3)
        x1 = F.relu(self.conv1(concatenated_secrets))
        x1 = F.relu(self.conv2(x1))
        x1 = F.relu(self.conv3(x1))
        x1 = F.relu(self.conv4(x1))
        x1 = F.relu(self.conv5(x1))

        # second inception block (4x4)
        x2 = F.relu(self.conv6(concatenated_secrets))
        x2 = F.relu(self.conv7(x2))
        x2 = F.relu(self.conv8(x2))
        x2 = F.relu(self.conv9(x2))
        x2 = F.relu(self.conv10(x2))

        #  third inception block (5x5)
        x3 = F.relu(self.conv11(concatenated_secrets))
        x3 = F.relu(self.conv12(x3))
        x3 = F.relu(self.conv13(x3))
        x3 = F.relu(self.conv14(x3))
        x3 = F.relu(self.conv15(x3))

        # stego image 
        stego_image = F.relu(self.final_layer(torch.cat([x1, x2, x3], dim=1)))
        return stego_image
        

In [52]:
class RevealNet(nn.Module):
    def __init__(self):
        super().__init__()
         # first inception block
        self.conv1 = nn.Conv2d(in_channels=3,  out_channels=50, kernel_size=(3, 3), stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=50, out_channels=50, kernel_size=(3, 3), stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=50, out_channels=50, kernel_size=(3, 3), stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=50, out_channels=50, kernel_size=(3, 3), stride=1, padding=1)
        self.conv5 = nn.Conv2d(in_channels=50, out_channels=50, kernel_size=(3, 3), stride=1, padding=1)

        # second inception block
        self.conv6 = nn.Conv2d(in_channels=3,   out_channels=50, kernel_size=(4, 4), stride=1, padding="same")
        self.conv7 = nn.Conv2d(in_channels=50,  out_channels=50, kernel_size=(4, 4), stride=1, padding="same")
        self.conv8 = nn.Conv2d(in_channels=50,  out_channels=50, kernel_size=(4, 4), stride=1, padding="same")
        self.conv9 = nn.Conv2d(in_channels=50,  out_channels=50, kernel_size=(4, 4), stride=1, padding="same")
        self.conv10 = nn.Conv2d(in_channels=50, out_channels=50, kernel_size=(4, 4), stride=1, padding="same")

        # third inception block
        self.conv11 = nn.Conv2d(in_channels=3,  out_channels=50, kernel_size=(5, 5), stride=1, padding=2)
        self.conv12 = nn.Conv2d(in_channels=50, out_channels=50, kernel_size=(5, 5), stride=1, padding=2)
        self.conv13 = nn.Conv2d(in_channels=50, out_channels=50, kernel_size=(5, 5), stride=1, padding=2)
        self.conv14 = nn.Conv2d(in_channels=50, out_channels=50, kernel_size=(5, 5), stride=1, padding=2)
        self.conv15 = nn.Conv2d(in_channels=50, out_channels=50, kernel_size=(5, 5), stride=1, padding=2)
        
        self.final_layer = nn.Conv2d(in_channels=150, out_channels=3, kernel_size=(3,3), stride=1, padding=1)    
    
    def forward(self, stego_image):
        # first inception block (3x3)
        x1 = F.relu(self.conv1(stego_image))
        x1 = F.relu(self.conv2(x1))
        x1 = F.relu(self.conv3(x1))
        x1 = F.relu(self.conv4(x1))
        x1 = F.relu(self.conv5(x1))

        # second inception block (4x4)
        x2 = F.relu(self.conv6(stego_image))
        x2 = F.relu(self.conv7(x2))
        x2 = F.relu(self.conv8(x2))
        x2 = F.relu(self.conv9(x2))
        x2 = F.relu(self.conv10(x2))

        #  third inception block (5x5)
        x3 = F.relu(self.conv11(stego_image))
        x3 = F.relu(self.conv12(x3))
        x3 = F.relu(self.conv13(x3))
        x3 = F.relu(self.conv14(x3))
        x3 = F.relu(self.conv15(x3))

        recovered_img = F.relu(self.final_layer(torch.cat([x1, x2, x3], dim=1)))        
        return recovered_img

In [53]:
class EncoderModel(nn.Module):
    def __init__(self, prepNet, hidingNet):
        super().__init__()
        self.prepNet = prepNet
        self.hidingNet = hidingNet
    
    def forward(self, cover_image, secret_image):
        encoded_image = self.prepNet(secret_image)
        stego_image = self.hidingNet(encoded_image, cover_image)
        return stego_image

In [54]:
class DecoderModel(nn.Module):
    def __init__(self, revealNet):
        super().__init__()
        self.revealNet = revealNet
    
    def forward(self, stego_image, secret_image):
        predicted_secret_image = self.revealNet(stego_image)
        return predicted_secret_image

In [55]:
class EncoderLoss(nn.Module):
    def __init__(self, cover_weight, secret_weight):
        super().__init__()
        self.cover_weight = cover_weight
        self.secret_weight = secret_weight
        
    def forward(self, predicted_secret_image, secret_image, predicted_cover_image, cover_image):
        cover_loss = self.cover_weight * F.mse_loss(predicted_cover_image, cover_image)
        secret_loss = self.secret_weight * F.mse_loss(predicted_secret_image, secret_image) 
        return cover_loss + secret_loss

class DecoderLoss(nn.Module):
    def __init__(self, decoder_loss_weight):
        super().__init__()
        self.decoder_loss_weight = decoder_loss_weight
    
    def forward(self, predicted_secret_image, secret_image):
        reveal_img = self.decoder_loss_weight * F.mse_loss(predicted_secret_image, secret_image)
        return reveal_img

In [56]:
prep_net = PrepNet()
hiding_net = HidingNet()
reveal_net = RevealNet()

encoder_model = EncoderModel(prep_net, hiding_net)
decoder_model = DecoderModel(reveal_net)

encoder_model.to(device), decoder_model.to(device)

(EncoderModel(
   (prepNet): PrepNet(
     (conv1): Conv2d(3, 50, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (conv2): Conv2d(50, 50, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (conv3): Conv2d(50, 50, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (conv4): Conv2d(50, 50, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (conv5): Conv2d(50, 50, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (conv6): Conv2d(3, 50, kernel_size=(4, 4), stride=(1, 1), padding=same)
     (conv7): Conv2d(50, 50, kernel_size=(4, 4), stride=(1, 1), padding=same)
     (conv8): Conv2d(50, 50, kernel_size=(4, 4), stride=(1, 1), padding=same)
     (conv9): Conv2d(50, 50, kernel_size=(4, 4), stride=(1, 1), padding=same)
     (conv10): Conv2d(50, 50, kernel_size=(4, 4), stride=(1, 1), padding=same)
     (conv11): Conv2d(3, 50, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
     (conv12): Conv2d(50, 50, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
     (conv13

In [70]:
training_csv_path = os.path.join(dataset_path, "train_dataset.csv")
validation_csv_path = os.path.join(dataset_path, "validation_dataset.csv")

training_dataset = StegDataset(training_csv_path, transformations["train_transforms"])
valid_dataset = StegDataset(validation_csv_path, transformations["train_transforms"])

train_data_loader = torch.utils.data.DataLoader(training_dataset, 
                                                batch_size=TRAIN_BATCH_SIZE, 
                                                shuffle=True,
                                                drop_last=True,
                                                num_workers=0
                                               )

valid_data_loader = torch.utils.data.DataLoader(valid_dataset, 
                                                batch_size=VALID_BATCH_SIZE, 
                                                shuffle=True,
                                                drop_last=True,
                                                num_workers=0
                                               )

In [71]:
encoder_loss_func = EncoderLoss(SECRET_LOSS_WEIGHT, COVER_LOSS_WEIGHT)
decoder_loss_func = DecoderLoss(DECODER_LOSS_WEIGHT)

encoder_optimizer = torch.optim.Adam(encoder_model.parameters(), lr=LEARNING_RATE)
decoder_optimizer = torch.optim.Adam(decoder_model.parameters(), lr=LEARNING_RATE)

In [72]:
def training(encoder_model,
             decoder_model,
             encoder_loss_func, 
             decoder_loss_func,
             encoder_optimizer,
             decoder_optimizer,
             train_loader, 
             epochs,
             print_every=50):
    
    encoder_loss_list = []
    decoder_loss_list = []

    for epoch in range(epochs):
        for batch in train_loader:
            cover_image = batch["cover_image"].to(device)
            secret_image = batch["secret_image"].to(device)

            # Phase 1: Train encoder
            for param in encoder_model.parameters():
                param.requires_grad = True
            for param in decoder_model.parameters():
                param.requires_grad = False
                
            stego_image = encoder_model(cover_image, secret_image)  
            predicted_secret_image = decoder_model(stego_image)
            encoder_loss = encoder_loss_func(predicted_secret_image, secret_image, stego_image, cover_image)
            encoder_optimizer.zero_grad()
            encoder_loss.backward()
            encoder_optimizer.step()

            # Phase 2: Train Decoder
            for param in encoder_model.parameters():
                param.requires_grad = False
            for param in decoder_model.parameters():
                param.requires_grad = True

            with torch.no_grad():
                stego_image = encoder_model(cover_image, secret_image)

            predicted_secret_image = decoder_model(stego_image)
            decoder_loss = decoder_loss_func(predicted_secret_image, secret_image)
            decoder_optimizer.zero_grad()
            decoder_loss.backwards()
            decoder_optimizer.step()
        
        encoder_loss_list.append(encoder_loss.item())
        decoder_loss_list.append(decoder_loss.item())
        
        if epoch % print_every == 0:
            print("encoder loss {} | decoder loss {}".format(encoder_loss.item(), decoder_loss.item()))
        
    return encoder_model, decoder_model, encoder_loss_list, decoder_loss_list

In [78]:
# enc_model, dec_model, enc_loss_list, dec_loss_list = training(encoder_model, 
#                                                               decoder_model,
#                                                               encoder_loss_func,
#                                                               decoder_loss_func,
#                                                               encoder_optimizer,
#                                                               decoder_optimizer,
#                                                               train_data_loader, 
#                                                               EPOCHS,
#                                                               50)

In [79]:
# plt.plot(enc_loss_list)
# plt.show()

In [80]:
# plt.plot(dec_loss_list)
# plt.show()

In [81]:
# encoder_model.eval()
# decoder_model.eval()

In [89]:
# val_data = next(iter(valid_data_loader))
# cover_image = val_data.get("cover_image")
# secret_image = val_data.get("secret_image")

# with torch.no_grad():
#     stego_image = encoder_model(cover_image, secret_image)
#     predicted_secret_image = decoder_model(stego_image)

#     enc_loss = encoder_loss_func(predicted_secret_image, secret_image, stego_image, cover_image)
#     dec_loss = decoder_loss_funct(predicted_secret_image, secret_image)
    
#     print("encoder loss => ", enc_loss)
#     print("decoder loss => ", dec_loss)
    
#     plt.imshow(stego_image), plt.imshow(predicted_secret_image)