In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2 
import matplotlib.pyplot as plt
import os
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from glob import glob 
from torch.autograd import Variable
import random
import csv
from tqdm.notebook import tqdm

In [None]:
if torch.cuda.is_available():
  device=torch.device("cuda:0")
  print("Running on the GPU")
  torch.cuda.empty_cache()
  dataType=torch.float32
else:
  device=torch.device("cpu")
  print("Running on the CPU")
  dataType=torch.float32

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
class Train_M(Dataset):
  def __init__(self):
    self.imdir_message=glob('../input/steganography-final-dataset/Steganography_New_Dataset/Train_M/*')
  def __len__(self):
    return len(self.imdir_message)

  def __getitem__(self,idx):
    imdir_message=self.imdir_message[idx]
    message=cv2.imread(imdir_message)
    
    
    return torch.tensor(message/255,dtype=torch.float32).permute(2,0,1)
    

In [None]:
class Train_C(Dataset):
  def __init__(self):
    self.imdir_cover=glob('../input/steganography-final-dataset/Steganography_New_Dataset/Train_C/*')
  def __len__(self):
    return len(self.imdir_cover)

  def __getitem__(self,idx):
    imdir_cover=self.imdir_cover[idx]
    cover=cv2.imread(imdir_cover)
   
    return torch.tensor(cover/255,dtype=torch.float32).permute(2,0,1)
   

In [None]:
class Valid(Dataset):
  def __init__(self):
    self.imdir_message=glob('../input/steganography-final-dataset/Steganography_New_Dataset/Valid_M/*')
    self.imdir_cover=glob('../input/steganography-final-dataset/Steganography_New_Dataset/Valid_C/*')
  def __len__(self):
    return len(self.imdir_message)

  def __getitem__(self,idx):
    imdir_message=self.imdir_message[idx]
    message=cv2.imread(imdir_message)
    
    imdir_cover=self.imdir_cover[idx]
    cover=cv2.imread(imdir_cover)
    
    return {'message':torch.tensor(message/255,dtype=torch.float32).permute(2,0,1),'cover':torch.tensor(cover/255,dtype=torch.float32).permute(2,0,1)}


In [None]:
class Test(Dataset):
  def __init__(self):
    self.imdir_message=glob('../input/steganography-final-dataset/Steganography_New_Dataset/Test_M/*')
    self.imdir_cover=glob('../input/steganography-final-dataset/Steganography_New_Dataset/Test_C/*')
  def __len__(self):
    return len(self.imdir_message)

  def __getitem__(self,idx):
    imdir_message=self.imdir_message[idx]
    message=cv2.imread(imdir_message)
    
    imdir_cover=self.imdir_cover[idx]
    cover=cv2.imread(imdir_cover)
    
    return {'message':torch.tensor(message/255,dtype=torch.float32).permute(2,0,1),'cover':torch.tensor(cover/255,dtype=torch.float32).permute(2,0,1)}


In [None]:
class conv_Block(nn.Module):
  def __init__(self,in_channels, out_channels, **kwargs):
    super(conv_Block,self).__init__()

    self.relu=nn.ReLU()
    self.conv=nn.Conv2d(in_channels, out_channels, **kwargs)
    self.batchNorm=nn.BatchNorm2d(out_channels)
  
  def forward(self,x):
    x=self.conv(x)
    x=self.batchNorm(x)
    x=self.relu(x)
    return x


class deconv_Block(nn.Module):
  def __init__(self,in_channels, out_channels, **kwargs):
    super(deconv_Block,self).__init__()

    self.relu=nn.ReLU()
    self.deconv=nn.ConvTranspose2d(in_channels, out_channels, **kwargs)
    self.batchNorm=nn.BatchNorm2d(out_channels)
  
  def forward(self,x):
    x=self.deconv(x)
    x=self.batchNorm(x)
    x=self.relu(x)
    return x



