### Code taken from https://github.com/albblgb/Deep-Steganalysis

### Importing necessary libraries and modules

In [17]:
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
import numpy as np
from torch.nn.parameter import Parameter
import torch
import torch.nn as nn
import torch.optim as optim
import os
from torch.optim.lr_scheduler import MultiStepLR
from tqdm import tqdm
import logging
import numpy as np
import math
from torchvision.utils import save_image
from tqdm.contrib import tzip

import torch
import numpy as np
import torchvision
import os
import glob
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as T


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Dataloader and Dataset

In [18]:
# Dataset and DataLoader definitions
class dataset_(Dataset):
    def __init__(self, cover_dir, stego_dir, transform):
        self.cover_dir = cover_dir
        self.stego_dir = stego_dir
        self.transforms = transform
        self.cover_filenames = list(sorted(os.listdir(cover_dir)))
        self.stego_filenames = list(sorted(os.listdir(stego_dir)))
    
    def __len__(self):
        return len(self.cover_filenames)
    
    def __getitem__(self, index):
        cover_paths = os.path.join(self.cover_dir, self.cover_filenames[index])
        stego_paths = os.path.join(self.stego_dir, self.stego_filenames[index])
        cover_img = Image.open(cover_paths).convert("RGB")
        stego_img = Image.open(stego_paths).convert("RGB")
        if self.transforms:
            cover_img = self.transforms(cover_img)
            stego_img = self.transforms(stego_img)
        cover_label = torch.tensor(0, dtype=torch.long)
        stego_label = torch.tensor(1, dtype=torch.long)
        sample = {"cover": cover_img, "stego": stego_img}
        sample["label"] = [cover_label, stego_label]
        return sample

# Data transforms
transformation = T.Compose([
    T.ToTensor()
])

In [21]:
# Data loaders
def get_train_loader(data_dir, batchsize=4, subset_size=None):
    dataset = dataset_(os.path.join(data_dir, 'cover'), os.path.join(data_dir, 'stego'), transformation)
    
    if subset_size is not None and subset_size < len(dataset):
        # Create a random subset of specified size
        indices = np.random.choice(len(dataset), subset_size, replace=False)
        dataset = Subset(dataset, indices)
    
    train_loader = DataLoader(
        dataset,
        batch_size=batchsize,
        shuffle=True,
        pin_memory=True,
        drop_last=True
    )
    return train_loader

def get_val_loader(data_dir, batchsize=4):
    val_loader = DataLoader(
        dataset_(os.path.join(data_dir, 'cover'), os.path.join(data_dir, 'stego'), transformation),
        batch_size=batchsize,
        shuffle=True,
        pin_memory=False,
        drop_last=False
    )
    return val_loader

def get_test_loader(data_dir, batch_size):
    test_sets = ImageFolder(root=data_dir, transform=transformation)
    test_loader = DataLoader(test_sets, batch_size=batch_size, shuffle=False, num_workers=2, drop_last=False)
    return test_loader

### SRNet

In [3]:
class ConvBn(nn.Module):
    """Provides utility to create different types of layers."""

    def __init__(self, in_channels: int, out_channels: int) -> None:
        """Constructor.
        Args:
            in_channels (int): no. of input channels.
            out_channels (int): no. of output channels.
        """
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=False,
        )
        self.batch_norm = nn.BatchNorm2d(out_channels)

    def forward(self, inp: Tensor) -> Tensor:
        """Returns Conv2d followed by BatchNorm.

        Returns:
            Tensor: Output of Conv2D -> BN.
        """
        return self.batch_norm(self.conv(inp))


class Type1(nn.Module):
    """Creates type 1 layer of SRNet."""

    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.convbn = ConvBn(in_channels, out_channels)
        self.relu = nn.ReLU()

    def forward(self, inp: Tensor) -> Tensor:
        """Returns type 1 layer of SRNet.
        Args:
            inp (Tensor): input tensor.
        Returns:
            Tensor: Output of type 1 layer.
        """
        return self.relu(self.convbn(inp))


class Type2(nn.Module):
    """Creates type 2 layer of SRNet."""

    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.type1 = Type1(in_channels, out_channels)
        self.convbn = ConvBn(in_channels, out_channels)

    def forward(self, inp: Tensor) -> Tensor:
        """Returns type 2 layer of SRNet.
        Args:
            inp (Tensor): input tensor.
        Returns:
            Tensor: Output of type 2 layer.
        """
        return inp + self.convbn(self.type1(inp))


