# Dataloader

In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [3]:
import glob
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torch.utils.tensorboard import SummaryWriter
from natsort import natsorted
from PIL import Image
import cv2
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [4]:
DIV2K_path = "/data/whq/data/DIV2K"

In [25]:
transform = T.Compose([
    T.ToTensor(),
    T.RandomCrop(128)
])
transform_A =A.Compose([
    A.RandomCrop(400,400),
    A.augmentations.transforms.ChannelShuffle(0.5),
    ToTensorV2()
])


**使用 torchvision 进行变换**
1. 使用 PIL 进行图像读取

In [26]:
class DIV2K_Dataset(Dataset):
    def __init__(self, transforms_=None, mode='train'):
        self.transform = transforms_
        self.mode = mode
        if mode == 'train':
            self.files = natsorted(sorted(glob.glob(DIV2K_path+"/train"+"/*."+"png")))
        else:
            self.files = natsorted(
                sorted(glob.glob(DIV2K_path+"/valid"+"/*."+"png")))

    def __getitem__(self, index):
        # try:
        image = Image.open(self.files[index]).convert(
            'RGB')  # 使用 PIL 读取并转为 RGB
        item = self.transform(image)
        return item
        # except:
        #     return self.__getitem__(index+1)  # 这个不理解

    def __len__(self):
        return len(self.files)


DIV2K_train_loader = DataLoader(
    DIV2K_Dataset(transforms_=transform, mode="train"),
    batch_size=16,
    shuffle=True,
    pin_memory=True,
    num_workers=0,
    drop_last=True
)

DIV2K_val_loader = DataLoader(
    DIV2K_Dataset(transforms_=transform, mode="val"),
    batch_size=16,
    shuffle=False,
    pin_memory=True,
    num_workers=0,
    drop_last=True
)

2. 使用 cv2 进行图像读取，并使用 albumentations 进行数据增强

In [36]:
class DIV2K_Dataset(Dataset):
    def __init__(self, transforms_=None, mode='train'):
        self.transform = transforms_
        self.mode = mode
        if mode == 'train':
            self.files = natsorted(
                sorted(glob.glob(DIV2K_path+"/DIV2K_train_HR"+"/*."+"png")))
        else:
            self.files = natsorted(
                sorted(glob.glob(DIV2K_path+"/DIV2K_valid_HR"+"/*."+"png")))

    def __getitem__(self, index):
        img = cv2.imread(self.files[index])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # 转换为 RGB
        trans_img = self.transform(img)
        item= trans_img['image']
        return item

    def __len__(self):
        return len(self.files)


DIV2K_train_loader = DataLoader(
    DIV2K_Dataset(transforms_=transform_A, mode="train"),
    batch_size=16,
    shuffle=True,
    pin_memory=True,
    num_workers=0,
    drop_last=True
)

DIV2K_val_loader = DataLoader(
    DIV2K_Dataset(transforms_=transform_A, mode="val"),
    batch_size=16,
    shuffle=False,
    pin_memory=True,
    num_workers=0,
    drop_last=True
)

**使用 albumentations 进行变换**

In [34]:
class DIV2K_Dataset(Dataset):
    def __init__(self, transforms_=None, mode='train'):
        self.transform = transforms_
        self.mode = mode
        if mode == 'train':
            self.files = natsorted(sorted(glob.glob(DIV2K_path+"/train"+"/*."+"png")))
        else:
            self.files = natsorted(
                sorted(glob.glob(DIV2K_path+"/valid"+"/*."+"png")))

    def __getitem__(self, index):
        # try:
        image = Image.open(self.files[index]).convert(
            'RGB')  # 使用 PIL 读取并转为 RGB
        image = np.array(image) # 转为 np 数组
        transformed_item = self.transform(image=image)
        # item=T.ToTensor(transformed_item['image'])
        return transformed_item['image']
        # except:
        #     return self.__getitem__(index+1)  # 这个不理解

    def __len__(self):
        return len(self.files)


DIV2K_train_loader = DataLoader(
    DIV2K_Dataset(transforms_=transform_A, mode="train"),
    batch_size=16,
    shuffle=True,
    pin_memory=True,
    num_workers=0,
    drop_last=True
)

DIV2K_val_loader = DataLoader(
    DIV2K_Dataset(transforms_=transform_A, mode="val"),
    batch_size=16,
    shuffle=False,
    pin_memory=True,
    num_workers=0,
    drop_last=True
)

使用 tensorboard 来查看数据