class Inception_block(nn.Module):
  def __init__(self, in_channels, out_11, red_33, out_33, red_55, out_55, maxPool_11):
    super(Inception_block,self).__init__()

    self.branch1=conv_Block(in_channels, out_11, kernel_size=1)
    self.branch2=nn.Sequential(conv_Block(in_channels, red_33, kernel_size=1), conv_Block(red_33, out_33, kernel_size=3, padding=1))
    self.branch3=nn.Sequential(conv_Block(in_channels, red_55, kernel_size=1), conv_Block(red_55, out_55, kernel_size=5, padding=2))
    self.branch4=nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=1, padding=1), conv_Block(in_channels, maxPool_11, kernel_size=1))

  def forward(self,x):
   return torch.cat([self.branch1(x),self.branch2(x),self.branch3(x),self.branch4(x)],1)

In [None]:
class prep_Net(nn.Module):
  def __init__(self):
    super(prep_Net,self).__init__()
    
    self.CB31=conv_Block(in_channels=3,out_channels=100,kernel_size=(3,3),stride=1,padding=1)
    self.CB32=conv_Block(in_channels=100,out_channels=100,kernel_size=(3,3),stride=1,padding=1)
    self.CB33=conv_Block(in_channels=100,out_channels=100,kernel_size=(3,3),stride=1,padding=1)
    #in_channels, out_11, red_33, out_33, red_55, out_55, maxPool_11
    
    self.In1=Inception_block(in_channels=100,out_11=15,red_33=20,out_33=35,red_55=20,out_55=35,maxPool_11=15)
    self.In2=Inception_block(in_channels=100,out_11=15,red_33=20,out_33=35,red_55=20,out_55=35,maxPool_11=15)
    self.In3=Inception_block(in_channels=100,out_11=15,red_33=20,out_33=35,red_55=20,out_55=35,maxPool_11=15)
    self.In4=Inception_block(in_channels=100,out_11=15,red_33=20,out_33=35,red_55=20,out_55=35,maxPool_11=15)
    
    self.CB34=conv_Block(in_channels=100,out_channels=100,kernel_size=(3,3),stride=1,padding=1)
    
    self.DCB21=deconv_Block(in_channels=100,out_channels=50,kernel_size=(2,2),stride=2,padding=0)
    self.CB35=conv_Block(in_channels=50,out_channels=50,kernel_size=(3,3),stride=1,padding=1)

    
  def forward(self,message):
    message=self.CB31(message)
    message=self.CB32(message)
    message=self.CB33(message)
    message=self.In1(message)
    message=self.In2(message)
    message=self.In3(message)
    message=self.In4(message)
    message=self.CB34(message)
    
    message=self.DCB21(message)
    message=self.CB35(message)
    
    return message

In [None]:

