In [None]:
import matplotlib.pyplot as plt
import os
import sys

import torch
import torch.nn as nn

import torchvision.transforms as transforms
from torch.utils.data import Dataset

from time import time

import numpy as np
from PIL import Image
import natsort

from sklearn.model_selection import train_test_split


In [None]:
GPU = 2
gpu_list = ''
multi_gpus = False
if isinstance(GPU, int):
    gpu_list = str(GPU)
else:
    multi_gpus = True
    for i, gpu_id in enumerate(GPU):
        gpu_list += str(gpu_id)
        if i != len(GPU) - 1:
            gpu_list += ','
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list

is_cuda_available = torch.cuda.is_available()

if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")


In [None]:
available_gpus = [torch.cuda.device(i) for i in range(torch.cuda.device_count())]
device = torch.device('cuda:0' if (torch.cuda.is_available()) else 'cpu')

In [None]:
device

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)



# Model

In [None]:

class Conv2DBlock(nn.Module):
    def __init__(self, in_c, out_c, kernel, stride, padding):
        super(Conv2DBlock, self).__init__()

        self.block = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel, stride, padding),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.2),
            nn.BatchNorm2d(out_c)
        )

        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(
                    m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self,x):
        return self.block(x)


class TransConv2DBlock(nn.Module):
    def __init__(self, in_c, out_c, kernel, stride, padding, output_padding=1):
        super(TransConv2DBlock, self).__init__()

        self.block = nn.Sequential(
            nn.ConvTranspose2d(in_c, out_c, kernel, stride, padding, output_padding=output_padding),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.2),
            nn.BatchNorm2d(out_c)
        )

        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(
                    m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        return self.block(x)


In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, input_shape):
        super(AutoEncoder, self).__init__()

        # input --> 321 x 321 x 3
        self.conv1 = Conv2DBlock(input_shape, 64, (4,6), 1, (1,2)) 
        # 320 x 320 x 64
        self.conv2 = Conv2DBlock(64, 64, 3, 2, 1)
        # 160 x 160 x 64
        self.conv3 = Conv2DBlock(64, 128, 3, 2, 1)
        # 80 x 80 x 128
        self.conv4 = Conv2DBlock(128, 256, 3,2,1)
        # 40 x 40 x 256
        self.conv5 = Conv2DBlock(256, 512, 3,2,1)
        # 20 x 20 x 512

        self.deconv1 = TransConv2DBlock(512, 256, 3, 2, 1)
        # 40 x 40 x 256
        self.conv6 = Conv2DBlock(512, 256, 3, 1, 1)
        # 40 x 40 x 256
        
        self.deconv2 = TransConv2DBlock(256, 128, 3, 2, 1)
        # 80 x 80 x 128
        self.conv7 = Conv2DBlock(256, 128, 3, 1, 1)
        # 80 x 80 x 128
        
        self.deconv3 = TransConv2DBlock(128, 64, 3, 2, 1)
        # 160 x 160 x 64
        self.conv8 = Conv2DBlock(128, 64, 3, 1, 1)
        # 160 x 160 x 64
        
        self.deconv4 = TransConv2DBlock(64, 64, 3, 2, 1)
        # 320 x 320 x 64
        self.conv9 = Conv2DBlock(128, 64, 3, 1, 1)
        # 320 x 320 x 64
        
        self.deconv5 = TransConv2DBlock(64, 64, (4,6), 1, (1,2), output_padding=0)
        # 321 x 321 x 64
        self.conv10 = nn.Conv2d(64, 3, 3, 1, 1)
        # 321 x 321 x 3

    def forward(self, x, predict=False):
        # 321 x 321 x 3
        x = self.conv1(x) # 320 x 320 x 64
        x1 = self.conv2(x) # 160 x 160 x 64
        x2 = self.conv3(x1) # 80 x 80 x 128
        x3 = self.conv4(x2) # 40 x 40 x 256
        x4 = self.conv5(x3) # 20 x 20 x 512
        

        y4 = self.deconv1(x4)  # 40 x 40 x 256
        y4 = torch.cat((y4,x3),1) # 40 x 40 x 512
        y4 = self.conv6(y4) # 40 x 40 x 256
        
        y3 = self.deconv2(y4) # 80 x 80 x 128
        y3 = torch.cat((y3,x2),1)  # 80 x 80 x 256
        y3 = self.conv7(y3) # 80 x 80 x 128
        
        y2 = self.deconv3(y3) # 160 x 160 x 64
        y2 = torch.cat((y2,x1),1)  # 160 x 160 x 128
        y2 = self.conv8(y2) # 160 x 160 x 64
        
        y1 = self.deconv4(y2) # 320 x 320 x 64
        y1 = torch.cat((y1,x),1) # 320 x 320 x 128
        y1 = self.conv9(y1) # 320 x 320 x 64
        
        y0 = self.deconv5(y1) # 321 x 321 x 64
        y0 = self.conv10(y0) # 321 x 321 x 3
        
#        if predict:
        y = torch.sigmoid(y0)
#        else:
#            y = y0

        return y

In [None]:
class DenoiseLoss(nn.Module):
    def __init__(self):
        super(DenoiseLoss, self).__init__()
        self.criterion = nn.MSELoss()

    def forward(self, x, y):
        return self.criterion(x.flatten(), y.flatten())


# Trainer