class Type3(nn.Module):
    """Creates type 3 layer of SRNet."""

    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=1,
            stride=2,
            padding=0,
            bias=False,
        )
        self.batch_norm = nn.BatchNorm2d(out_channels)
        self.type1 = Type1(in_channels, out_channels)
        self.convbn = ConvBn(out_channels, out_channels)
        self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)

    def forward(self, inp: Tensor) -> Tensor:
        """Returns type 3 layer of SRNet.
        Args:
            inp (Tensor): input tensor.

        Returns:
            Tensor: Output of type 3 layer.
        """
        out = self.batch_norm(self.conv1(inp))
        out1 = self.pool(self.convbn(self.type1(inp)))
        return out + out1


class Type4(nn.Module):
    """Creates type 4 layer of SRNet."""

    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.type1 = Type1(in_channels, out_channels)
        self.convbn = ConvBn(out_channels, out_channels)
        self.gap = nn.AdaptiveAvgPool2d(output_size=1)

    def forward(self, inp: Tensor) -> Tensor:
        """Returns type 4 layer of SRNet.
        Args:
            inp (Tensor): input tensor.
        Returns:
            Tensor: Output of type 4 layer.
        """
        return self.gap(self.convbn(self.type1(inp)))



class SRNet(nn.Module):
    """This is SRNet model class."""

    def __init__(self) -> None:
        """Constructor."""
        super().__init__()
        self.type1s = nn.Sequential(Type1(3, 64), Type1(64, 16))
        self.type2s = nn.Sequential(
            Type2(16, 16),
            Type2(16, 16),
            Type2(16, 16),
            Type2(16, 16),
            Type2(16, 16),
        )
        self.type3s = nn.Sequential(
            Type3(16, 16),
            Type3(16, 64),
            Type3(64, 128),
            Type3(128, 256),
        )
        self.type4 = Type4(256, 512)
        self.dense = nn.Linear(512, 2)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, inp: Tensor) -> Tensor:
        """Returns logits for input images.
        Args:
            inp (Tensor): input image tensor of shape (Batch, stego_img_channel, stego_img_height, stego_img_width)
        Returns:
            Tensor: Logits of shape (Batch, 2)
        """
 
        out = self.type1s(inp)
        out = self.type2s(out)
        out = self.type3s(out)
        out = self.type4(out)
        out = out.view(out.size(0), -1)
        out = self.dense(out)
        return out

### XuNet

In [None]:
class ImageProcessing(nn.Module):
    """Computes convolution with KV filter over the input tensor."""

    def __init__(self) -> None:
        """Constructor"""

        super().__init__()
        # pylint: disable=E1101
        self.kv_filter = nn.Parameter(
            torch.tensor(
                [
                    [-1.0, 2.0, -2.0, 2.0, -1.0],
                    [2.0, -6.0, 8.0, -6.0, 2.0],
                    [-2.0, 8.0, -12.0, 8.0, -2.0],
                    [2.0, -6.0, 8.0, -6.0, 2.0],
                    [-1.0, 2.0, -2.0, 2.0, -1.0],
                ],
            ).view(1, 1, 5, 5)
            / 12.0
        )  # pylint: enable=E1101

    def forward(self, inp: Tensor) -> Tensor:
        """Returns tensor convolved with KV filter"""

        for i in range(3):
            if i == 0:
                features = F.conv2d(inp[:, i, :, :].unsqueeze(dim=1), self.kv_filter, stride=1, padding=2)
            else:
                features = torch.cat((features, F.conv2d(inp[:, i, :, :].unsqueeze(dim=1), self.kv_filter, stride=1, padding=2)), dim=1)
            
        return features