class hide_Net(nn.Module):
  def __init__(self):
    super(hide_Net,self).__init__()
    
    self.CB31_cover=conv_Block(in_channels=3,out_channels=50,kernel_size=(3,3),stride=1,padding=1)
    self.CB32_cover=conv_Block(in_channels=50,out_channels=50,kernel_size=(3,3),stride=1,padding=1)
    
    self.CB31=conv_Block(in_channels=100,out_channels=100,kernel_size=(3,3),stride=1,padding=1)
    self.CB32=conv_Block(in_channels=100,out_channels=100,kernel_size=(3,3),stride=1,padding=1)
    #in_channels, out_11, red_33, out_33, red_55, out_55, maxPool_11
    
    self.In1=Inception_block(in_channels=100,out_11=15,red_33=20,out_33=38,red_55=20,out_55=35,maxPool_11=15)
    self.In2=Inception_block(in_channels=103,out_11=15,red_33=20,out_33=38,red_55=20,out_55=35,maxPool_11=15)
    self.In3=Inception_block(in_channels=103,out_11=15,red_33=20,out_33=38,red_55=20,out_55=35,maxPool_11=15)
    
    self.CB33=conv_Block(in_channels=103,out_channels=103,kernel_size=(3,3),stride=1,padding=1)
    
    self.In4=Inception_block(in_channels=103,out_11=15,red_33=20,out_33=38,red_55=20,out_55=35,maxPool_11=15)
    self.In5=Inception_block(in_channels=103,out_11=15,red_33=20,out_33=38,red_55=20,out_55=35,maxPool_11=15)
    
    self.CB34=conv_Block(in_channels=103,out_channels=103,kernel_size=(3,3),stride=1,padding=1)
    self.CB35=conv_Block(in_channels=103,out_channels=103,kernel_size=(3,3),stride=1,padding=1)
    self.CB36=conv_Block(in_channels=103,out_channels=103,kernel_size=(3,3),stride=1,padding=1)
    
    self.CB11=conv_Block(in_channels=103,out_channels=3,kernel_size=(1,1),stride=1,padding=0)
    
  def forward(self,cover,message):
    
    cover=self.CB31_cover(cover)
    cover=self.CB32_cover(cover)
    
    encoded=torch.cat([cover,message],1)
    
    encoded=self.CB31(encoded)
    encoded=self.CB32(encoded)
    
    encoded=self.In1(encoded)
    encoded=self.In2(encoded)
    encoded=self.In3(encoded)
    
    encoded=self.CB33(encoded)
    
    encoded=self.In4(encoded)
    encoded=self.In5(encoded)
    
    encoded=self.CB34(encoded)
    encoded=self.CB35(encoded)
    encoded=self.CB36(encoded)
    
    encoded=self.CB11(encoded)
    
    return encoded

In [None]:
class rev_Net(nn.Module):
  def __init__(self):
    super(rev_Net,self).__init__()
    
    self.CB31=conv_Block(in_channels=3,out_channels=100,kernel_size=(3,3),stride=1,padding=1)
    self.CB32=conv_Block(in_channels=100,out_channels=100,kernel_size=(3,3),stride=1,padding=1)
    #in_channels, out_11, red_33, out_33, red_55, out_55, maxPool_11
    self.In1=Inception_block(in_channels=100,out_11=15,red_33=15,out_33=35,red_55=15,out_55=35,maxPool_11=15)
    self.In2=Inception_block(in_channels=100,out_11=15,red_33=15,out_33=35,red_55=15,out_55=35,maxPool_11=15)
    self.CB33=conv_Block(in_channels=100,out_channels=100,kernel_size=(3,3),stride=1,padding=1)
    self.CB11=conv_Block(in_channels=100,out_channels=50,kernel_size=(3,3),stride=1,padding=1)
    
    self.CB41=conv_Block(in_channels=100,out_channels=100,kernel_size=(4,4),stride=2,padding=1)
    self.CB21=conv_Block(in_channels=100,out_channels=100,kernel_size=(2,2),stride=2,padding=0)
    self.MP1=nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
    self.AP1=nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
   
    self.In3=Inception_block(in_channels=300,out_11=25,red_33=25,out_33=50,red_55=25,out_55=50,maxPool_11=25)
    self.In4=Inception_block(in_channels=150,out_11=25,red_33=25,out_33=50,red_55=25,out_55=50,maxPool_11=25)
    
    self.CB34=conv_Block(in_channels=150,out_channels=150,kernel_size=(3,3),stride=1,padding=1)
    self.CB35=conv_Block(in_channels=150,out_channels=150,kernel_size=(3,3),stride=1,padding=1)
    self.CB36=conv_Block(in_channels=150,out_channels=150,kernel_size=(3,3),stride=1,padding=1)
    self.CB12=conv_Block(in_channels=150,out_channels=3,kernel_size=(1,1),stride=1,padding=0)
    
  def forward(self,cover):
    
    cover=self.CB31(cover)
    cover=self.CB32(cover)
    cover=self.In1(cover)
    features1=self.In2(cover)
    features2=self.CB33(features1)
    
    CB_features1=self.CB41(features1)
    CB_features2=self.CB21(features1)
    features3=self.CB11(features2)
    MP_features=self.MP1(features3)
    AP_features=self.AP1(features3)
    
    decoded=torch.cat([CB_features1,CB_features2,MP_features,AP_features],1)
    
    decoded=self.In3(decoded)
    decoded=self.In4(decoded)
    
    decoded=self.CB34(decoded)
    decoded=self.CB35(decoded)
    decoded=self.CB36(decoded)
    decoded=self.CB12(decoded)

    
    return decoded