In [35]:
writer = SummaryWriter("DIV2K_train_dataloader")
step =0
for data in DIV2K_train_loader:
    imgs=data
    # print(imgs.type())
    writer.add_images("train_data_droplast",imgs,step)
    step+=1
    print(step)

writer.close()
    

KeyboardInterrupt: 

# Model

## 库引入

In [1]:
import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

# Feed Forward(MLP)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

# self attention


class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# transformer block


class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])  # 存储 Transformer 的每一个块
        for _ in range(depth):  # 看要堆叠多少个 Transformer blcok
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),  # self attention
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))    # mlp
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x


class model1_encoder(nn.Module):
    def __init__(self, *, image_size, patch_size, dim, extract_dim, depth, heads, mlp_dim, channels=3, dim_head=64, dropout=0., emb_dropout=0.):
        """
            image_size: 输入图像大小
            patch_size: patch size
            dim: token's length
            extract_dim: feature extractor output's dim
            depth: numbers of transformer block
            heads: numbers of multi-attention head
            mlp_dim: mlp's dim
            dim_head: one of the head's dim
        """
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'  # 防止划分不完全

        num_patches = (image_height//patch_height)*(image_width//patch_width)
        self.proportion = image_height//patch_height  # 记录 patch 大小和图像大小之间的比例，以便后面进行转换
        self.patch_size = patch_height
        patch_dim = extract_dim*patch_height*patch_width   # 每个 patch 看成是一个 Token，其维度为 H'xW'xC
        self.to_patch_embedding = nn.Sequential(    # 输入进行线性嵌入，似乎有点不一样？
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),  # 转为 Token
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))  # 位置嵌入是可学习的？
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.feature_extractor = nn.Conv2d(in_channels=channels*2, out_channels=extract_dim, kernel_size=3, stride=1, padding=1)  # 使用 CNN 初步提取特征

        self.to_latent = nn.Identity()    # 这是啥

    def forward(self, img):
        extract_feature = self.feature_extractor(img)
        x = self.to_patch_embedding(extract_feature)
        b, n, _ = x.shape
        x += self.pos_embedding[:, :n]
        x = self.dropout(x)

        x = self.transformer(x)  # (B,H'xW',dim)

        x = rearrange(x, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=self.proportion, p1=self.patch_size, p2=self.patch_size)  # 转化为图像表示，(B,C,H,W)

        return x

class model1_decoder(nn.Module):
    def __init__(self, *, image_size, patch_size, dim, depth=1, heads, mlp_dim, channels=3, dim_head=64, dropout=0., emb_dropout=0.):
        """
            image_size: 输入图像大小
            patch_size: patch size
            dim: token's length
            extract_dim: feature extractor output's dim
            depth: numbers of transformer block
            heads: numbers of multi-attention head
            mlp_dim: mlp's dim
            dim_head: one of the head's dim
        """
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'  # 防止划分不完全

        num_patches = (image_height//patch_height)*(image_width//patch_width)
        self.proportion = image_height//patch_height  # 记录 patch 大小和图像大小之间的比例，以便后面进行转换
        self.patch_size = patch_height
        patch_dim = channels*patch_height*patch_width   # 每个 patch 看成是一个 Token，其维度为 H'xW'xC
        self.to_patch_embedding = nn.Sequential(    # 输入进行线性嵌入，似乎有点不一样？
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),  # 转为 Token
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))  # 位置嵌入是可学习的？
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.to_latent = nn.Identity()    # 这是啥
        
    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape
        x += self.pos_embedding[:, :n]
        x = self.dropout(x)

        x = self.transformer(x)  # (B,H'xW',dim)

        x = rearrange(x, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=self.proportion, p1=self.patch_size, p2=self.patch_size)  # 转化为图像表示，(B,C,H,W)

        return x
    
test_model_encoder=model1_encoder(image_size=128,patch_size=16,dim=768,extract_dim=128,depth=6,heads=8,mlp_dim=1024,)   # dim 是 token 维度，dim=(patch_height x patch_width x channel)
test_model_decoder=model1_decoder(image_size=128,patch_size=16,dim=768,depth=1,heads=8,mlp_dim=1024)
input_data = torch.rand((1,6,128,128))
output=test_model_encoder(input_data)
output=test_model_decoder(output)
print(output.shape)

torch.Size([1, 3, 128, 128])


### v0.2 模块化测试
1. 双线模块