class ConvBlock(nn.Module):
    """This class returns building block for XuNet class."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        activation: str = "relu",
        abs: str = False,
    ) -> None:
        super().__init__()

        if kernel_size == 5:
            self.padding = 2
        else:
            self.padding = 0

        if activation == "tanh":
            self.activation = nn.Tanh()
        else:
            self.activation = nn.ReLU()

        self.abs = abs
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=1,
            padding=self.padding,
            bias=False,
        )
        self.batch_norm = nn.BatchNorm2d(out_channels)
        self.pool = nn.AvgPool2d(kernel_size=5, stride=2, padding=2)

    def forward(self, inp: Tensor) -> Tensor:
        """Returns conv->batch_norm."""
        if self.abs:
            return self.pool(
                self.activation(self.batch_norm(torch.abs(self.conv(inp))))
            )
        return self.pool(self.activation(self.batch_norm(self.conv(inp))))


class XuNet(nn.Module):
    """This class returns XuNet model."""

    def __init__(self) -> None:
        super().__init__()

        self.ImageProcessingLayer = ImageProcessing()
        self.layer1 = ConvBlock(
            3, 8, kernel_size=5, activation="tanh", abs=True
        )
        self.layer2 = ConvBlock(8, 16, kernel_size=5, activation="tanh")
        self.layer3 = ConvBlock(16, 32, kernel_size=1)
        self.layer4 = ConvBlock(32, 64, kernel_size=1)
        self.layer5 = ConvBlock(64, 128, kernel_size=1)
        self.gap = nn.AdaptiveAvgPool2d(output_size=1)
        self.fully_connected = nn.Sequential(
            nn.Linear(in_features=128, out_features=128),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=128, out_features=2),
        )

    def forward(self, image: Tensor) -> Tensor:
        """Returns logit for the given tensor."""
        with torch.no_grad():
            out = self.ImageProcessingLayer(image)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = self.gap(out)
        out = out.view(out.size(0), -1)
        out = self.fully_connected(out)
        return out

### YeNet

In [None]:
SRM_npy = np.load('/kaggle/input/models/SRM_Kernels.npy')

class SRM_conv2d(nn.Module):
    def __init__(self, stride=1, padding=0):
        super(SRM_conv2d, self).__init__()
        self.in_channels = 3
        self.out_channels = 30
        self.kernel_size = (5, 5)
        if isinstance(stride, int):
            self.stride = (stride, stride)
        else:
            self.stride = stride
        if isinstance(padding, int):
            self.padding = (padding, padding)
        else:
            self.padding = padding
        self.dilation = (1,1)
        self.transpose = False
        self.output_padding = (0,)
        self.groups = 1
        self.weight = Parameter(torch.Tensor(30, self.in_channels, 5, 5), \
                                requires_grad=True)
        self.bias = Parameter(torch.Tensor(30), \
                              requires_grad=True)
        self.reset_parameters()

    def reset_parameters(self):
        self.weight.data.numpy()[:] = SRM_npy
        self.bias.data.zero_()

    def forward(self, input):
        return F.conv2d(input, self.weight, self.bias, \
                        self.stride, self.padding, self.dilation, \
                        self.groups)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, \
                 stride=1, with_bn=False):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, \
                              stride)
        self.relu = nn.ReLU()
        self.with_bn = with_bn
        if with_bn:
            self.norm = nn.BatchNorm2d(out_channels)
        else:
            self.norm = lambda x: x
        self.reset_parameters()

    def forward(self, x):
        return self.norm(self.relu(self.conv(x)))

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.conv.weight)
        self.conv.bias.data.fill_(0.2)
        if self.with_bn:
            self.norm.reset_parameters()

class YeNet(nn.Module):
    def __init__(self, with_bn=False, threshold=3):
        super(YeNet, self).__init__()
        self.with_bn = with_bn
        self.preprocessing = SRM_conv2d(1, 0)
        self.TLU = nn.Hardtanh(-threshold, threshold, True)
        if with_bn:
            self.norm1 = nn.BatchNorm2d(30)
        else:
            self.norm1 = lambda x: x
        self.block2 = ConvBlock(30, 30, 3, with_bn=self.with_bn)
        self.block3 = ConvBlock(30, 30, 3, with_bn=self.with_bn)
        self.block4 = ConvBlock(30, 30, 3, with_bn=self.with_bn)
        self.pool1 = nn.AvgPool2d(2, 2)
        self.block5 = ConvBlock(30, 32, 5, with_bn=self.with_bn)
        self.pool2 = nn.AvgPool2d(3, 2)
        self.block6 = ConvBlock(32, 32, 5, with_bn=self.with_bn)
        self.pool3 = nn.AvgPool2d(3, 2)
        self.block7 = ConvBlock(32, 32, 5, with_bn=self.with_bn)
        self.pool4 = nn.AvgPool2d(3, 2)
        self.block8 = ConvBlock(32, 16, 3, with_bn=self.with_bn)
        self.block9 = ConvBlock(16, 16, 3, 3, with_bn=self.with_bn)
        self.num_of_neurons = 144 
       
        self.ip1 = nn.Linear(self.num_of_neurons, 2)
        self.reset_parameters()

    def forward(self, x):
        x = x.float()
        x = self.preprocessing(x)
        x = self.TLU(x)
        x = self.norm1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.pool1(x)
        x = self.block5(x)
        x = self.pool2(x)
        x = self.block6(x)
        x = self.pool3(x)
        x = self.block7(x)
        x = self.pool4(x)
        x = self.block8(x)
        x = self.block9(x)
        x = x.view(x.size(0), -1)
        x = self.ip1(x)
        return x

    def reset_parameters(self):
        for mod in self.modules():
            if isinstance(mod, SRM_conv2d) or \
                    isinstance(mod, nn.BatchNorm2d) or \
                    isinstance(mod, ConvBlock):
                mod.reset_parameters()
            elif isinstance(mod, nn.Linear):
                nn.init.normal_(mod.weight, 0. ,0.01)
                mod.bias.data.zero_()

def accuracy(outputs, labels):
    _, argmax = torch.max(outputs, 1)
    return (labels == argmax.squeeze()).float().mean()

### Train function

In [None]:
def train_model(model, train_loader, val_loader, epochs, lr=2e-4, weight_decay=1e-5, device='cuda'):
    """Train the model with given parameters"""
    
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
    
    best_val_acc = 0
    best_model_state = None
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_losses = []
        train_accuracies = []
        
        for train_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} - Training"):
            inputs = torch.cat((train_batch["cover"], train_batch["stego"]), 0)
            labels = torch.cat((train_batch["label"][0], train_batch["label"][1]), 0)
            inputs = inputs.to(device, dtype=torch.float)
            labels = labels.to(device, dtype=torch.long)

            outputs = model(inputs)
            loss = loss_fn(outputs, labels)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            train_losses.append(loss.item())
            prediction = outputs.data.max(1)[1]
            accuracy = (prediction.eq(labels.data).sum() * 100.0 / (labels.size()[0]))
            train_accuracies.append(accuracy.item())

        scheduler.step()  
        train_loss_avg = np.mean(np.array(train_losses))
        train_acc_avg = np.mean(np.array(train_accuracies))  
        
        print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss_avg:.4f}, Train Acc: {train_acc_avg:.4f}")

        # Validation
        model.eval()
        val_losses = []
        val_accuracies = []
        
        with torch.no_grad():
            for val_batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} - Validating"):
                inputs = torch.cat((val_batch["cover"], val_batch["stego"]), 0)
                labels = torch.cat((val_batch["label"][0], val_batch["label"][1]), 0)
                inputs = inputs.to(device, dtype=torch.float)
                labels = labels.to(device, dtype=torch.long)

                outputs = model(inputs)
                loss = loss_fn(outputs, labels)

                val_losses.append(loss.item())
                prediction = outputs.data.max(1)[1]
                accuracy = (prediction.eq(labels.data).sum() * 100.0 / (labels.size()[0]))
                val_accuracies.append(accuracy.item())

            val_loss_avg = np.mean(np.array(val_losses))
            val_acc_avg = np.mean(np.array(val_accuracies))
            
            print(f"Epoch {epoch+1}/{epochs} - Val Loss: {val_loss_avg:.4f}, Val Acc: {val_acc_avg:.4f}")
            
            # Save best model
            if val_acc_avg > best_val_acc:
                best_val_acc = val_acc_avg
                best_model_state = model.state_dict().copy()
                print(f"New best model with val acc: {best_val_acc:.4f}")
    
    return best_model_state, best_val_acc


### Test function

In [None]:
def test_model(model, test_loader, device='cuda'):
    """Test the model on test dataset"""
    model.eval()
    test_accuracies = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc="Testing"):
            inputs = inputs.to(device, dtype=torch.float)
            labels = labels.to(device, dtype=torch.long)

            outputs = model(inputs)
            prediction = outputs.data.max(1)[1]
            accuracy = (prediction.eq(labels.data).sum() * 100.0 / (labels.size()[0]))
            test_accuracies.append(accuracy.item())
    
    test_acc_avg = np.mean(np.array(test_accuracies))
    print(f'Test Accuracy: {test_acc_avg:.4f}%')
    
    return test_acc_avg

### Incremental training function

In [23]:
# Main incremental training function
def incremental_training(data_dir, model_type='SRNet', start_size=10, max_size=100, step_size=10,
                         epochs=100, batch_size=4, val_batch_size=4, test_batch_size=16,
                         device='cuda'):
    """
    Train models incrementally with increasing dataset sizes and evaluate on test set
    
    Args:
        data_dir: Directory containing train, val, and test data
        model_type: String specifying which model to use ('SRNet', 'XuNet', or 'YeNet')
        start_size: Initial number of cover-stego pairs to use
        max_size: Maximum number of cover-stego pairs to use
        step_size: Increment size for dataset expansion
        epochs: Number of training epochs per dataset size
        batch_size: Batch size for training
        val_batch_size: Batch size for validation
        test_batch_size: Batch size for testing
        device: Device to use for training ('cuda' or 'cpu')
    
    Returns:
        Dictionary of test accuracies for each dataset size
    """
    train_dir = os.path.join(data_dir, 'train')
    val_dir = os.path.join(data_dir, 'val')
    test_dir = os.path.join(data_dir, 'test')
    
    # Prepare validation and test loaders
    val_loader = get_val_loader(val_dir, val_batch_size)
    test_loader = get_test_loader(test_dir, test_batch_size)
    
    # Dictionary to store results
    test_accuracies = {}
    
    # Loop through different dataset sizes
    for num_pairs in range(start_size, max_size + step_size, step_size):
        print(f"\n{'='*50}")
        print(f"Training {model_type} with {num_pairs} cover-stego pairs")
        print(f"{'='*50}")
        
        # Get training loader with specified subset size
        train_loader = get_train_loader(train_dir, batch_size, subset_size=num_pairs)
        
        # Initialize the specified model for each run
        if model_type == 'SRNet':
            model = SRNet().to(device)
        elif model_type == 'XuNet':
            model = XuNet().to(device)
        elif model_type == 'YeNet':
            model = YeNet().to(device)
        else:
            raise ValueError(f"Unknown model type: {model_type}. Choose from 'SRNet', 'XuNet', or 'YeNet'")
        
        # Train the model
        best_model_state, best_val_acc = train_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            epochs=epochs,
            device=device
        )
        
        # Load best model for testing
        model.load_state_dict(best_model_state)
        
        # Test the model
        test_acc = test_model(model, test_loader, device)
        
        # Store results
        test_accuracies[num_pairs] = test_acc
        
        # Print results so far
        print(f"\nTest Accuracies for {model_type} so far:")
        for size, acc in test_accuracies.items():
            print(f"{size} pairs: {acc:.4f}%")
    
    return test_accuracies

In [24]:
def run_incremental_training(data_dir, model_type='SRNet', start_size=10, max_size=100, step_size=10, epochs=100):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    print(f"Running incremental training for {model_type} model")
    
    accuracies = incremental_training(
        data_dir=data_dir,
        model_type=model_type,
        start_size=start_size,
        max_size=max_size,
        step_size=step_size,
        epochs=epochs,
        device=device
    )
    
    print(f"\n=== Final Results for {model_type} ===")
    for size, acc in sorted(accuracies.items()):
        print(f"{size} pairs: {acc:.4f}%")
        
    return accuracies

In [None]:
# Example of how to run for all three models
def run_all_models(data_dir, start_size=10, max_size=100, step_size=10, epochs=100):
    model_types = ['SRNet', 'XuNet', 'YeNet']
    all_results = {}
    
    for model_type in model_types:
        print(f"\n\n{'#'*60}")
        print(f"# Starting evaluation for {model_type}")
        print(f"{'#'*60}\n")
        
        results = run_incremental_training(
            data_dir=data_dir,
            model_type=model_type,
            start_size=start_size,
            max_size=max_size,
            step_size=step_size,
            epochs=epochs
        )
        
        all_results[model_type] = results
    
    return all_results

In [None]:
final_results = run_all_models()

In [None]:
print(final_results)