In [None]:
class Trainer():
    def __init__(self, device, use_mult_gpu=False):
        self.device = device
        self.model = AutoEncoder(3)#.to(self.device)
        self.model = self.model.to(device)
        if use_mult_gpu:
            self.model = nn.DataParallel(self.model)
        
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        self.loss_criterion = DenoiseLoss()

    def step(self, x, y):
        self.model.zero_grad()

        y_rec = self.model(x)
        loss_batch = self.loss_criterion(y_rec, y)

        loss_batch.backward()
        self.optimizer.step()

        return loss_batch
    
    def eval(self, data_loader):
        self.model.eval()

        loss = 0.
        for _, x_batch, y_batch in data_loader:
            x_batch = x_batch.to(self.device)
            y_batch = y_batch.to(self.device)

            
            with torch.no_grad():
                y_rec = self.model(x_batch)
                loss_batch = self.loss_criterion(y_rec, y_batch)
                
                loss += loss_batch

        return loss/len(data_loader)
    
    
    def train(self, tr_loader, ts_loader, batch_size=32, num_epochs=10, step=20):
        losses = np.zeros(num_epochs)
        tr_losses = np.zeros(num_epochs)
        te_losses = np.zeros(num_epochs)
        
        best_loss = np.Inf

        total_steps = (len(tr_loader.dataset) // batch_size)  # *num_epochs
        print("[INFO] Starting training phase...")
        start = time()

        try:
            step_count = 0
            for epoch in range(num_epochs):
                i = 0
                self.model.train()
                for _, x_batch, y_batch in tr_loader:
                    i += 1
                    step_count += 1
                    x_batch = x_batch.to(self.device)
                    y_batch = y_batch.to(self.device)

                    ### Train autoencoder ###
                    loss = self.step(x_batch,y_batch)

                    losses[epoch] = losses[epoch] * (i / (i + 1.)) + loss.item() * (1. / (i + 1.))

                    if (i + 1) % step == 0:
                        sys.stdout.write(
                            '\nEpoch [{:>3}/{}] | Step [{:>3}/{}]| loss: {:.4f} |'
                            .format(epoch + 1, num_epochs, i + 1, total_steps, loss.item()))
                        sys.stdout.flush()

                te_losses[epoch] = self.eval(ts_loader)
                if te_losses[epoch] < best_loss:
                    best_loss = te_losses[epoch]
                    # Save the model checkpoints
                    torch.save(self.model.state_dict(), './model_UNet.ckpt')
                sys.stdout.write(
                    '\nEND Epoch [{:>3}/{}] | Train loss: {:.4f} | Test loss: {:.4f} '
                            .format(epoch + 1, num_epochs, losses[epoch], te_losses[epoch]))
                sys.stdout.flush()
                
        except KeyboardInterrupt:
            print('\n')
            print('-' * 89)
            print('[INFO] Exiting from training early')
        print(f'\n[INFO] Training phase... Elapsed time: {(time() - start):.0f} seconds\n')
        return losses[:epoch], tr_losses[:epoch], te_losses[:epoch]

# Data Loading

In [None]:
WORK_AREA = './data/BSD70'

In [None]:
batch_size = 32
epochs = 150

print("[INFO] loading dataset...")

transform = transforms.Compose([transforms.ToTensor()])

class CustomDataset(Dataset):
    def __init__(self, main_dir, train, transform, attack_type=None):
        self.main_dir = main_dir
        self.list_imgs = train
        self.transform = transform
        self.attack_type = attack_type

    def __len__(self):
        return len(self.list_imgs)

    def __getitem__(self, idx):
        clean_img_loc = os.path.join(self.main_dir, 'clean', self.list_imgs[idx])
        clean_image = Image.open(clean_img_loc).convert("RGB")
        tensor_clean_image = self.transform(clean_image)
        
        mal_img_loc = os.path.join(self.main_dir, self.attack_type, self.list_imgs[idx])
        malware_image = Image.open(mal_img_loc).convert("RGB")
        tensor_mal_image = self.transform(malware_image)
        
        
        return self.list_imgs[idx], tensor_mal_image, tensor_clean_image


# Define train, val, test

clean_data = natsort.natsorted(os.listdir(os.path.join(WORK_AREA, 'clean')))

train, test = train_test_split(clean_data, test_size=0.33, random_state=0)
train, val = train_test_split(train, test_size=int(0.33 * len(clean_data)), random_state=0) 
    
    
train_set_classic_rows = CustomDataset(WORK_AREA, train, transform=transform, attack_type='lsb_classic/interleaving_rows')

train_set_classic_squares = CustomDataset(WORK_AREA, train, transform=transform, attack_type='lsb_classic/interleaving_squares')

train_set_classic_sequential = CustomDataset(WORK_AREA, train, transform=transform, attack_type='lsb_classic/sequential')


train_set_oceanlotus_rows = CustomDataset(WORK_AREA, train, transform=transform, attack_type='lsb_oceanlotus/interleaving_rows')

train_set_oceanlotus_squares = CustomDataset(WORK_AREA, train,
                          transform=transform, attack_type='lsb_oceanlotus/interleaving_squares')

train_set_oceanlotus_sequential = CustomDataset(WORK_AREA, train,
                          transform=transform, attack_type='lsb_oceanlotus/sequential')


train_set = torch.utils.data.ConcatDataset([train_set_classic_rows, train_set_classic_squares,
                                            train_set_classic_sequential, train_set_oceanlotus_rows,
                                            train_set_oceanlotus_squares, train_set_oceanlotus_sequential])

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)


val_set_classic_rows = CustomDataset(WORK_AREA, val,
                          transform=transform, attack_type='lsb_classic/interleaving_rows')

val_set_classic_squares = CustomDataset(WORK_AREA, val,
                          transform=transform, attack_type='lsb_classic/interleaving_squares')

val_set_classic_sequential = CustomDataset(WORK_AREA, val,
                          transform=transform, attack_type='lsb_classic/sequential')


val_set_oceanlotus_rows = CustomDataset(WORK_AREA, val,
                          transform=transform, attack_type='lsb_oceanlotus/interleaving_rows')

val_set_oceanlotus_squares = CustomDataset(WORK_AREA, val,
                          transform=transform, attack_type='lsb_oceanlotus/interleaving_squares')

val_set_oceanlotus_sequential = CustomDataset(WORK_AREA, val,
                          transform=transform, attack_type='lsb_oceanlotus/sequential')

val_set = torch.utils.data.ConcatDataset([val_set_classic_rows, val_set_classic_squares,
                                            val_set_classic_sequential, val_set_oceanlotus_rows,
                                            val_set_oceanlotus_squares, val_set_oceanlotus_sequential])

val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size)


test_set_classic_rows = CustomDataset(WORK_AREA, test,
                          transform=transform, attack_type='lsb_classic/interleaving_rows')
test_set_classic_squares = CustomDataset(WORK_AREA, test, transform=transform, attack_type='lsb_classic/interleaving_squares')

test_set_classic_sequential = CustomDataset(WORK_AREA, test,
                          transform=transform, attack_type='lsb_classic/sequential')

test_set_oceanlotus_rows = CustomDataset(WORK_AREA, test,
                          transform=transform, attack_type='lsb_oceanlotus/interleaving_rows')

test_set_oceanlotus_squares = CustomDataset(WORK_AREA, test,
                          transform=transform, attack_type='lsb_oceanlotus/interleaving_squares')

test_set_oceanlotus_sequential = CustomDataset(WORK_AREA, test,
                          transform=transform, attack_type='lsb_oceanlotus/sequential')


test_set_classic_rows_loader = torch.utils.data.DataLoader(test_set_classic_rows, batch_size=batch_size)
test_set_classic_squares_loader = torch.utils.data.DataLoader(test_set_classic_squares, batch_size=batch_size)
test_set_classic_sequential_loader = torch.utils.data.DataLoader(test_set_classic_sequential, batch_size=batch_size)
test_set_oceanlotus_rows_loader = torch.utils.data.DataLoader(test_set_oceanlotus_rows, batch_size=batch_size)
test_set_oceanlotus_squares_loader = torch.utils.data.DataLoader(test_set_oceanlotus_squares, batch_size=batch_size)
test_set_oceanlotus_sequential_loader = torch.utils.data.DataLoader(test_set_oceanlotus_sequential, batch_size=batch_size)

print("[INFO] done")



# Model learning


In [None]:
model_trainer = Trainer(device,use_mult_gpu=False)

In [None]:
losses_cum = model_trainer.train(train_loader, val_loader, batch_size=batch_size, num_epochs=epochs)

In [None]:
# Save the model checkpoints
torch.save(model_trainer.model.state_dict(), './model_UNet.ckpt')

In [None]:
losses, tr_losses, te_losses = losses_cum

num_epochs = len(losses)

plt.figure(figsize=(15, 10))
plt.xlim(0, num_epochs + 1)
plt.plot(range(1, num_epochs + 1), losses[:num_epochs], label='Training Loss')
plt.plot(range(1, num_epochs + 1), te_losses[:num_epochs], label='Test Loss')
plt.legend()
plt.show()



In [None]:
from sklearn.metrics import mean_squared_error as MSE
from sklearn.metrics import mean_absolute_error as MAE
from skimage.metrics import peak_signal_noise_ratio as PSNR
from skimage.metrics import structural_similarity as SSIM

In [None]:
the_model = model_trainer.model

the_model.load_state_dict(torch.load('./model_UNet.ckpt'))

outdir = os.path.join(WORK_AREA, 'sanitized/test/lsb_classic/sequential')

the_model.eval()
printout=True


mal_distances = clean_distances = 0.
total_mse = total_mae = total_psnr = cnt = 0
total_mse_secret = total_mse_not_secret = 0
total_psnr_secret = total_psnr_not_secret = 0

total_mse_se = total_mse_secret_se = total_mse_not_secret_se = 0
total_psnr_se = total_psnr_secret_se = total_psnr_not_secret_se = 0