In [None]:
class Steganography(nn.Module):
  def __init__(self):
    super(Steganography,self).__init__()
    self.pNet=prep_Net()
    self.hNet=hide_Net()
    self.rNet=rev_Net()
    
    
  def save(self,optim,epoch):
    self.eval()
    torch.save({
        'epoch': epoch,
        'model_state_dict': self.state_dict(),
        'optimizer_state_dict': optim.state_dict(),
        'loss': 0,
        }, './stg{}.pth'.format(epoch))
    
  def load(self,optim,path):
    #checkpoint = torch.load(path,map_location=torch.device('cpu'))
    checkpoint = torch.load(path)
    self.load_state_dict(checkpoint['model_state_dict'])
    optim.load_state_dict(checkpoint['optimizer_state_dict'])
    self.epoch = checkpoint['epoch']
    self.loss = checkpoint['loss']

  def forward(self,cover,message):
    message=self.pNet(message)
    encoded=self.hNet(cover,message)
    decoded=self.rNet(encoded)
    
    return encoded,decoded

In [None]:
train_M=Train_M()
train_C=Train_C()
train_loader_M=DataLoader(train_M,batch_size=20,shuffle=True,num_workers=2)
train_loader_C=DataLoader(train_C,batch_size=20,shuffle=True,num_workers=2)

valid_data=Valid()
valid_loader=DataLoader(valid_data,batch_size=10,shuffle=True,num_workers=2)

test_data=Test()
test_loader=DataLoader(test_data,batch_size=10,shuffle=True,num_workers=2)

In [None]:
stg=Steganography().to(device)
loss=nn.MSELoss()
optimizer=optim.AdamW(stg.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)
lrs = []

In [None]:
stg.load(optimizer,'../input/dsp-project-weights-2021/stg15.pth')


In [None]:
# n=1
# TC_Loss=[]
# VC_Loss=[]
# TM_Loss=[]
# VM_Loss=[]

# losses_TC = AverageMeter()
# losses_VC = AverageMeter()
# losses_TM = AverageMeter()
# losses_VM = AverageMeter()


# x=[]
# ep=0
# for epoch in range(15):
    
#     i=0
#     iterable = enumerate(zip(train_loader_M,train_loader_C))
#     progress = None
#     progress = tqdm(iterable, desc='Train', total=len(train_loader_M))
#     iterable = progress

#     for i,(batch_M,batch_C) in iterable:
#         stg.train()
#         i=i+1
#         encoded,decoded=stg(batch_C.to(device),batch_M.to(device))
#         LSC=loss(batch_C.to(device),encoded)
#         LSM=loss(batch_M.to(device),decoded)

#         LS=(1-0.7*(LSC<LSM))*LSC+LSM

#         LS.backward()
#         optimizer.step()
#         optimizer.zero_grad()
        
#         valid = iter(valid_loader).next()
#         stg.eval()
#         with torch.no_grad():
#             encodedV,decodedV=stg(valid["cover"].to(device),valid["message"].to(device))
#             VLSC=loss(valid["cover"].to(device),encodedV)
#             VLSM=loss(valid["message"].to(device),decodedV)
            
#         if i%100==0:
#             print('Epoch:'+str(n)+'  Batch:'+str(i)+'  TCover:'+str(LSC.cpu().detach().numpy())+'  VCover:'+str(VLSC.cpu().detach().numpy())+'  TMessage:'+str(LSM.cpu().detach().numpy())+'  VMessage:'+str(VLSM.cpu().detach().numpy()))
        