In [19]:
class stg_block(nn.Module):
    def __init__(self, *, image_size, patch_size, dim, heads, mlp_dim, channels=3, dim_head=64, dropout=0., emb_dropout=0.):
        """
            image_size: 输入图像大小
            patch_size: patch size
            dim: token's length
            heads: numbers of multi-attention head
            mlp_dim: mlp's dim
            dim_head: one of the head's dim
            channels:每张图片的通道数！！！
        """
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        num_patches = (image_height//patch_height)*(image_width//patch_width)
        self.proportion = image_height//patch_height  # 记录 patch 大小和图像大小之间的比例，以便后面进行转换
        self.patch_size = patch_height

        self.transformer_block = nn.Sequential(
            PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
            PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
        )
        self.cnn_blcok = nn.Sequential(
            nn.Conv2d(in_channels=2*channels, out_channels=4*channels, kernel_size=3, stride=2, padding=1),  # 特征通道数翻倍，分辨率减半
            nn.BatchNorm2d(num_features=4*channels),
            nn.ConvTranspose2d(in_channels=4*channels, out_channels=2*channels, kernel_size=3, stride=2, padding=1,output_padding=1)   # 恢复分辨率
        )

    def forward(self, img):
        input_img = img   # (B,2C,H,W)
        img_token = rearrange(input_img, 'b (n c) (h p1) (w p2) -> b (n h w) (p1 p2 c)',n=2,p1=self.patch_size,p2=self.patch_size)  # 转为 Token（B,2N,L)
        transformer_output=self.transformer_block(img_token)
        transformer_output=rearrange(transformer_output,'b (n w) l -> b n w l',n=2)
        transformer_output=rearrange(transformer_output,'b n (h w) (p1 p2 c) -> b (n c) (h p1) (w p2)', h=self.proportion,p1=self.patch_size,p2=self.patch_size)    # 转为图像表示

        cnn_input = img
        cnn_output = self.cnn_blcok(cnn_input)

        res_output=cnn_output+transformer_output+img
        return res_output

**实例化测试**

In [27]:
stg_b=stg_block(image_size=128,patch_size=16,dim=2048,heads=8,mlp_dim=1024,channels=8,dim_head=256)
input_test=torch.randn((1,16,128,128))
test=input_test

img_divide = rearrange(input_test, 'b (n c) (h p1) (w p2) -> b (n h w) (p1 p2 c)',n=2,p1=16,p2=16)  # 转为 Token（B,2N,L)
img_divide=rearrange(img_divide,'b (n h w) (p1 p2 c) -> b (n c) (h p1) (w p2)', n=2,h=8,p1=16,p2=16)    # 转为图像表示
print(img_divide == test)


tensor([[[[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True]],

         [[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True]],

         [[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ...

**实例化测试**


In [None]:
test_model_encoder=model1_encoder(image_size=128,patch_size=16,dim=768,extract_dim=128,depth=6,heads=8,mlp_dim=1024,)   # dim 是 token 维度，dim=(patch_height x patch_width x channel)
test_model_decoder=model1_decoder(image_size=128,patch_size=16,dim=768,depth=1,heads=8,mlp_dim=1024)
input_data = torch.rand((1,6,128,128))
output=test_model_encoder(input_data)
output=test_model_decoder(output)
print(output.shape)

torch.Size([1, 3, 128, 128])


# Train

In [5]:
import datasets
import torchvision

使用单张图片让模型过拟合以测试性能

In [3]:
import torch
import torch.nn as nn
import torch.optim
import math
import numpy as np
from model import *
from tensorboardX import SummaryWriter
import datasets
from tqdm import tqdm
from PIL import Image
import torchvision.transforms as T

transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(512),
    T.ToTensor()
])

# 以类的方式定义参数
class Args:
    def __init__(self) -> None:
        self.batch_size = 1
        self.image_size = 256
        self.patch_size = 16
        self.lr = 1e-3
        self.epochs = 10
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args = Args()

# 设置随机种子
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# 损失函数
def steg_loss(img1,img2):
    loss_fn = torch.nn.MSELoss(reduce=True,size_average=False)
    loss = loss_fn(img1,img2)
    return loss.to(args.device)

def reconstruction_loss(img1,img2):
    loss_fn = torch.nn.MSELoss(reduce=True,size_average=False)
    loss = loss_fn(img1,img2)
    return loss.to(args.device)

# tensorboard
writer = SummaryWriter(log_dir='logs')

# 模型初始化
encoder = model1_encoder(image_size=512, patch_size=16, dim=768, extract_dim=128, depth=6, heads=8, mlp_dim=1024,)   # dim 是 token 维度，dim=(patch_height x patch_width x channel)
decoder = model1_decoder(image_size=512, patch_size=16, dim=768, depth=1, heads=8, mlp_dim=1024)
encoder.cuda()
decoder.cuda()

# 用于过拟合的两张图像
cover=Image.open("F:/dataset/test/cover/0802.png").convert('RGB')
secret=Image.open("F:/dataset/test/secret/0801.png").convert('RGB')
cover=transform(cover)
secret=transform(secret)
cover=torch.unsqueeze(cover,0)
secret=torch.unsqueeze(secret,0)

# 优化器
optim = torch.optim.AdamW([{'params': encoder.parameters()}, {'params': decoder.parameters()}], lr=args.lr)

使用单张图片进行过拟合测试

In [6]:
# 过拟合代码
for i_epooch in range(3000):
    cover = cover.to(args.device)
    secret = secret.to(args.device)
    input_img = torch.cat((cover, secret), 1)

    # encode
    encode_img = encoder(input_img)
    
    # decode
    decode_img = decoder(encode_img)

    if(i_epooch%100==0):
        torchvision.utils.save_image(encode_img,"encode_img_"+str(i_epooch)+".png",normalize=True)
        torchvision.utils.save_image(decode_img,"decode_img_"+str(i_epooch)+".png",normalize=True)


    h_loss=steg_loss(cover.cuda(),encode_img.cuda())
    r_loss=reconstruction_loss(secret.cuda(),decode_img.cuda())
    total_loss=h_loss+r_loss
    print("encode_loss:"+str(h_loss.item())+", recontruct_loss: "+str(r_loss.item()))

    total_loss.backward()
    optim.step()
    optim.zero_grad()



encode_loss:2088243.75, recontruct_loss: 1855798.0
encode_loss:2092841.0, recontruct_loss: 1811178.75
encode_loss:1528301.0, recontruct_loss: 1750252.25
encode_loss:1334093.5, recontruct_loss: 1693350.75
encode_loss:1148063.0, recontruct_loss: 1626662.75
encode_loss:1072494.125, recontruct_loss: 1539712.5
encode_loss:980572.9375, recontruct_loss: 1453182.75
encode_loss:929544.3125, recontruct_loss: 1384761.25
encode_loss:884019.5625, recontruct_loss: 1317560.75
encode_loss:858181.375, recontruct_loss: 1254858.25
encode_loss:944178.5, recontruct_loss: 1295881.75
encode_loss:875073.375, recontruct_loss: 1251150.75
encode_loss:839129.875, recontruct_loss: 1054337.25
encode_loss:819080.875, recontruct_loss: 952952.375
encode_loss:793683.1875, recontruct_loss: 879189.1875
encode_loss:744834.5625, recontruct_loss: 847020.8125
encode_loss:715429.25, recontruct_loss: 809703.625
encode_loss:677606.875, recontruct_loss: 783306.875
encode_loss:647172.75, recontruct_loss: 755442.75
encode_loss:615

KeyboardInterrupt: 

使用训练集进行训练测试


In [None]:
# 训练代码
for i_epooch in range(3000):
    for i_batch, (cover, secret) in enumerate(zip(datasets.DIV2K_train_cover_loader, datasets.DIV2K_train_secret_loader)) :
        cover = cover.to(args.device)
        secret = secret.to(args.device)
        input_img = torch.cat((cover, secret), 1)
        input_img=input_img.to(args.device)
        # print(input_img.shape)

        # encode
        encode_img = encoder(input_img)
        
        # decode
        decode_img = decoder(encode_img)

        h_loss=steg_loss(cover.cuda(),encode_img.cuda())
        r_loss=reconstruction_loss(secret.cuda(),decode_img.cuda())
        total_loss=h_loss+r_loss
        if(i_batch==0):
            print(total_loss)

        total_loss.backward()
        optim.step()
        optim.zero_grad()

# 测试指标

In [16]:
import numpy as np
from skimage.metrics import *

def calculate_ssim(img1,img2):
    img1=np.array(img1).astype(np.float64)*255
    img2=np.array(img2).astype(np.float64)*255
    img1=np.clip(img1,0,255)
    img2=np.clip(img2,0,255)
    ssim_score=structural_similarity(img1,img2,channel_axis=1)
    return ssim_score
    

def calculate_psnr(img1, img2):
    # 转为 float64 防止精度丢失
    test1=np.array(img1).astype(np.float64)
    test2=np.array(img2).astype(np.float64)
    return peak_signal_noise_ratio(test1,test2)



In [17]:
import torch
test1=torch.rand((16,3,16,16))
test2=torch.rand((16,3,16,16))

calculate_psnr(test1,test2)
# calculate_ssim(test1,test1)


7.781162383004529