Library import

In [None]:
from PIL import Image
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt


splits = {'train': 'plain_text/train-00000-of-00001.parquet', 'test': 'plain_text/test-00000-of-00001.parquet', 'unsupervised': 'plain_text/unsupervised-00000-of-00001.parquet'}
df = pd.read_parquet("hf://datasets/stanfordnlp/imdb/" + splits["train"])

Substitutes to 2nd algorithm

In [None]:
subs = {
    'a': '4', 'b': '6', 'i': '1', 'l': '1', 'o': '0', 's': '5', 'z': '2', 'e': '3','g':'9'
}

def format_text(df, image_size, letter_size):
    def cut_text(text):
        substituted_text = ''.join([subs.get(char, char) for char in text.lower()])
        cut_text = ''.join([char for char in substituted_text if char in 'cdfhjkmnprtuvwxy 0123456789'])
        max_length = (image_size[0] // letter_size[0]) * (image_size[1] // letter_size[1]) * 3
        return cut_text[:max_length]
    df['cut_text'] = df['text'].apply(cut_text)
    return df

df = format_text(df,(64,64),(8,8))

Substitution to 1 and 3 algorithms

In [None]:
def create_cut_text_column(df, image_size, letter_size, allowed_characters=None):
    if allowed_characters is None:
        allowed_characters = ''.join(chr(i) for i in range(min(255, image_size[0]*image_size[1]/letter_size[0]/letter_size[1])))
    
    def cut_text(text):
        cut_text = ''.join([char for char in text.lower() if char in allowed_characters])
        max_length = (image_size[0] // letter_size[0]) * (image_size[1] // letter_size[1]) * 3
        return cut_text[:max_length]

    df['cut_text'] = df['text'].apply(cut_text)
    return df

df = format_text(df,(64,64),(8,8))

General code to decode and encode

In [None]:
from PIL import Image

def text_to_rgb_image_ascii(text, image_size=(64, 64), letter_size=(8, 8), allowed_chars=None):
    image = Image.new('RGB', image_size, (255, 255, 255))
    pixels = image.load()

    if allowed_chars is None:
        allowed_chars = ''.join(chr(i) for i in range(256))

    text = ''.join([char for char in text if char in allowed_chars])

    max_chars = (image_size[0] // letter_size[0]) * (image_size[1] // letter_size[1])

    if len(text) > max_chars:
        text = text[:max_chars]
    else:
        text = text.ljust(max_chars)

    for i, char in enumerate(text):
        ascii_value = ord(char)
        row = (i // (image_size[0] // letter_size[0])) * letter_size[1]
        col = (i % (image_size[0] // letter_size[0])) * letter_size[0]

        for y in range(row, row + letter_size[1]):
            for x in range(col, col + letter_size[0]):
                if 48 <= ascii_value <= 57:
                    pixels[x, y] = (255, 255, 255)
                else:
                    r = (ascii_value & 0xE0)
                    g = (ascii_value & 0x1C) << 3
                    b = (ascii_value & 0x03) << 6
                    pixels[x, y] = (r, g, b)

    return image, text

def rgb_image_to_text_ascii(image, letter_size=(8, 8), allowed_chars=None):
    if allowed_chars is None:
        allowed_chars = ''.join(chr(i) for i in range(256))

    image_width, image_height = image.size
    num_chars_per_row = image_width // letter_size[0]
    num_rows = image_height // letter_size[1]
    max_chars = num_chars_per_row * num_rows

    decoded_text = []

    for i in range(max_chars):
        row = (i // num_chars_per_row) * letter_size[1]
        col = (i % num_chars_per_row) * letter_size[0]

        r_total, g_total, b_total = 0, 0, 0
        pixel_count = letter_size[0] * letter_size[1]

        for y in range(row, row + letter_size[1]):
            for x in range(col, col + letter_size[0]):
                r, g, b = image.getpixel((x, y))
                r_total += r
                g_total += g
                b_total += b

        r_avg = r_total // pixel_count
        g_avg = g_total // pixel_count
        b_avg = b_total // pixel_count

        ascii_value = (r_avg & 0xE0) | ((g_avg >> 3) & 0x1C) | ((b_avg >> 6) & 0x03)

        if 0 <= ascii_value < 256 and chr(ascii_value) in allowed_chars:
            decoded_text.append(chr(ascii_value))

    return ''.join(decoded_text)

test_text = 'Hello, World! This is a test for encoding ASCII text into an image.'

test_image, filtered_text = text_to_rgb_image_ascii(
    test_text, image_size=(64, 64), letter_size=(8, 8)
)

decoded_text = rgb_image_to_text_ascii(test_image, letter_size=(8, 8))

print("Original Text:")
print(filtered_text)
print("\nDecoded Text:")
print(decoded_text)


Encode and decode to way 1


In [None]:
from PIL import Image

def text_to_rgb_image(text, image_size=(64, 64), letter_size=(8, 8)):

    image = Image.new('RGB', image_size, (255, 255, 255))
    pixels = image.load()


    characters = 'abcdefghijklmnopqrstuvwxyz 0123456789'


    char_to_gray = {char: int((i / (len(characters) - 1)) * 255) for i, char in enumerate(characters)}


    max_chars = (image_size[0] // letter_size[0]) * (image_size[1] // letter_size[1])*3


    text = ''.join([char for char in text.lower() if char in characters])[:max_chars]


    chars_per_channel = max_chars // 3


    for i, char in enumerate(text):
        if char in char_to_gray:
            grayscale_value = char_to_gray[char]
        else:
            continue

        ic = i%64
        row = (ic // (image_size[0] // letter_size[0])) * letter_size[1]
        col = (ic % (image_size[0] // letter_size[0])) * letter_size[0]


        if i < chars_per_channel: 
            for y in range(row, row + letter_size[1]):
                for x in range(col, col + letter_size[0]):
                    r, g, b = pixels[x, y]
                    pixels[x, y] = (grayscale_value, g, b)
        elif i < 2 * chars_per_channel: 
            for y in range(row, row + letter_size[1]):
                for x in range(col, col + letter_size[0]):
                    r, g, b = pixels[x, y]
                    pixels[x, y] = (r, grayscale_value, b)
        else: 
            for y in range(row, row + letter_size[1]):
                for x in range(col, col + letter_size[0]):
                    r, g, b = pixels[x, y]
                    pixels[x, y] = (r, g, grayscale_value)

    return image, text



def rgb_image_to_text(image, image_size=(64, 64), letter_size=(8, 8)):

    characters = 'abcdefghijklmnopqrstuvwxyz 0123456789'
    

    char_to_gray = {char: int((i / (len(characters) - 1)) * 255) for i, char in enumerate(characters)}

    sorted_chars = sorted(char_to_gray.items(), key=lambda item: item[1])
    gray_values = [v for k, v in sorted_chars]
    gray_to_char = {v: k for k, v in sorted_chars}
    

    def closest_char(gray):
        closest = min(gray_values, key=lambda x: abs(x - gray))
        return gray_to_char.get(closest, '?')  
    

    pixels = image.load()
    
    width, height = image_size
    lw, lh = letter_size
    cols = width // lw
    rows = height // lh
    chars_per_channel = cols * rows
    
    decoded_chars = {'R': [], 'G': [], 'B': []}
    
    for channel in ['R', 'G', 'B']:
        for i in range(chars_per_channel):

            row = (i // cols) * lh
            col = (i % cols) * lw
            

            channel_values = []
            for y in range(row, row + lh):
                for x in range(col, col + lw):
                    r, g, b = pixels[x, y]
                    if channel == 'R':
                        channel_values.append(r)
                    elif channel == 'G':
                        channel_values.append(g)
                    elif channel == 'B':
                        channel_values.append(b)
            

            avg_gray = sum(channel_values) / len(channel_values)
            avg_gray = int(round(avg_gray))
            

            char = closest_char(avg_gray)
            if avg_gray == 255:

                continue
            decoded_chars[channel].append(char)
    

    text = ''.join(decoded_chars['R'] + decoded_chars['G'] + decoded_chars['B'])
    
    return text

Encode and decode to way 2


In [None]:
from PIL import Image

def text_to_rgb_image_mapped(text, image_size=(64, 64), letter_size=(8, 8)):
    image = Image.new('RGB', image_size, (255, 255, 255))
    pixels = image.load()
    characters = 'cdfhjkmnprtuvwxy 0123456789RB'
    characters_list = list(characters)
    map_1_chars = characters_list[:10]    
    map_2_chars = characters_list[10:20] 
    map_3_chars = characters_list[20:]

    def create_char_to_gray_map(char_map):
        return {char: int((i / (len(char_map) - 1)) * 255) for i, char in enumerate(char_map)}

    char_to_gray_map1 = create_char_to_gray_map(map_1_chars)
    char_to_gray_map2 = create_char_to_gray_map(map_2_chars)
    char_to_gray_map3 = create_char_to_gray_map(map_3_chars)


    max_chars = (image_size[0] // letter_size[0]) * (image_size[1] // letter_size[1])


    filtered_text = ''.join([char for char in text.lower() if char in characters_list])[:max_chars]

    i = 0  
    while i < len(filtered_text):
        char = filtered_text[i]

        row = (i // (image_size[0] // letter_size[0])) * letter_size[1]
        col = (i % (image_size[0] // letter_size[0])) * letter_size[0]

        if char in map_1_chars:
            grayscale_value = char_to_gray_map1[char]
            for y in range(row, row + letter_size[1]):
                for x in range(col, col + letter_size[0]):
                    _, g, b = pixels[x, y]
                    pixels[x, y] = (grayscale_value, g, 226)


        elif char in map_2_chars:
            grayscale_value = char_to_gray_map2[char]
            for y in range(row, row + letter_size[1]):
                for x in range(col, col + letter_size[0]):
                    r, _, b = pixels[x, y]
                    pixels[x, y] = (r, grayscale_value, 255)


        elif char in map_3_chars:
            grayscale_value = char_to_gray_map3[char]
            for y in range(row, row + letter_size[1]):
                for x in range(col, col + letter_size[0]):
                    r, g, b = pixels[x, y]
                    pixels[x, y] = (r, g, grayscale_value)
        i += 1 
    return image, filtered_text

def rgb_image_to_text(image, image_size=(64, 64), letter_size=(8, 8)):
    characters = 'cdfhjkmnprtuvwxy 0123456789'
    gray_to_char = {int((i / (len(characters) - 1)) * 255): char for i, char in enumerate(characters)}
    pixels = image.load()
    max_chars = (image_size[0] // letter_size[0]) * (image_size[1] // letter_size[1]) * 3
    chars_per_channel = max_chars // 3
    decoded_text = ''
    for i in range(max_chars):
        ic = i % 64
        row_start = (ic // (image_size[0] // letter_size[0])) * letter_size[1]
        col_start = (ic % (image_size[0] // letter_size[0])) * letter_size[0]
        r_sum, g_sum, b_sum = 0, 0, 0
        total_pixels = letter_size[0] * letter_size[1]
        for row_offset in range(letter_size[1]):
            for col_offset in range(letter_size[0]):
                r, g, b = pixels[col_start + col_offset, row_start + row_offset]
                r_sum += r
                g_sum += g
                b_sum += b
        avg_r = r_sum // total_pixels
        avg_g = g_sum // total_pixels
        avg_b = b_sum // total_pixels
        if i < chars_per_channel: 
            grayscale_value = avg_r
        elif i < 2 * chars_per_channel: 
            grayscale_value = avg_g
        else:  
            grayscale_value = avg_b
        closest_gray_value = min(gray_to_char.keys(), key=lambda x: abs(x - grayscale_value))
        decoded_char = gray_to_char[closest_gray_value]
        decoded_text += decoded_char
    return decoded_text

In [None]:
df = df[:12800]
df = create_cut_text_column(df)

In [None]:
Image data preprocessing

In [None]:
sub_train_images = 12800
train_size = int(0.7 * sub_train_images)
val_size = int(0.1 * sub_train_images)
test_size = sub_train_images - train_size - val_size

In [None]:
class TextImageDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.sample(frac=1).reset_index(drop=True)
        self.transform = transform
        self.image_size = (64,64)
        self.letter_size = (4,4)


    def __len__(self):

        return len(self.df)

    def __getitem__(self, idx):
        image = self.df.iloc[idx]['image']
        text = self.df.iloc[idx]['cut_text']  


        if isinstance(image, Image.Image): 
            if self.transform:
                image = self.transform(image)  

        return image, text


In [None]:

image_transforms = transforms.Compose([
    transforms.Resize((64, 64)), 
    transforms.ToTensor(),             
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  
])


text_image_dataset_train = TextImageDataset(train_text_df, transform=image_transforms)
text_image_dataset_val = TextImageDataset(val_text_df, transform=image_transforms)
text_image_dataset_test = TextImageDataset(test_text_df, transform=image_transforms)


batch_size = 64
text_image_train_dataloader = DataLoader(text_image_dataset_train, batch_size=64, shuffle=True, num_workers=4)
text_image_val_dataloader = DataLoader(text_image_dataset_val, batch_size=64, shuffle=True, num_workers=4)
text_image_test_dataloader = DataLoader(text_image_dataset_test, batch_size=64, shuffle=True, num_workers=4)



In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets, transforms


cropH, cropW = 64, 64
batchSize = 64
dataDir = f"{nikhilshingadiya_tinyimagenet200_path}/tiny-imagenet-200"
trainDir = f"{dataDir}/train"
sub_size = 12800  # Select a subset of images


data_transforms = transforms.Compose([
    transforms.RandomCrop((cropH, cropW), pad_if_needed=True),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


train_images = datasets.ImageFolder(trainDir, transform=data_transforms)


indices_train = np.random.choice(len(train_images), size=sub_size, replace=False)
sub_train_images = torch.utils.data.Subset(train_images, indices_train)


train_dataset, val_dataset, test_dataset = random_split(sub_train_images, [train_size, val_size, test_size])


train_loader = DataLoader(train_dataset, batch_size=batchSize, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batchSize, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batchSize, shuffle=False)

In [None]:
Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class CBAMBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(CBAMBlock, self).__init__()
        # Channel Attention
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // reduction, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, kernel_size=1),
            nn.Sigmoid()
        )
        # Spatial Attention
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=7, padding=3),
            nn.Sigmoid()
        )

    def forward(self, x):
        # Channel Attention
        ca = self.channel_attention(x)
        x = x * ca
        # Spatial Attention
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        sa_input = torch.cat([avg_out, max_out], dim=1)
        sa = self.spatial_attention(sa_input)
        x = x * sa
        return x


class ResidualCBAMBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualCBAMBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels)
        )
        self.cbam = CBAMBlock(channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        res = x
        x = self.conv(x)
        x = self.cbam(x)
        x += res
        x = self.relu(x)
        return x


class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        # Initial convolution
        self.conv_in = nn.Sequential(
            nn.Conv2d(6, 64, kernel_size=7, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        # Residual CBAM blocks
        self.res_blocks = nn.Sequential(
            ResidualCBAMBlock(64),
            ResidualCBAMBlock(64),
            ResidualCBAMBlock(64),
            ResidualCBAMBlock(64),
            ResidualCBAMBlock(64),
            ResidualCBAMBlock(64),
            ResidualCBAMBlock(64),
        )
        # Output convolution
        self.conv_out = nn.Sequential(
            nn.Conv2d(64, 3, kernel_size=7, padding=3),
            nn.Tanh()
        )

    def forward(self, input_S, input_C):
        # Concatenate input images
        x = torch.cat([input_S, input_C], dim=1)
        x = self.conv_in(x)
        x = self.res_blocks(x)
        output = self.conv_out(x)
        return output


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        # Initial convolution
        self.conv_in = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        # Residual CBAM blocks
        self.res_blocks = nn.Sequential(
            ResidualCBAMBlock(64),
            ResidualCBAMBlock(64),
            ResidualCBAMBlock(64),
            ResidualCBAMBlock(64),
            ResidualCBAMBlock(64),
            ResidualCBAMBlock(64),
            ResidualCBAMBlock(64)

        )

        self.conv_out = nn.Sequential(
            nn.Conv2d(64, 3, kernel_size=7, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.conv_in(x)
        x = self.res_blocks(x)
        output = self.conv_out(x)
        return output


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, input_S, input_C):
        output_Cprime = self.encoder(input_S, input_C)
        output_Sprime = self.decoder(output_Cprime)
        return output_Cprime, output_Sprime

    def pixel_errors(self, input_S, input_C):
        with torch.no_grad():
            output_Cprime, output_Sprime = self.forward(input_S, input_C)
            diff_C = torch.abs(output_Cprime - input_C)
            diff_S = torch.abs(output_Sprime - input_S)
            see_Cpixel = torch.sqrt(torch.mean(diff_C ** 2)).item()
            see_Spixel = torch.sqrt(torch.mean(diff_S ** 2)).item()
        return see_Cpixel, see_Spixel


In [None]:
Model training

In [None]:
def train(model, train_loader, val_loader, text_image_train_loader, text_image_val_loader, num_epochs=1):
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    print("device : ", device)

    model.to(device)


    S_mseloss = torch.nn.MSELoss().to(device) 
    C_mseloss = torch.nn.MSELoss().to(device)
    optimizer = torch.optim.Adam(model.parameters())


    loss_all_total, c_loss_total, s_loss_total = [], [], []
    val_loss_all_total, val_c_loss_total, val_s_loss_total = [], [], []

    for epoch in range(num_epochs):

        model.train()  
        loss_all, c_loss, s_loss = [], [], []

        train_iter = iter(train_loader)
        text_train_iter = iter(text_image_train_loader)  
        t = tqdm(range(len(train_loader)), desc=f"Epoch {epoch+1} [Training]")

        for _ in t:
            try:
                cover_images, _ = next(train_iter)
                cover_images = cover_images.to(device)
                image_secrets, _ = next(text_train_iter)
                image_secrets = image_secrets.to(device)

            except StopIteration:
                break

            optimizer.zero_grad() 


            output_C, output_S = model(image_secrets, cover_images)


            beta = 1.0
            ssLoss = S_mseloss(image_secrets, output_S)
            ccLoss = C_mseloss(cover_images, output_C)
            loss = beta * ssLoss + ccLoss


            loss.backward()
            optimizer.step()


            loss_all.append(loss.item())
            c_loss.append(ccLoss.item())
            s_loss.append(ssLoss.item())


            t.set_description(f"Epoch {epoch+1} [Training] | Loss: {np.mean(loss_all):.4f}")


        loss_all_total.append(np.mean(loss_all))
        c_loss_total.append(np.mean(c_loss))
        s_loss_total.append(np.mean(s_loss))


        model.eval()  
        val_loss_all, val_c_loss, val_s_loss = [], [], []

        val_iter = iter(val_loader)
        text_val_iter = iter(text_image_val_loader)  
        v = tqdm(range(len(val_loader)), desc=f"Epoch {epoch+1} [Validation]")

        with torch.no_grad():
            for _ in v:
                try:
                    cover_images, _ = next(val_iter)
                    cover_images = cover_images.to(device)
                    image_secrets, _ = next(text_val_iter)
                    image_secrets = image_secrets.to(device)

                except StopIteration:
                    break

                output_C, output_S = model(image_secrets, cover_images)
                ssLoss = S_mseloss(image_secrets, output_S)
                ccLoss = C_mseloss(cover_images, output_C)
                val_loss = beta * ssLoss + ccLoss
                val_loss_all.append(val_loss.item())
                val_c_loss.append(ccLoss.item())
                val_s_loss.append(ssLoss.item())
                v.set_description(f"Epoch {epoch+1} [Validation] | Loss: {np.mean(val_loss_all):.4f}")


        val_loss_all_total.append(np.mean(val_loss_all))
        val_c_loss_total.append(np.mean(val_c_loss))
        val_s_loss_total.append(np.mean(val_s_loss))


        print(f"[Epoch {epoch+1}] Training Loss: {np.mean(loss_all)} | Cover Loss: {np.mean(c_loss)} | Secret Loss: {np.mean(s_loss)}")
        print(f"[Epoch {epoch+1}] Validation Loss: {np.mean(val_loss_all)} | Cover Loss: {np.mean(val_c_loss)} | Secret Loss: {np.mean(val_s_loss)}")

    return model, (loss_all_total, c_loss_total, s_loss_total), (val_loss_all_total, val_c_loss_total, val_s_loss_total)

model=Model()

trained_model, train_history, val_history = train(model, train_loader, val_loader, text_image_train_dataloader, text_image_val_dataloader, num_epochs=16)


In [None]:
import matplotlib.pyplot as plt


train_total_loss = train_history[0]
train_cover_loss = train_history[1]
train_secret_loss = train_history[2]

val_total_loss = val_history[0]
val_cover_loss = val_history[1]
val_secret_loss = val_history[2]

epochs = range(1, 17)

plt.figure(figsize=(15, 10))


plt.subplot(3, 1, 1)
plt.plot(epochs, train_total_loss, 'g-', label='Train Total Loss')
plt.plot(epochs, val_total_loss, 'b-', label='Val Total Loss')
plt.title('Total Loss History')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()


plt.subplot(3, 1, 2)
plt.plot(epochs, train_cover_loss, 'g-', label='Train Cover Loss')
plt.plot(epochs, val_cover_loss, 'b-', label='Val Cover Loss')
plt.title('Cover Image Loss History')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()


plt.subplot(3, 1, 3)
plt.plot(epochs, train_secret_loss, 'g-', label='Train Secret Loss')
plt.plot(epochs, val_secret_loss, 'b-', label='Val Secret Loss')
plt.title('Secret Image Loss History')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()


In [None]:
Demonstrating results

In [None]:
from PIL import Image, ImageDraw, ImageFont
import torch
import matplotlib.pyplot as plt


def display_results(image_dataloader, text_image_dataloader, model):
    model.eval()
    image_data_iter = iter(image_dataloader)
    text_data_iter = iter(text_image_dataloader)
    cover_images, _ = next(image_data_iter) 
    image_secrets, _ = next(text_data_iter)   
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    image_secrets = image_secrets.to(device)
    cover_images = cover_images.to(device)
    with torch.no_grad():
        output_covers, output_secrets = model(image_secrets, cover_images)
    image_secrets = image_secrets.cpu().numpy()
    cover_images = cover_images.cpu().numpy()
    output_secrets = output_secrets.cpu().numpy()
    output_covers = output_covers.cpu().numpy()
    fig, axs = plt.subplots(4, 5, figsize=(15, 12)) 
    for i in range(5):
        img = (image_secrets[i].transpose(1, 2, 0) + 1) / 2 
        axs[0, i].imshow(img)
        axs[0, i].set_title(f'Secret {i+1}')
        axs[0, i].axis('off') 
    for i in range(5):
        img = (output_secrets[i].transpose(1, 2, 0) + 1) / 2 
        axs[1, i].imshow(img)
        axs[1, i].set_title(f'Restored secret {i+1}')
        axs[1, i].axis('off')
    for i in range(5):
        img = (cover_images[i].transpose(1, 2, 0) + 1) / 2  
        axs[2, i].imshow(img)
        axs[2, i].set_title(f'Cover {i+1}')
        axs[2, i].axis('off') 
    for i in range(5):
        img = (output_covers[i].transpose(1, 2, 0) + 1) / 2
        axs[3, i].imshow(img)
        axs[3, i].set_title(f'Cover with hidden image{i+1}')
        axs[3, i].axis('off')  
    plt.show()

display_results(val_loader, text_image_val_dataloader, AEmodel)


In [None]:
Demos

In [None]:
Fast calculation of test

In [None]:
from PIL import Image
from spellchecker import SpellChecker
import string

def rgb_image_to_text_with_average_fixed(image, letter_size=(8, 8)):
    characters_list = [chr(i) for i in range(128)]
    value_to_state = {0: 0, 51: 1, 102: 2, 153: 3, 204: 4, 255: 5}
    possible_values = list(value_to_state.keys())

    def round_to_nearest(value):
        return min(possible_values, key=lambda x: abs(x - value))

    def base7_rgb_to_position(r, g, b):
        r_rounded = round_to_nearest(r)
        g_rounded = round_to_nearest(g)
        b_rounded = round_to_nearest(b)

        r_state = value_to_state[r_rounded]
        g_state = value_to_state[g_rounded]
        b_state = value_to_state[b_rounded]
        position = r_state + (g_state * 6) + (b_state * 36)  

        return position

    def average_rgb_in_block(x_start, y_start, width, height):
        r_total, g_total, b_total = 0, 0, 0
        pixel_count = width * height

        for y in range(y_start, y_start + height):
            for x in range(x_start, x_start + width):
                r, g, b = image.getpixel((x, y))
                r_total += r
                g_total += g
                b_total += b

        r_avg = r_total // pixel_count
        g_avg = g_total // pixel_count
        b_avg = b_total // pixel_count

        return r_avg, g_avg, b_avg

    image_width, image_height = image.size
    num_chars_per_row = image_width // letter_size[0]
    num_rows = image_height // letter_size[1]
    max_chars = num_chars_per_row * num_rows
    decoded_text = []

    for i in range(max_chars):
        row = (i // num_chars_per_row) * letter_size[1]
        col = (i % num_chars_per_row) * letter_size[0]

        r_avg, g_avg, b_avg = average_rgb_in_block(col, row, letter_size[0], letter_size[1])

        position = base7_rgb_to_position(r_avg, g_avg, b_avg)

        if position < len(characters_list):
            decoded_text.append(characters_list[position])
        else:
            decoded_text.append(' ')
    
    decoded_string = ''.join(decoded_text)
    

    word = ''
    for char in decoded_string:
        if not char.isalpha():
            if word == "":
                corrected_text+=char
                continue
            corrected_word = spell.correction(word)
            if corrected_word is None:
                corrected_text+=word
                corrected_word = ""
            elif len(word)==len(corrected_word):
                new_word = ""
                for i in range(len(word)):
                    if word[i].lower() == corrected_word[i].lower():
                        new_word+=word[i]
                    else:
                        new_word += corrected_word[i]
                corrected_text+=new_word
                
            else:
                corrected_text+=word
            corrected_word = ""
            word = ""
            corrected_text+=char
        else:
            word+=char
    if word!="":
        corrected_word = spell.correction(word)
        if corrected_word is None:
            corrected_text+=word
            corrected_word = ""
        elif len(word)==len(corrected_word):
            new_word = ""
            for i in range(len(word)):
                if word[i].lower() == corrected_word[i].lower():
                    new_word+=word[i]
                else:
                    new_word += corrected_word[i]
            corrected_text+=new_word
        else:
            corrected_text+=word

    return decoded_string, decoded_string



Calculation with dataloader

In [None]:

def confusion_matric(restored_texts, ground_truth_texts):
    characters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .' 
    char_to_idx = {char: idx for idx, char in enumerate(characters)}
    confusion_matrix = np.zeros((len(characters), len(characters)), dtype=int)

    total_chars = 0
    correct_chars = 0
    exact_matches = 0

    for restored, ground_truth in zip(restored_texts, ground_truth_texts):

        restored = restored[:len(ground_truth)]
        for r_char, gt_char in zip(restored, ground_truth):
            total_chars += 1
            if r_char == gt_char:
                correct_chars += 1

            if gt_char in char_to_idx and r_char in char_to_idx:
                confusion_matrix[char_to_idx[gt_char], char_to_idx[r_char]] += 1


        if restored == ground_truth:
            exact_matches += 1

    char_accuracy = (correct_chars / total_chars) * 100 if total_chars > 0 else 0
    exact_match_accuracy = (exact_matches / len(ground_truth_texts)) * 100

    return char_accuracy, exact_match_accuracy, confusion_matrix

In [None]:
from PIL import Image
from spellchecker import SpellChecker
import string

def rgb_image_to_text(image, letter_size=(8, 8), N=6):
    max_positions = N ** 3
    characters_list = [chr(i) for i in range(max_positions)]
    
    spell = SpellChecker()
    decoded_string = ''
    corrected_text = ''
    word = ''

    value_to_state = {}
    possible_values = []
    if N > 1:
        step = 255 // (N - 1)
    else:
        step = 255  

    for i in range(N):
        value = i * step
        possible_values.append(value)
        value_to_state[value] = i

    def round_to_nearest(value):
        return min(possible_values, key=lambda x: abs(x - value))

    def baseN_rgb_to_position(r, g, b):
        r_rounded = round_to_nearest(r)
        g_rounded = round_to_nearest(g)
        b_rounded = round_to_nearest(b)

        r_state = value_to_state[r_rounded]
        g_state = value_to_state[g_rounded]
        b_state = value_to_state[b_rounded]

        position = r_state + (g_state * N) + (b_state * N * N)
        return position

    def average_rgb_in_block(x_start, y_start, width, height):
        r_total, g_total, b_total = 0, 0, 0
        pixel_count = width * height

        for y in range(y_start, y_start + height):
            for x in range(x_start, x_start + width):
                r, g, b = image.getpixel((x, y))
                r_total += r
                g_total += g
                b_total += b

        r_avg = r_total // pixel_count
        g_avg = g_total // pixel_count
        b_avg = b_total // pixel_count

        return r_avg, g_avg, b_avg

    image_width, image_height = image.size
    num_chars_per_row = image_width // letter_size[0]
    num_rows = image_height // letter_size[1]
    max_chars = num_chars_per_row * num_rows

    for i in range(max_chars):
        row = (i // num_chars_per_row) * letter_size[1]
        col = (i % num_chars_per_row) * letter_size[0]

        r_avg, g_avg, b_avg = average_rgb_in_block(col, row, letter_size[0], letter_size[1])

        position = baseN_rgb_to_position(r_avg, g_avg, b_avg)

        if position < len(characters_list):
            character = characters_list[position]
        else:
            character = ' ' 

        decoded_string += character

        if not character.isalpha():
            if word == "":
                corrected_text += character
                continue
            corrected_word = spell.correction(word)
            if corrected_word is None:
                corrected_text += word
            elif len(word) == len(corrected_word):
                new_word = ''
                for j in range(len(word)):
                    if word[j].islower() == corrected_word[j].islower():
                        new_word += word[j]
                    else:
                        new_word += corrected_word[j]
                corrected_text += new_word
            else:
                corrected_text += word
            word = ''
            corrected_text += character
        else:
            word += character

    if word != "":
        corrected_word = spell.correction(word)
        if corrected_word is None:
            corrected_text += word
        elif len(word) == len(corrected_word):
            new_word = ''
            for j in range(len(word)):
                if word[j].islower() == corrected_word[j].islower():
                    new_word += word[j]
                else:
                    new_word += corrected_word[j]
            corrected_text += new_word
        else:
            corrected_text += word

    return decoded_string, corrected_text

In [None]:
from collections import Counter

def evaluate_text_restoration_accuracy(image_dataloader, text_image_dataloader, model):
    model.eval()

    restored_texts = []
    ground_truth_texts = []
    decoded_texts = []

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    total_mse_loss = 0
    total_ssim_score = 0
    num_batches = 0

    error_counts_decoded = []
    error_counts_restored = []

    decoded_errors = []
    restored_errors = []


    for (cover_images, _), (secret_images, texts) in zip(image_dataloader, text_image_dataloader):
        secret_images = secret_images.to(device)
        cover_images = cover_images.to(device)

        with torch.no_grad():
            output_covers, output_secrets = model(secret_images, cover_images)

        output_covers = output_covers.to(device)

        output_secrets = output_secrets.cpu().numpy()

        mse_loss = F.mse_loss(output_covers, cover_images, reduction='mean')
        total_mse_loss += mse_loss.item()

        batch_ssim = ssim(output_covers, cover_images, data_range=1.0)
        total_ssim_score += batch_ssim.item()

        num_batches += 1

        for i in range(output_secrets.shape[0]):
            img = (output_secrets[i].transpose(1, 2, 0) + 1) / 2  
            pil_image = Image.fromarray((img * 255).astype('uint8'))
            pil_image.save(f'image{i}.png')
            restored_text, decoded_text = rgb_image_to_text(pil_image) 
            restored_text = restored_text[:len(texts[i])]
            decoded_text = decoded_text[:len(texts[i])]
            decoded_texts.append(decoded_text)
            restored_texts.append(restored_text)
            ground_truth_texts.append(texts[i]) 

            decoded_error_count = sum(1 for a, b in zip(decoded_text, texts[i]) if a != b)
            restored_error_count = sum(1 for a, b in zip(restored_text, texts[i]) if a != b)

            decoded_errors.append(decoded_error_count)
            restored_errors.append(restored_error_count)
            error_counts_decoded.append(decoded_error_count)
            error_counts_restored.append(restored_error_count)


    char_accuracy, exact_match_accuracy, confusion_matrix = calculate_confusion_matrix_and_accuracy(
        restored_texts, ground_truth_texts
    )

    avg_mse_loss = total_mse_loss / num_batches
    avg_ssim_score = total_ssim_score / num_batches

    print(f"Character-Level Accuracy Decoded (Dataset): {char_accuracy:.2f}%")
    print(f"Exact Match Accuracy Decoded (Dataset): {exact_match_accuracy:.2f}%")
    print(f"Average MSE Loss (Cover Images): {avg_mse_loss:.4f}")
    print(f"Average SSIM Score (Cover Images): {avg_ssim_score:.4f}")

    confusion_matrix_normalized = confusion_matrix / confusion_matrix.sum(axis=1, keepdims=True)

    characters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .'
    plt.figure(figsize=(24, 20))
    sns.heatmap(confusion_matrix_normalized, annot=True, fmt=".2f", cmap="YlGnBu", xticklabels=characters, yticklabels=characters)
    plt.title("Character-Level Reconstruction Confusion Matrix")
    plt.xlabel("Restored Characters")
    plt.ylabel("Ground Truth Characters")
    plt.show()


    char_accuracy, exact_match_accuracy, confusion_matrix = calculate_confusion_matrix_and_accuracy(
        decoded_texts, ground_truth_texts
    )

    avg_mse_loss = total_mse_loss / num_batches
    avg_ssim_score = total_ssim_score / num_batches

    print(f"Character-Level Accuracy Decoded (Dataset): {char_accuracy:.2f}%")
    print(f"Exact Match Accuracy Decoded (Dataset): {exact_match_accuracy:.2f}%")

    confusion_matrix_normalized = confusion_matrix / confusion_matrix.sum(axis=1, keepdims=True)

    characters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .'
    plt.figure(figsize=(24, 20))
    sns.heatmap(confusion_matrix_normalized, annot=True, fmt=".2f", cmap="YlGnBu", xticklabels=characters, yticklabels=characters)
    plt.title("Character-Level Reconstruction Confusion Matrix")
    plt.xlabel("Restored Characters")
    plt.ylabel("Ground Truth Characters")
    plt.show()



    error_values, frequency_values = zip(*error_counts_decoded)
    total_reconstructions = sum(frequency_values)
    error_percentages = [freq / total_reconstructions * 100 for freq in frequency_values]

    plt.figure(figsize=(10, 6))
    plt.bar(error_values, error_percentages, color='skyblue')
    plt.xlabel("Number of Errors")
    plt.ylabel("Percentage of Reconstructions in Decoded(%)")
    plt.title("Distribution of Errors in Text Reconstructions in Decoded")
    plt.xticks(range(max(error_values) + 1)) 
    plt.show()

    error_values, frequency_values = zip(*error_counts_restored)
    total_reconstructions = sum(frequency_values)
    error_percentages = [freq / total_reconstructions * 100 for freq in frequency_values]

    plt.figure(figsize=(10, 6))
    plt.bar(error_values, error_percentages, color='skyblue')
    plt.xlabel("Number of Errors")
    plt.ylabel("Percentage of Reconstructions in Corrected(%)")
    plt.title("Distribution of Errors in Text Reconstructions in Corrected")
    plt.xticks(range(max(error_values) + 1)) 
    plt.show()