#         losses_TC.update(LSC)
#         losses_TM.update(LSM)
#         losses_VC.update(VLSC)
#         losses_VM.update(VLSM)
        
#         if progress is not None:
#             progress.set_postfix_str('TC: {loss_TC:0.6f},VC: {loss_VC:0.6f}, TM: {loss_TM:0.6f},VM: {loss_VM:0.6f}'.format(
#                 loss_TC=losses_TC.avg,
#                 loss_VC=losses_VC.avg,
#                 loss_TM=losses_TM.avg,
#                 loss_VM=losses_VM.avg
#             ))
        
#         TC_Loss.append((losses_TC.avg).cpu().detach().numpy())
#         TM_Loss.append((losses_TM.avg).cpu().detach().numpy())
#         VC_Loss.append((losses_VC.avg).cpu().detach().numpy())
#         VM_Loss.append((losses_VM.avg).cpu().detach().numpy())
#         ep=ep+1
#         x.append(ep)
#     n=n+1
#     lrs.append(optimizer.param_groups[0]["lr"])
#     scheduler.step()
  

In [None]:
# stg.save(optimizer,15)

In [None]:
# TCLoss = open('../input/dsp-project-weights-2021/TC_Loss.csv', 'w')
# writer = csv.writer(TCLoss)
# writer.writerow(np.ndarray.tolist(np.asarray(TC_Loss)))
# TCLoss.close()

# TMLoss = open('../input/dsp-project-weights-2021/TM_Loss.csv', 'w')
# writer = csv.writer(TMLoss)
# writer.writerow(np.ndarray.tolist(np.asarray(TM_Loss)))
# TMLoss.close()

# VCLoss = open('../input/dsp-project-weights-2021/VC_Loss.csv', 'w')
# writer = csv.writer(VCLoss)
# writer.writerow(np.ndarray.tolist(np.asarray(VC_Loss)))
# VCLoss.close()

# VMLoss = open('../input/dsp-project-weights-2021/VM_Loss.csv', 'w')
# writer = csv.writer(VMLoss)
# writer.writerow(np.ndarray.tolist(np.asarray(VM_Loss)))
# VMLoss.close()

# xx = open('xx.csv', 'w')
# writer = csv.writer(xx)
# writer.writerow(np.ndarray.tolist(np.asarray(x)))
# xx.close()


In [None]:
# plt.plot(x,TC_Loss, label = "TC_Loss")
# plt.plot(x,VC_Loss, label = "VC_Loss")
# plt.legend()
# plt.show()

In [None]:
# plt.plot(x,TM_Loss, label = "TM_Loss")
# plt.plot(x,VM_Loss, label = "VM_Loss")
# plt.legend()
# plt.show()

In [None]:
test = iter(valid_loader).next()
with torch.no_grad():
    encodedT,decodedT=stg(test["cover"].to(device),test["message"].to(device))
    TLSC=loss(test["cover"].to(device),encodedT)
    TLSM=loss(test["message"].to(device),decodedT)
    
print('TC_Loss:'+str(TLSC.cpu().numpy())+' TM_Loss:'+str(TLSM.cpu().numpy()))

encN=encodedT
decN=decodedT

enc2=encN[0,:,:,:].permute(1,2,0)
dec2=decN[0,:,:,:].permute(1,2,0)


enc2=enc2.cpu().numpy()*255
dec2=dec2.cpu().numpy()*255
cover=test["cover"]
cover2=cover[0,:,:,:]
cover2=cover2.permute(1,2,0).numpy()*255
message=test["message"]
message2=message[0,:,:,:]
message2=message2.permute(1,2,0).numpy()*255


f, axarr = plt.subplots(1,4)
axarr[0].imshow(cover2.astype(int))
axarr[1].imshow(enc2.astype(int))
axarr[2].imshow(message2.astype(int))
axarr[3].imshow(dec2.astype(int))