total_mse_no = total_mse_secret_no = total_mse_not_secret_no = 0
total_psnr_no = total_psnr_secret_no = total_psnr_not_secret_no = 0

with torch.no_grad():
    for name, x,y in test_set_classic_sequential_loader:
        bs = y.size(0)
        x = x.to(device)
        start = time()
        
        y_rec = the_model(x,predict=True).detach().cpu()
        x = x.cpu()
        
        clean_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-y_rec.view(bs,-1))),-1)
        mal_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-x.view(bs,-1))),-1)

        total_mse += MSE(y.view(bs,-1), y_rec.view(bs,-1))
        total_psnr += PSNR(y.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_se += MSE(x.view(bs,-1), y_rec.view(bs,-1))
        total_psnr_se += PSNR(x.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_no += MSE(x.view(bs,-1), y.view(bs,-1))
        total_psnr_no += PSNR(x.view(bs,-1).numpy(), y.view(bs,-1).numpy())
        
        cnt += 1
        
        if printout:
            for i in range(bs):
                image = transforms.ToPILImage()(y_rec[i])
                image.save(os.path.join(outdir, 'UNet_sanitized_' + name[i]))


print(f'Difference w.r.t. malicious:\t{mal_distances}\nDifference .w.r.t. sanitized:\t{clean_distances}')
print('Original --> x_mal w.r.t. y_clean')
print(f'MSE: {(total_mse_no/cnt).item()}, PSRN: {(total_psnr_no/cnt).item()}')

print('Image preservation --> y_clean w.r.t. x_rec')
print(f'MSE: {(total_mse/cnt).item()}, PSRN: {(total_psnr/cnt).item()}')

print('Secret elimination --> x_mal w.r.t. x_rec')
print(f'MSE: {(total_mse_se/cnt).item()}, PSRN: {(total_psnr_se/cnt).item()}')

In [None]:
the_model = model_trainer.model

the_model.load_state_dict(torch.load('./model_UNet.ckpt'))

outdir = os.path.join(WORK_AREA, 'sanitized/test/lsb_classic/interleaving_rows')

the_model.eval()
printout=False


mal_distances = clean_distances = mal_distances_original = 0.

total_mse = total_mae = total_psnr = cnt = 0
total_mse_secret = total_mse_not_secret = 0
total_psnr_secret = total_psnr_not_secret = 0

total_mse_se = total_mse_secret_se = total_mse_not_secret_se = 0
total_psnr_se = total_psnr_secret_se = total_psnr_not_secret_se = 0

total_mse_no = total_mse_secret_no = total_mse_not_secret_no = 0
total_psnr_no = total_psnr_secret_no = total_psnr_not_secret_no = 0

with torch.no_grad():
    for name, x,y in test_set_classic_rows_loader:
        bs = y.size(0)
        x = x.to(device)
        start = time()
        
        y_rec = the_model(x,predict=True).detach().cpu()
        x = x.cpu()

        clean_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-y_rec.view(bs,-1))),-1)
        mal_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-x.view(bs,-1))),-1)

        total_mse += MSE(y.view(bs,-1), y_rec.view(bs,-1))
        total_psnr += PSNR(y.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_se += MSE(x.view(bs,-1), y_rec.view(bs,-1))
        total_psnr_se += PSNR(x.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_no += MSE(x.view(bs,-1), y.view(bs,-1))
        total_psnr_no += PSNR(x.view(bs,-1).numpy(), y.view(bs,-1).numpy())
        
        cnt += 1
        
        if printout:
            for i in range(bs):
                image = transforms.ToPILImage()(y_rec[i])
                image.save(os.path.join(outdir, 'UNet_sanitized_' + name[i]))


print(f'Difference w.r.t. malicious:\t{mal_distances}\nDifference .w.r.t. sanitized:\t{clean_distances}')
print('Original --> x_mal w.r.t. y_clean')
print(f'MSE: {(total_mse_no/cnt).item()}, PSRN: {(total_psnr_no/cnt).item()}')

print('Image preservation --> y_clean w.r.t. x_rec')
print(f'MSE: {(total_mse/cnt).item()}, PSRN: {(total_psnr/cnt).item()}')

print('Secret elimination --> x_mal w.r.t. x_rec')
print(f'MSE: {(total_mse_se/cnt).item()}, PSRN: {(total_psnr_se/cnt).item()}')

In [None]:
the_model = model_trainer.model

the_model.load_state_dict(torch.load('./model_UNet.ckpt'))

outdir = os.path.join(WORK_AREA, 'sanitized/train/lsb_classic/interleaving_rows')

train_set_classic_rows_loader = torch.utils.data.DataLoader(train_set_classic_rows, batch_size=batch_size)


the_model.eval()
printout=True


mal_distances = clean_distances = 0.
total_mse = total_mae = total_psnr = cnt = 0
total_mse_secret = total_mse_not_secret = 0
total_psnr_secret = total_psnr_not_secret = 0

total_mse_se = total_mse_secret_se = total_mse_not_secret_se = 0
total_psnr_se = total_psnr_secret_se = total_psnr_not_secret_se = 0

total_mse_no = total_mse_secret_no = total_mse_not_secret_no = 0
total_psnr_no = total_psnr_secret_no = total_psnr_not_secret_no = 0

with torch.no_grad():
    for name, x,y in train_set_classic_rows_loader:
        bs = y.size(0)
        x = x.to(device)
        start = time()
        y_rec = the_model(x,predict=True).detach().cpu()
        x = x.cpu()

        clean_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-y_rec.view(bs,-1))),-1)
        mal_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-x.view(bs,-1))),-1)

        total_mse += MSE(y.view(bs,-1), y_rec.view(bs,-1))
        total_psnr += PSNR(y.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_se += MSE(x.view(bs,-1), y_rec.view(bs,-1))
        total_psnr_se += PSNR(x.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_no += MSE(x.view(bs,-1), y.view(bs,-1))
        total_psnr_no += PSNR(x.view(bs,-1).numpy(), y.view(bs,-1).numpy())
        
        cnt += 1
        
        if printout:
            for i in range(bs):
                image = transforms.ToPILImage()(y_rec[i])
                image.save(os.path.join(outdir, 'UNet_sanitized_' + name[i]))


print(f'Difference w.r.t. malicious:\t{mal_distances}\nDifference .w.r.t. sanitized:\t{clean_distances}')
print('Original --> x_mal w.r.t. y_clean')
print(f'MSE: {(total_mse_no/cnt).item()}, PSRN: {(total_psnr_no/cnt).item()}')

print('Image preservation --> y_clean w.r.t. x_rec')
print(f'MSE: {(total_mse/cnt).item()}, PSRN: {(total_psnr/cnt).item()}')

print('Secret elimination --> x_mal w.r.t. x_rec')
print(f'MSE: {(total_mse_se/cnt).item()}, PSRN: {(total_psnr_se/cnt).item()}')

In [None]:
the_model = model_trainer.model

the_model.load_state_dict(torch.load('./model_UNet.ckpt'))

outdir = os.path.join(WORK_AREA, 'sanitized/val/lsb_classic/interleaving_rows')

val_set_classic_rows_loader = torch.utils.data.DataLoader(val_set_classic_rows, batch_size=batch_size)


the_model.eval()
printout=True


mal_distances = clean_distances = 0.
total_mse = total_mae = total_psnr = cnt = 0
total_mse_secret = total_mse_not_secret = 0
total_psnr_secret = total_psnr_not_secret = 0

total_mse_se = total_mse_secret_se = total_mse_not_secret_se = 0
total_psnr_se = total_psnr_secret_se = total_psnr_not_secret_se = 0

total_mse_no = total_mse_secret_no = total_mse_not_secret_no = 0
total_psnr_no = total_psnr_secret_no = total_psnr_not_secret_no = 0

with torch.no_grad():
    for name, x,y in val_set_classic_rows_loader:
        bs = y.size(0)
        x = x.to(device)
        start = time()
        y_rec = the_model(x,predict=True).detach().cpu()
        x = x.cpu()

        clean_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-y_rec.view(bs,-1))),-1)
        mal_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-x.view(bs,-1))),-1)

        total_mse += MSE(y.view(bs,-1), y_rec.view(bs,-1))
        total_psnr += PSNR(y.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_se += MSE(x.view(bs,-1), y_rec.view(bs,-1))
        total_psnr_se += PSNR(x.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_no += MSE(x.view(bs,-1), y.view(bs,-1))
        total_psnr_no += PSNR(x.view(bs,-1).numpy(), y.view(bs,-1).numpy())
        
        cnt += 1
        
        if printout:
            for i in range(bs):
                image = transforms.ToPILImage()(y_rec[i])
                image.save(os.path.join(outdir, 'UNet_sanitized_' + name[i]))


print(f'Difference w.r.t. malicious:\t{mal_distances}\nDifference .w.r.t. sanitized:\t{clean_distances}')
print('Original --> x_mal w.r.t. y_clean')
print(f'MSE: {(total_mse_no/cnt).item()}, PSRN: {(total_psnr_no/cnt).item()}')

print('Image preservation --> y_clean w.r.t. x_rec')
print(f'MSE: {(total_mse/cnt).item()}, PSRN: {(total_psnr/cnt).item()}')

print('Secret elimination --> x_mal w.r.t. x_rec')
print(f'MSE: {(total_mse_se/cnt).item()}, PSRN: {(total_psnr_se/cnt).item()}')

In [None]:
the_model = model_trainer.model

the_model.load_state_dict(torch.load('./model_UNet.ckpt'))

outdir = os.path.join(WORK_AREA, 'sanitized/test/lsb_classic/interleaving_squares')

the_model.eval()
printout=True


mal_distances = clean_distances = 0.
total_mse = total_mae = total_psnr = cnt = 0
total_mse_secret = total_mse_not_secret = 0
total_psnr_secret = total_psnr_not_secret = 0

total_mse_se = total_mse_secret_se = total_mse_not_secret_se = 0
total_psnr_se = total_psnr_secret_se = total_psnr_not_secret_se = 0

total_mse_no = total_mse_secret_no = total_mse_not_secret_no = 0
total_psnr_no = total_psnr_secret_no = total_psnr_not_secret_no = 0


with torch.no_grad():
    for name, x,y in test_set_classic_squares_loader:
        bs = y.size(0)
        x = x.to(device)
        start = time()
        
        y_rec = the_model(x,predict=True).detach().cpu()
        x = x.cpu()

        clean_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-y_rec.view(bs,-1))),-1)
        mal_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-x.view(bs,-1))),-1)

        total_mse += MSE(y.view(bs,-1), y_rec.view(bs,-1))
        total_psnr += PSNR(y.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_se += MSE(x.view(bs,-1), y_rec.view(bs,-1))
        total_psnr_se += PSNR(x.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_no += MSE(x.view(bs,-1), y.view(bs,-1))
        total_psnr_no += PSNR(x.view(bs,-1).numpy(), y.view(bs,-1).numpy())
        
        cnt += 1
        
        if printout:
            for i in range(bs):
                image = transforms.ToPILImage()(y_rec[i])
                image.save(os.path.join(outdir, 'UNet_sanitized_' + name[i]))


print(f'Difference w.r.t. malicious:\t{mal_distances}\nDifference .w.r.t. sanitized:\t{clean_distances}')
print('Original --> x_mal w.r.t. y_clean')
print(f'MSE: {(total_mse_no/cnt).item()}, PSRN: {(total_psnr_no/cnt).item()}')

print('Image preservation --> y_clean w.r.t. x_rec')
print(f'MSE: {(total_mse/cnt).item()}, PSRN: {(total_psnr/cnt).item()}')

print('Secret elimination --> x_mal w.r.t. x_rec')
print(f'MSE: {(total_mse_se/cnt).item()}, PSRN: {(total_psnr_se/cnt).item()}')

In [None]:
the_model = model_trainer.model

the_model.load_state_dict(torch.load('./model_UNet.ckpt'))

outdir = os.path.join(WORK_AREA, 'sanitized/train/lsb_classic/interleaving_squares')

train_set_classic_squares_loader = torch.utils.data.DataLoader(train_set_classic_squares, batch_size=batch_size)


the_model.eval()
printout=True


mal_distances = clean_distances = 0.
total_mse = total_mae = total_psnr = cnt = 0
total_mse_secret = total_mse_not_secret = 0
total_psnr_secret = total_psnr_not_secret = 0

total_mse_se = total_mse_secret_se = total_mse_not_secret_se = 0
total_psnr_se = total_psnr_secret_se = total_psnr_not_secret_se = 0

total_mse_no = total_mse_secret_no = total_mse_not_secret_no = 0
total_psnr_no = total_psnr_secret_no = total_psnr_not_secret_no = 0

with torch.no_grad():
    for name, x,y in train_set_classic_squares_loader:
        bs = y.size(0)
        x = x.to(device)
        start = time()
        y_rec = the_model(x,predict=True).detach().cpu()
        x = x.cpu()

        clean_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-y_rec.view(bs,-1))),-1)
        mal_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-x.view(bs,-1))),-1)

        total_mse += MSE(y.view(bs,-1), y_rec.view(bs,-1))
        total_psnr += PSNR(y.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_se += MSE(x.view(bs,-1), y_rec.view(bs,-1))
        total_psnr_se += PSNR(x.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_no += MSE(x.view(bs,-1), y.view(bs,-1))
        total_psnr_no += PSNR(x.view(bs,-1).numpy(), y.view(bs,-1).numpy())
        
        cnt += 1
        
        if printout:
            for i in range(bs):
                image = transforms.ToPILImage()(y_rec[i])
                image.save(os.path.join(outdir, 'UNet_sanitized_' + name[i]))


print(f'Difference w.r.t. malicious:\t{mal_distances}\nDifference .w.r.t. sanitized:\t{clean_distances}')
print('Original --> x_mal w.r.t. y_clean')
print(f'MSE: {(total_mse_no/cnt).item()}, PSRN: {(total_psnr_no/cnt).item()}')

print('Image preservation --> y_clean w.r.t. x_rec')
print(f'MSE: {(total_mse/cnt).item()}, PSRN: {(total_psnr/cnt).item()}')

print('Secret elimination --> x_mal w.r.t. x_rec')
print(f'MSE: {(total_mse_se/cnt).item()}, PSRN: {(total_psnr_se/cnt).item()}')

In [None]:
the_model = model_trainer.model

the_model.load_state_dict(torch.load('./model_UNet.ckpt'))

outdir = os.path.join(WORK_AREA, 'sanitized/val/lsb_classic/interleaving_squares')

val_set_classic_squares_loader = torch.utils.data.DataLoader(val_set_classic_squares, batch_size=batch_size)


the_model.eval()
printout=True


mal_distances = clean_distances = 0.
total_mse = total_mae = total_psnr = cnt = 0
total_mse_secret = total_mse_not_secret = 0
total_psnr_secret = total_psnr_not_secret = 0

total_mse_se = total_mse_secret_se = total_mse_not_secret_se = 0
total_psnr_se = total_psnr_secret_se = total_psnr_not_secret_se = 0

total_mse_no = total_mse_secret_no = total_mse_not_secret_no = 0
total_psnr_no = total_psnr_secret_no = total_psnr_not_secret_no = 0

with torch.no_grad():
    for name, x,y in val_set_classic_squares_loader:
        bs = y.size(0)
        x = x.to(device)
        start = time()
        y_rec = the_model(x,predict=True).detach().cpu()
        x = x.cpu()

        clean_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-y_rec.view(bs,-1))),-1)
        mal_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-x.view(bs,-1))),-1)

        total_mse += MSE(y.view(bs,-1), y_rec.view(bs,-1))
        total_psnr += PSNR(y.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_se += MSE(x.view(bs,-1), y_rec.view(bs,-1))
        total_psnr_se += PSNR(x.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_no += MSE(x.view(bs,-1), y.view(bs,-1))
        total_psnr_no += PSNR(x.view(bs,-1).numpy(), y.view(bs,-1).numpy())
        
        cnt += 1
        
        if printout:
            for i in range(bs):
                image = transforms.ToPILImage()(y_rec[i])
                image.save(os.path.join(outdir, 'UNet_sanitized_' + name[i]))


print(f'Difference w.r.t. malicious:\t{mal_distances}\nDifference .w.r.t. sanitized:\t{clean_distances}')
print('Original --> x_mal w.r.t. y_clean')
print(f'MSE: {(total_mse_no/cnt).item()}, PSRN: {(total_psnr_no/cnt).item()}')

print('Image preservation --> y_clean w.r.t. x_rec')
print(f'MSE: {(total_mse/cnt).item()}, PSRN: {(total_psnr/cnt).item()}')

print('Secret elimination --> x_mal w.r.t. x_rec')
print(f'MSE: {(total_mse_se/cnt).item()}, PSRN: {(total_psnr_se/cnt).item()}')

In [None]:
the_model = model_trainer.model

the_model.load_state_dict(torch.load('./model_UNet.ckpt'))

outdir = os.path.join(WORK_AREA, 'sanitized/train/lsb_classic/sequential')

train_set_classic_sequential_loader = torch.utils.data.DataLoader(train_set_classic_sequential, batch_size=batch_size)


the_model.eval()
printout=True


mal_distances = clean_distances = 0.
total_mse = total_mae = total_psnr = cnt = 0
total_mse_secret = total_mse_not_secret = 0
total_psnr_secret = total_psnr_not_secret = 0

total_mse_se = total_mse_secret_se = total_mse_not_secret_se = 0
total_psnr_se = total_psnr_secret_se = total_psnr_not_secret_se = 0

total_mse_no = total_mse_secret_no = total_mse_not_secret_no = 0
total_psnr_no = total_psnr_secret_no = total_psnr_not_secret_no = 0

with torch.no_grad():
    for name, x,y in train_set_classic_sequential_loader:
        bs = y.size(0)
        x = x.to(device)
        start = time()
        y_rec = the_model(x,predict=True).detach().cpu()
        x = x.cpu()

        clean_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-y_rec.view(bs,-1))),-1)
        mal_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-x.view(bs,-1))),-1)

        total_mse += MSE(y.view(bs,-1), y_rec.view(bs,-1))
        total_psnr += PSNR(y.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_se += MSE(x.view(bs,-1), y_rec.view(bs,-1))
        total_psnr_se += PSNR(x.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_no += MSE(x.view(bs,-1), y.view(bs,-1))
        total_psnr_no += PSNR(x.view(bs,-1).numpy(), y.view(bs,-1).numpy())
        
        cnt += 1
        
        if printout:
            for i in range(bs):
                image = transforms.ToPILImage()(y_rec[i])
                image.save(os.path.join(outdir, 'UNet_sanitized_' + name[i]))


print(f'Difference w.r.t. malicious:\t{mal_distances}\nDifference .w.r.t. sanitized:\t{clean_distances}')
print('Original --> x_mal w.r.t. y_clean')
print(f'MSE: {(total_mse_no/cnt).item()}, PSRN: {(total_psnr_no/cnt).item()}')

print('Image preservation --> y_clean w.r.t. x_rec')
print(f'MSE: {(total_mse/cnt).item()}, PSRN: {(total_psnr/cnt).item()}')

print('Secret elimination --> x_mal w.r.t. x_rec')
print(f'MSE: {(total_mse_se/cnt).item()}, PSRN: {(total_psnr_se/cnt).item()}')
print(f'Difference w.r.t. malicious:\t{mal_distances}\nDifference .w.r.t. sanitized:\t{clean_distances}')

In [None]:
the_model = model_trainer.model

the_model.load_state_dict(torch.load('./model_UNet.ckpt'))

outdir = os.path.join(WORK_AREA, 'sanitized/val/lsb_classic/sequential')

val_set_classic_sequential_loader = torch.utils.data.DataLoader(val_set_classic_sequential, batch_size=batch_size)


the_model.eval()
printout=True


mal_distances = clean_distances = 0.
total_mse = total_mae = total_psnr = cnt = 0
total_mse_secret = total_mse_not_secret = 0
total_psnr_secret = total_psnr_not_secret = 0

total_mse_se = total_mse_secret_se = total_mse_not_secret_se = 0
total_psnr_se = total_psnr_secret_se = total_psnr_not_secret_se = 0

total_mse_no = total_mse_secret_no = total_mse_not_secret_no = 0
total_psnr_no = total_psnr_secret_no = total_psnr_not_secret_no = 0

with torch.no_grad():
    for name, x,y in val_set_classic_sequential_loader:
        bs = y.size(0)
        x = x.to(device)
        start = time()
        y_rec = the_model(x,predict=True).detach().cpu()
        x = x.cpu()

        clean_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-y_rec.view(bs,-1))),-1)
        mal_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-x.view(bs,-1))),-1)

        total_mse += MSE(y.view(bs,-1), y_rec.view(bs,-1))
        total_psnr += PSNR(y.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_se += MSE(x.view(bs,-1), y_rec.view(bs,-1))
        total_psnr_se += PSNR(x.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_no += MSE(x.view(bs,-1), y.view(bs,-1))
        total_psnr_no += PSNR(x.view(bs,-1).numpy(), y.view(bs,-1).numpy())
        
        cnt += 1
        
        if printout:
            for i in range(bs):
                image = transforms.ToPILImage()(y_rec[i])
                image.save(os.path.join(outdir, 'UNet_sanitized_' + name[i]))


print(f'Difference w.r.t. malicious:\t{mal_distances}\nDifference .w.r.t. sanitized:\t{clean_distances}')
print('Original --> x_mal w.r.t. y_clean')
print(f'MSE: {(total_mse_no/cnt).item()}, PSRN: {(total_psnr_no/cnt).item()}')

print('Image preservation --> y_clean w.r.t. x_rec')
print(f'MSE: {(total_mse/cnt).item()}, PSRN: {(total_psnr/cnt).item()}')

print('Secret elimination --> x_mal w.r.t. x_rec')
print(f'MSE: {(total_mse_se/cnt).item()}, PSRN: {(total_psnr_se/cnt).item()}')

In [None]:
the_model = model_trainer.model

the_model.load_state_dict(torch.load('./model_UNet.ckpt'))

outdir = os.path.join(WORK_AREA, 'sanitized/test/lsb_oceanlotus/interleaving_rows')

the_model.eval()
printout=True


mal_distances = clean_distances = 0.
total_mse = total_mae = total_psnr = cnt = 0
total_mse_secret = total_mse_not_secret = 0
total_psnr_secret = total_psnr_not_secret = 0

total_mse_se = total_mse_secret_se = total_mse_not_secret_se = 0
total_psnr_se = total_psnr_secret_se = total_psnr_not_secret_se = 0

total_mse_no = total_mse_secret_no = total_mse_not_secret_no = 0
total_psnr_no = total_psnr_secret_no = total_psnr_not_secret_no = 0

with torch.no_grad():
    for name, x, y in test_set_oceanlotus_rows_loader:
        bs = y.size(0)
        x = x.to(device)
        start = time()
        
        y_rec = the_model(x,predict=True).detach().cpu()
        x = x.cpu()

        clean_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-y_rec.view(bs,-1))),-1)
        mal_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-x.view(bs,-1))),-1)

        total_mse += MSE(y.view(bs,-1), y_rec.view(bs,-1))
        total_psnr += PSNR(y.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_se += MSE(x.view(bs,-1), y_rec.view(bs,-1))
        total_psnr_se += PSNR(x.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_no += MSE(x.view(bs,-1), y.view(bs,-1))
        total_psnr_no += PSNR(x.view(bs,-1).numpy(), y.view(bs,-1).numpy())
        
        cnt += 1
        
        if printout:
            for i in range(bs):
                image = transforms.ToPILImage()(y_rec[i])
                image.save(os.path.join(outdir, 'UNet_sanitized_' + name[i]))


print(f'Difference w.r.t. malicious:\t{mal_distances}\nDifference .w.r.t. sanitized:\t{clean_distances}')
print('Original --> x_mal w.r.t. y_clean')
print(f'MSE: {(total_mse_no/cnt).item()}, PSRN: {(total_psnr_no/cnt).item()}')

print('Image preservation --> y_clean w.r.t. x_rec')
print(f'MSE: {(total_mse/cnt).item()}, PSRN: {(total_psnr/cnt).item()}')

print('Secret elimination --> x_mal w.r.t. x_rec')
print(f'MSE: {(total_mse_se/cnt).item()}, PSRN: {(total_psnr_se/cnt).item()}')

In [None]:
the_model = model_trainer.model

the_model.load_state_dict(torch.load('./model_UNet.ckpt'))

outdir = os.path.join(WORK_AREA, 'sanitized/train/lsb_oceanlotus/interleaving_rows')

train_set_oceanlotus_rows_loader = torch.utils.data.DataLoader(train_set_oceanlotus_rows, batch_size=batch_size)


the_model.eval()
printout=True


mal_distances = clean_distances = 0.
total_mse = total_mae = total_psnr = cnt = 0
total_mse_secret = total_mse_not_secret = 0
total_psnr_secret = total_psnr_not_secret = 0

total_mse_se = total_mse_secret_se = total_mse_not_secret_se = 0
total_psnr_se = total_psnr_secret_se = total_psnr_not_secret_se = 0

total_mse_no = total_mse_secret_no = total_mse_not_secret_no = 0
total_psnr_no = total_psnr_secret_no = total_psnr_not_secret_no = 0

with torch.no_grad():
    for name, x,y in train_set_oceanlotus_rows_loader:
        bs = y.size(0)
        x = x.to(device)
        start = time()
        y_rec = the_model(x,predict=True).detach().cpu()
        x = x.cpu()

        clean_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-y_rec.view(bs,-1))),-1)
        mal_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-x.view(bs,-1))),-1)

        total_mse += MSE(y.view(bs,-1), y_rec.view(bs,-1))
        total_psnr += PSNR(y.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_se += MSE(x.view(bs,-1), y_rec.view(bs,-1))
        total_psnr_se += PSNR(x.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_no += MSE(x.view(bs,-1), y.view(bs,-1))
        total_psnr_no += PSNR(x.view(bs,-1).numpy(), y.view(bs,-1).numpy())
        
        cnt += 1
        
        if printout:
            for i in range(bs):
                image = transforms.ToPILImage()(y_rec[i])
                image.save(os.path.join(outdir, 'UNet_sanitized_' + name[i]))


print(f'Difference w.r.t. malicious:\t{mal_distances}\nDifference .w.r.t. sanitized:\t{clean_distances}')
print('Original --> x_mal w.r.t. y_clean')
print(f'MSE: {(total_mse_no/cnt).item()}, PSRN: {(total_psnr_no/cnt).item()}')

print('Image preservation --> y_clean w.r.t. x_rec')
print(f'MSE: {(total_mse/cnt).item()}, PSRN: {(total_psnr/cnt).item()}')

print('Secret elimination --> x_mal w.r.t. x_rec')
print(f'MSE: {(total_mse_se/cnt).item()}, PSRN: {(total_psnr_se/cnt).item()}')

In [None]:
the_model = model_trainer.model

the_model.load_state_dict(torch.load('./model_UNet.ckpt'))

outdir = os.path.join(WORK_AREA, 'sanitized/val/lsb_oceanlotus/interleaving_rows')

val_set_oceanlotus_rows_loader = torch.utils.data.DataLoader(val_set_oceanlotus_rows, batch_size=batch_size)


the_model.eval()
printout=True


mal_distances = clean_distances = 0.
total_mse = total_mae = total_psnr = cnt = 0
total_mse_secret = total_mse_not_secret = 0
total_psnr_secret = total_psnr_not_secret = 0

total_mse_se = total_mse_secret_se = total_mse_not_secret_se = 0
total_psnr_se = total_psnr_secret_se = total_psnr_not_secret_se = 0

total_mse_no = total_mse_secret_no = total_mse_not_secret_no = 0
total_psnr_no = total_psnr_secret_no = total_psnr_not_secret_no = 0

with torch.no_grad():
    for name, x,y in val_set_oceanlotus_rows_loader:
        bs = y.size(0)
        x = x.to(device)
        start = time()
        y_rec = the_model(x,predict=True).detach().cpu()
        x = x.cpu()

        clean_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-y_rec.view(bs,-1))),-1)
        mal_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-x.view(bs,-1))),-1)

        total_mse += MSE(y.view(bs,-1), y_rec.view(bs,-1))
        total_psnr += PSNR(y.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_se += MSE(x.view(bs,-1), y_rec.view(bs,-1))
        total_psnr_se += PSNR(x.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_no += MSE(x.view(bs,-1), y.view(bs,-1))
        total_psnr_no += PSNR(x.view(bs,-1).numpy(), y.view(bs,-1).numpy())
        
        cnt += 1
        
        if printout:
            for i in range(bs):
                image = transforms.ToPILImage()(y_rec[i])
                image.save(os.path.join(outdir, 'UNet_sanitized_' + name[i]))


print(f'Difference w.r.t. malicious:\t{mal_distances}\nDifference .w.r.t. sanitized:\t{clean_distances}')
print('Original --> x_mal w.r.t. y_clean')
print(f'MSE: {(total_mse_no/cnt).item()}, PSRN: {(total_psnr_no/cnt).item()}')

print('Image preservation --> y_clean w.r.t. x_rec')
print(f'MSE: {(total_mse/cnt).item()}, PSRN: {(total_psnr/cnt).item()}')

print('Secret elimination --> x_mal w.r.t. x_rec')
print(f'MSE: {(total_mse_se/cnt).item()}, PSRN: {(total_psnr_se/cnt).item()}')

In [None]:
the_model = model_trainer.model

the_model.load_state_dict(torch.load('./model_UNet.ckpt'))

outdir = os.path.join(WORK_AREA, 'sanitized/test/lsb_oceanlotus/interleaving_squares')

the_model.eval()
printout=True


mal_distances = clean_distances = 0.
total_mse = total_mae = total_psnr = cnt = 0
total_mse_secret = total_mse_not_secret = 0
total_psnr_secret = total_psnr_not_secret = 0

total_mse_se = total_mse_secret_se = total_mse_not_secret_se = 0
total_psnr_se = total_psnr_secret_se = total_psnr_not_secret_se = 0

total_mse_no = total_mse_secret_no = total_mse_not_secret_no = 0
total_psnr_no = total_psnr_secret_no = total_psnr_not_secret_no = 0

with torch.no_grad():
    for name, x,y in test_set_oceanlotus_squares_loader:
        bs = y.size(0)
        x = x.to(device)
        start = time()
        
        y_rec = the_model(x,predict=True).detach().cpu()
        x = x.cpu()

        clean_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-y_rec.view(bs,-1))),-1)
        mal_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-x.view(bs,-1))),-1)

        total_mse += MSE(y.view(bs,-1), y_rec.view(bs,-1))
        total_psnr += PSNR(y.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_se += MSE(x.view(bs,-1), y_rec.view(bs,-1))
        total_psnr_se += PSNR(x.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_no += MSE(x.view(bs,-1), y.view(bs,-1))
        total_psnr_no += PSNR(x.view(bs,-1).numpy(), y.view(bs,-1).numpy())
        
        cnt += 1
        
        if printout:
            for i in range(bs):
                image = transforms.ToPILImage()(y_rec[i])
                image.save(os.path.join(outdir, 'UNet_sanitized_' + name[i]))


print(f'Difference w.r.t. malicious:\t{mal_distances}\nDifference .w.r.t. sanitized:\t{clean_distances}')
print('Original --> x_mal w.r.t. y_clean')
print(f'MSE: {(total_mse_no/cnt).item()}, PSRN: {(total_psnr_no/cnt).item()}')

print('Image preservation --> y_clean w.r.t. x_rec')
print(f'MSE: {(total_mse/cnt).item()}, PSRN: {(total_psnr/cnt).item()}')

print('Secret elimination --> x_mal w.r.t. x_rec')
print(f'MSE: {(total_mse_se/cnt).item()}, PSRN: {(total_psnr_se/cnt).item()}')

In [None]:
the_model = model_trainer.model

the_model.load_state_dict(torch.load('./model_UNet.ckpt'))

outdir = os.path.join(WORK_AREA, 'sanitized/train/lsb_oceanlotus/interleaving_squares')

train_set_oceanlotus_squares_loader = torch.utils.data.DataLoader(train_set_oceanlotus_squares, batch_size=batch_size)


the_model.eval()
printout=True


mal_distances = clean_distances = 0.
total_mse = total_mae = total_psnr = cnt = 0
total_mse_secret = total_mse_not_secret = 0
total_psnr_secret = total_psnr_not_secret = 0

total_mse_se = total_mse_secret_se = total_mse_not_secret_se = 0
total_psnr_se = total_psnr_secret_se = total_psnr_not_secret_se = 0

total_mse_no = total_mse_secret_no = total_mse_not_secret_no = 0
total_psnr_no = total_psnr_secret_no = total_psnr_not_secret_no = 0

with torch.no_grad():
    for name, x,y in train_set_oceanlotus_squares_loader:
        bs = y.size(0)
        x = x.to(device)
        start = time()
        y_rec = the_model(x,predict=True).detach().cpu()
        x = x.cpu()

        clean_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-y_rec.view(bs,-1))),-1)
        mal_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-x.view(bs,-1))),-1)

        total_mse += MSE(y.view(bs,-1), y_rec.view(bs,-1))
        total_psnr += PSNR(y.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_se += MSE(x.view(bs,-1), y_rec.view(bs,-1))
        total_psnr_se += PSNR(x.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_no += MSE(x.view(bs,-1), y.view(bs,-1))
        total_psnr_no += PSNR(x.view(bs,-1).numpy(), y.view(bs,-1).numpy())
        
        cnt += 1
        
        if printout:
            for i in range(bs):
                image = transforms.ToPILImage()(y_rec[i])
                image.save(os.path.join(outdir, 'UNet_sanitized_' + name[i]))


print(f'Difference w.r.t. malicious:\t{mal_distances}\nDifference .w.r.t. sanitized:\t{clean_distances}')
print('Original --> x_mal w.r.t. y_clean')
print(f'MSE: {(total_mse_no/cnt).item()}, PSRN: {(total_psnr_no/cnt).item()}')

print('Image preservation --> y_clean w.r.t. x_rec')
print(f'MSE: {(total_mse/cnt).item()}, PSRN: {(total_psnr/cnt).item()}')

print('Secret elimination --> x_mal w.r.t. x_rec')
print(f'MSE: {(total_mse_se/cnt).item()}, PSRN: {(total_psnr_se/cnt).item()}')

In [None]:
the_model = model_trainer.model

the_model.load_state_dict(torch.load('./model_UNet.ckpt'))

outdir = os.path.join(WORK_AREA, 'sanitized/val/lsb_oceanlotus/interleaving_squares')

val_set_oceanlotus_squares_loader = torch.utils.data.DataLoader(val_set_oceanlotus_squares, batch_size=batch_size)


the_model.eval()
printout=True


mal_distances = clean_distances = 0.
total_mse = total_mae = total_psnr = cnt = 0
total_mse_secret = total_mse_not_secret = 0
total_psnr_secret = total_psnr_not_secret = 0

total_mse_se = total_mse_secret_se = total_mse_not_secret_se = 0
total_psnr_se = total_psnr_secret_se = total_psnr_not_secret_se = 0

total_mse_no = total_mse_secret_no = total_mse_not_secret_no = 0
total_psnr_no = total_psnr_secret_no = total_psnr_not_secret_no = 0

with torch.no_grad():
    for name, x,y in val_set_oceanlotus_squares_loader:
        bs = y.size(0)
        x = x.to(device)
        start = time()
        y_rec = the_model(x,predict=True).detach().cpu()
        x = x.cpu()

        clean_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-y_rec.view(bs,-1))),-1)
        mal_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-x.view(bs,-1))),-1)

        total_mse += MSE(y.view(bs,-1), y_rec.view(bs,-1))
        total_psnr += PSNR(y.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_se += MSE(x.view(bs,-1), y_rec.view(bs,-1))
        total_psnr_se += PSNR(x.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_no += MSE(x.view(bs,-1), y.view(bs,-1))
        total_psnr_no += PSNR(x.view(bs,-1).numpy(), y.view(bs,-1).numpy())
        
        cnt += 1
        
        if printout:
            for i in range(bs):
                image = transforms.ToPILImage()(y_rec[i])
                image.save(os.path.join(outdir, 'UNet_sanitized_' + name[i]))


print(f'Difference w.r.t. malicious:\t{mal_distances}\nDifference .w.r.t. sanitized:\t{clean_distances}')
print('Original --> x_mal w.r.t. y_clean')
print(f'MSE: {(total_mse_no/cnt).item()}, PSRN: {(total_psnr_no/cnt).item()}')

print('Image preservation --> y_clean w.r.t. x_rec')
print(f'MSE: {(total_mse/cnt).item()}, PSRN: {(total_psnr/cnt).item()}')

print('Secret elimination --> x_mal w.r.t. x_rec')
print(f'MSE: {(total_mse_se/cnt).item()}, PSRN: {(total_psnr_se/cnt).item()}')

In [None]:
the_model = model_trainer.model

the_model.load_state_dict(torch.load('./model_UNet.ckpt'))

outdir = os.path.join(WORK_AREA, 'sanitized/test/lsb_oceanlotus/sequential')

the_model.eval()
printout=True


mal_distances = clean_distances = 0.
total_mse = total_mae = total_psnr = cnt = 0
total_mse_secret = total_mse_not_secret = 0
total_psnr_secret = total_psnr_not_secret = 0

total_mse_se = total_mse_secret_se = total_mse_not_secret_se = 0
total_psnr_se = total_psnr_secret_se = total_psnr_not_secret_se = 0

total_mse_no = total_mse_secret_no = total_mse_not_secret_no = 0
total_psnr_no = total_psnr_secret_no = total_psnr_not_secret_no = 0

with torch.no_grad():
    for name, x,y in test_set_oceanlotus_sequential_loader:
        bs = y.size(0)
        x = x.to(device)
        start = time()
        
        y_rec = the_model(x,predict=True).detach().cpu()
        x = x.cpu()

        clean_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-y_rec.view(bs,-1))),-1)
        mal_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-x.view(bs,-1))),-1)

        total_mse += MSE(y.view(bs,-1), y_rec.view(bs,-1))
        total_psnr += PSNR(y.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_se += MSE(x.view(bs,-1), y_rec.view(bs,-1))
        total_psnr_se += PSNR(x.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_no += MSE(x.view(bs,-1), y.view(bs,-1))
        total_psnr_no += PSNR(x.view(bs,-1).numpy(), y.view(bs,-1).numpy())
        
        cnt += 1
        
        if printout:
            for i in range(bs):
                image = transforms.ToPILImage()(y_rec[i])
                image.save(os.path.join(outdir, 'UNet_sanitized_' + name[i]))


print(f'Difference w.r.t. malicious:\t{mal_distances}\nDifference .w.r.t. sanitized:\t{clean_distances}')
print('Original --> x_mal w.r.t. y_clean')
print(f'MSE: {(total_mse_no/cnt).item()}, PSRN: {(total_psnr_no/cnt).item()}')

print('Image preservation --> y_clean w.r.t. x_rec')
print(f'MSE: {(total_mse/cnt).item()}, PSRN: {(total_psnr/cnt).item()}')

print('Secret elimination --> x_mal w.r.t. x_rec')
print(f'MSE: {(total_mse_se/cnt).item()}, PSRN: {(total_psnr_se/cnt).item()}')

In [None]:
the_model = model_trainer.model

the_model.load_state_dict(torch.load('./model_UNet.ckpt'))

outdir = os.path.join(WORK_AREA, 'sanitized/train/lsb_oceanlotus/sequential')

train_set_oceanlotus_sequential_loader = torch.utils.data.DataLoader(train_set_oceanlotus_sequential, batch_size=batch_size)


the_model.eval()
printout=True


mal_distances = clean_distances = 0.
total_mse = total_mae = total_psnr = cnt = 0
total_mse_secret = total_mse_not_secret = 0
total_psnr_secret = total_psnr_not_secret = 0

total_mse_se = total_mse_secret_se = total_mse_not_secret_se = 0
total_psnr_se = total_psnr_secret_se = total_psnr_not_secret_se = 0

total_mse_no = total_mse_secret_no = total_mse_not_secret_no = 0
total_psnr_no = total_psnr_secret_no = total_psnr_not_secret_no = 0

with torch.no_grad():
    for name, x,y in train_set_oceanlotus_sequential_loader:
        bs = y.size(0)
        x = x.to(device)
        start = time()
        y_rec = the_model(x,predict=True).detach().cpu()
        x = x.cpu()

        clean_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-y_rec.view(bs,-1))),-1)
        mal_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-x.view(bs,-1))),-1)

        total_mse += MSE(y.view(bs,-1), y_rec.view(bs,-1))
        total_psnr += PSNR(y.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_se += MSE(x.view(bs,-1), y_rec.view(bs,-1))
        total_psnr_se += PSNR(x.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_no += MSE(x.view(bs,-1), y.view(bs,-1))
        total_psnr_no += PSNR(x.view(bs,-1).numpy(), y.view(bs,-1).numpy())
        
        cnt += 1
        
        if printout:
            for i in range(bs):
                image = transforms.ToPILImage()(y_rec[i])
                image.save(os.path.join(outdir, 'UNet_sanitized_' + name[i]))


print(f'Difference w.r.t. malicious:\t{mal_distances}\nDifference .w.r.t. sanitized:\t{clean_distances}')
print('Original --> x_mal w.r.t. y_clean')
print(f'MSE: {(total_mse_no/cnt).item()}, PSRN: {(total_psnr_no/cnt).item()}')

print('Image preservation --> y_clean w.r.t. x_rec')
print(f'MSE: {(total_mse/cnt).item()}, PSRN: {(total_psnr/cnt).item()}')

print('Secret elimination --> x_mal w.r.t. x_rec')
print(f'MSE: {(total_mse_se/cnt).item()}, PSRN: {(total_psnr_se/cnt).item()}')

In [None]:
the_model = model_trainer.model

the_model.load_state_dict(torch.load('./model_UNet.ckpt'))

outdir = os.path.join(WORK_AREA, 'sanitized/val/lsb_oceanlotus/sequential')

val_set_oceanlotus_sequential_loader = torch.utils.data.DataLoader(val_set_oceanlotus_sequential, batch_size=batch_size)


the_model.eval()
printout=True


mal_distances = clean_distances = 0.
total_mse = total_mae = total_psnr = cnt = 0
total_mse_secret = total_mse_not_secret = 0
total_psnr_secret = total_psnr_not_secret = 0

total_mse_se = total_mse_secret_se = total_mse_not_secret_se = 0
total_psnr_se = total_psnr_secret_se = total_psnr_not_secret_se = 0

total_mse_no = total_mse_secret_no = total_mse_not_secret_no = 0
total_psnr_no = total_psnr_secret_no = total_psnr_not_secret_no = 0

with torch.no_grad():
    for name, x,y in val_set_oceanlotus_sequential_loader:
        bs = y.size(0)
        x = x.to(device)
        start = time()
        y_rec = the_model(x,predict=True).detach().cpu()
        x = x.cpu()

        clean_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-y_rec.view(bs,-1))),-1)
        mal_distances += torch.sum(torch.mean(torch.abs(y.view(bs,-1)-x.view(bs,-1))),-1)

        total_mse += MSE(y.view(bs,-1), y_rec.view(bs,-1))
        total_psnr += PSNR(y.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_se += MSE(x.view(bs,-1), y_rec.view(bs,-1))
        total_psnr_se += PSNR(x.view(bs,-1).numpy(), y_rec.view(bs,-1).numpy())
        
        total_mse_no += MSE(x.view(bs,-1), y.view(bs,-1))
        total_psnr_no += PSNR(x.view(bs,-1).numpy(), y.view(bs,-1).numpy())
        
        cnt += 1
        
        if printout:
            for i in range(bs):
                image = transforms.ToPILImage()(y_rec[i])
                image.save(os.path.join(outdir, 'UNet_sanitized_' + name[i]))


print(f'Difference w.r.t. malicious:\t{mal_distances}\nDifference .w.r.t. sanitized:\t{clean_distances}')
print('Original --> x_mal w.r.t. y_clean')
print(f'MSE: {(total_mse_no/cnt).item()}, PSRN: {(total_psnr_no/cnt).item()}')

print('Image preservation --> y_clean w.r.t. x_rec')
print(f'MSE: {(total_mse/cnt).item()}, PSRN: {(total_psnr/cnt).item()}')

print('Secret elimination --> x_mal w.r.t. x_rec')
print(f'MSE: {(total_mse_se/cnt).item()}, PSRN: {(total_psnr_se/cnt).item()}')