In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/bossbase/boss_256_0.4/cover/4353.png
/kaggle/input/bossbase/boss_256_0.4/cover/7968.png
/kaggle/input/bossbase/boss_256_0.4/cover/6490.png
/kaggle/input/bossbase/boss_256_0.4/cover/5511.png
/kaggle/input/bossbase/boss_256_0.4/cover/6262.png
/kaggle/input/bossbase/boss_256_0.4/cover/2664.png
/kaggle/input/bossbase/boss_256_0.4/cover/8419.png
/kaggle/input/bossbase/boss_256_0.4/cover/2539.png
/kaggle/input/bossbase/boss_256_0.4/cover/5703.png
/kaggle/input/bossbase/boss_256_0.4/cover/1231.png
/kaggle/input/bossbase/boss_256_0.4/cover/1017.png
/kaggle/input/bossbase/boss_256_0.4/cover/4803.png
/kaggle/input/bossbase/boss_256_0.4/cover/7197.png
/kaggle/input/bossbase/boss_256_0.4/cover/2437.png
/kaggle/input/bossbase/boss_256_0.4/cover/7530.png
/kaggle/input/bossbase/boss_256_0.4/cover/3217.png
/kaggle/input/bossbase/boss_256_0.4/cover/5695.png
/kaggle/input/bossbase/boss_256_0.4/cover/8565.png
/kaggle/input/bossbase/boss_256_0.4/cover/6441.png
/kaggle/input/bossbase/boss_256

In [5]:
# Install required packages
!pip install torch torchvision torchsummary pillow numpy scipy matplotlib scikit-learn tqdm

# Import libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, confusion_matrix
import time
from tqdm import tqdm
import zipfile
import warnings
import shutil
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Create directories
!mkdir -p /kaggle/working/data/cover
!mkdir -p /kaggle/working/data/stego
!mkdir -p /kaggle/working/models
!mkdir -p /kaggle/working/results

# ==================== DATASET ORGANIZATION ====================
def organize_bossbase_numerically(input_path, cover_output, stego_output):
    """
    Organize BOSSBase dataset using numerical ordering with proper naming
    """
    # Find all image files
    all_image_files = []
    for root, dirs, files in os.walk(input_path):
        for file in files:
            if file.endswith(('.pgm', '.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                all_image_files.append(os.path.join(root, file))
    
    print(f"Found {len(all_image_files)} image files")
    
    # Sort files numerically (not alphabetically)
    def numerical_sort_key(filepath):
        filename = os.path.basename(filepath)
        # Extract numbers from filename
        numbers = ''.join(filter(str.isdigit, filename))
        return int(numbers) if numbers else 0
    
    all_image_files.sort(key=numerical_sort_key)
    
    # Split files equally - first half as cover, second half as stego
    half_point = len(all_image_files) // 2
    cover_files = all_image_files[:half_point]
    stego_files = all_image_files[half_point:2*half_point]  # Ensure equal length
    
    print(f"Split into {len(cover_files)} cover and {len(stego_files)} stego images")
    
    # Clear existing directories
    shutil.rmtree(cover_output, ignore_errors=True)
    shutil.rmtree(stego_output, ignore_errors=True)
    os.makedirs(cover_output, exist_ok=True)
    os.makedirs(stego_output, exist_ok=True)
    
    # Copy cover images with consistent naming
    for i, img_path in enumerate(cover_files):
        filename = f"{i+1:05d}.png"  # Format as 00001.png, 00002.png, etc.
        destination = os.path.join(cover_output, filename)
        # Copy and convert if needed
        img = Image.open(img_path).convert('RGB')
        img.save(destination)
    
    # Copy stego images with consistent naming (same as cover)
    for i, img_path in enumerate(stego_files):
        filename = f"{i+1:05d}.png"  # Same naming convention as cover
        destination = os.path.join(stego_output, filename)
        img = Image.open(img_path).convert('RGB')
        img.save(destination)
    
    print(f"Organized {len(cover_files)} cover images and {len(stego_files)} stego images")
    return len(cover_files), len(stego_files)

# Organize the dataset
dataset_path = '/kaggle/input/bossbase'
cover_dir = '/kaggle/working/data/cover'
stego_dir = '/kaggle/working/data/stego'

cover_count, stego_count = organize_bossbase_numerically(dataset_path, cover_dir, stego_dir)

# Check the results
print("\nAfter organization:")
print(f"Cover images: {len(os.listdir(cover_dir))}")
print(f"Stego images: {len(os.listdir(stego_dir))}")

print("\nSample cover files:")
cover_files = sorted(os.listdir(cover_dir))
for f in cover_files[:5]:
    print(f"  {f}")

print("\nSample stego files:")
stego_files = sorted(os.listdir(stego_dir))
for f in stego_files[:5]:
    print(f"  {f}")

# Create simple 1:1 mapping (same filenames)
stego_to_cover_map = {}
for filename in cover_files:
    stego_to_cover_map[filename] = filename  # Same filename for both

print(f"\nCreated mapping for {len(stego_to_cover_map)} image pairs")
print("Sample mappings:")
for i, (cover, stego) in enumerate(list(stego_to_cover_map.items())[:5]):
    print(f"  {cover} → {stego}")

# Verify mapping is correct
print("\nVerifying mapping...")
all_good = True
for cover_file in cover_files[:10]:  # Check first 10
    stego_file = stego_to_cover_map[cover_file]
    if cover_file != stego_file:
        print(f"ERROR: {cover_file} → {stego_file}")
        all_good = False
    else:
        print(f"✓ {cover_file} → {stego_file}")

if all_good:
    print("✓ All mappings are correct!")
else:
    print("✗ Some mappings are incorrect!")

# ==================== SIMPLIFIED MODEL FOR TESTING ====================
print("\n" + "="*60)
print("CREATING SIMPLIFIED MODEL FOR TESTING")
print("="*60)

# Simplified Residual Feature Extractor
class SimpleResidualFeatureExtractor(nn.Module):
    def _init_(self):
        super(SimpleResidualFeatureExtractor, self)._init_()
        # Simple horizontal and vertical filters only
        self.horizontal_filter = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
        self.vertical_filter = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
        
        # Initialize with simple edge detection filters
        self.horizontal_filter.weight.data = torch.FloatTensor([[[[0, 0, 0], [-1, 0, 1], [0, 0, 0]]]])
        self.vertical_filter.weight.data = torch.FloatTensor([[[[0, -1, 0], [0, 0, 0], [0, 1, 0]]]])
        
    def forward(self, x):
        if x.shape[1] == 3:
            x_gray = 0.299 * x[:, 0:1, :, :] + 0.587 * x[:, 1:2, :, :] + 0.114 * x[:, 2:3, :, :]
        else:
            x_gray = x
        
        horizontal = self.horizontal_filter(x_gray)
        vertical = self.vertical_filter(x_gray)
        
        return torch.cat([horizontal, vertical], dim=1)

# Simplified Feature Fusion
class SimpleFeatureFusion(nn.Module):
    def _init_(self):
        super(SimpleFeatureFusion, self)._init_()
        self.fusion = nn.Sequential(
            nn.Conv2d(5, 32, kernel_size=3, padding=1),  # 2 residual + 3 RGB
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, residual_features, rgb_features):
        fused = torch.cat([residual_features, rgb_features], dim=1)
        return self.fusion(fused)

# Simplified Classifier
class SimpleStegoClassifier(nn.Module):
    def _init_(self):
        super(SimpleStegoClassifier, self)._init_()
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.classifier(x)

# Simplified Generator
class SimpleCoverGenerator(nn.Module):
    def _init_(self):
        super(SimpleCoverGenerator, self)._init_()
        self.generator = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 3, kernel_size=3, padding=1),
            nn.Tanh()
        )
        
    def forward(self, x):
        return self.generator(x)

# Simplified Dual Network
class SimpleDualNetwork(nn.Module):
    def __init__(self):
        super(SimpleDualNetwork, self).__init__()
        self.residual_extractor = SimpleResidualFeatureExtractor()
        self.feature_fusion = SimpleFeatureFusion()
        self.classifier = SimpleStegoClassifier()
        self.generator = SimpleCoverGenerator()
        
    def forward(self, x):
        residual_features = self.residual_extractor(x)
        fused_features = self.feature_fusion(residual_features, x)
        classification = self.classifier(fused_features)
        reconstruction = self.generator(fused_features)
        return classification, reconstruction, None


# Simplified Dataset
class SimpleStegoDataset(Dataset):
    def __init__(self, cover_dir, stego_dir, transform=None, max_samples=1000):
        self.cover_dir = cover_dir
        self.stego_dir = stego_dir
        self.transform = transform

        # Get limited number of samples for testing
        self.cover_images = sorted([f for f in os.listdir(cover_dir) 
                                   if f.endswith('.png')])[:max_samples]
        self.stego_images = sorted([f for f in os.listdir(stego_dir) 
                                   if f.endswith('.png')])[:max_samples]

        print(f"Using {len(self.cover_images)} samples for testing")

    def __len__(self):
        return len(self.cover_images)

    def __getitem__(self, idx):
        # Get cover image
        cover_name = self.cover_images[idx]
        cover_path = os.path.join(self.cover_dir, cover_name)

        # Get corresponding stego image (same filename)
        stego_name = self.stego_images[idx]
        stego_path = os.path.join(self.stego_dir, stego_name)

        # Load images
        cover_img = Image.open(cover_path).convert('RGB')
        stego_img = Image.open(stego_path).convert('RGB')

        if self.transform:
            cover_img = self.transform(cover_img)
            stego_img = self.transform(stego_img)

        return stego_img, cover_img, 1  # 1 indicates stego image

# Simplified Loss
class SimpleStegoLoss(nn.Module):
    def _init_(self):
        super(SimpleStegoLoss, self)._init_()
        self.classification_loss = nn.BCELoss()
        self.reconstruction_loss = nn.MSELoss()
        
    def forward(self, classification_pred, reconstruction, classification_target, reconstruction_target):
        cls_loss = self.classification_loss(classification_pred, classification_target)
        rec_loss = self.reconstruction_loss(reconstruction, reconstruction_target)
        total_loss = cls_loss + 0.5 * rec_loss  # Simple weighted sum
        return total_loss, cls_loss, rec_loss, 0

# Simplified Trainer
class SimpleStegoTrainer:
    def __init__(self, model, device, learning_rate=0.001):   # ✅ double underscores
        self.model = model.to(device)
        self.device = device
        self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        self.loss_fn = SimpleStegoLoss()
    
    def train_epoch(self, dataloader):
        self.model.train()
        total_loss, total_cls_loss, total_rec_loss = 0, 0, 0
        
        for batch_idx, (stego_imgs, cover_imgs, labels) in enumerate(dataloader):
            stego_imgs, cover_imgs, labels = (
                stego_imgs.to(self.device),
                cover_imgs.to(self.device),
                labels.float().to(self.device).unsqueeze(1)
            )
            
            # Forward pass
            classification_pred, reconstructed_cover, _ = self.model(stego_imgs)
            
            # Calculate loss
            loss, cls_loss, rec_loss, _ = self.loss_fn(
                classification_pred, reconstructed_cover, labels, cover_imgs
            )
            
            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            # Accumulate losses
            total_loss += loss.item()
            total_cls_loss += cls_loss.item()
            total_rec_loss += rec_loss.item()
        
        return (total_loss / len(dataloader), 
                total_cls_loss / len(dataloader), 
                total_rec_loss / len(dataloader), 
                0)


# ==================== MAIN EXECUTION ====================
print("\n" + "="*60)
print("STARTING TRAINING WITH SIMPLIFIED MODEL")
print("="*60)

def main():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Simple transforms for testing
    transform = transforms.Compose([
        transforms.Resize((128, 128)),  # Smaller for testing
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    # Create small dataset for testing
    print("Creating small test dataset...")
    dataset = SimpleStegoDataset(cover_dir, stego_dir, transform, max_samples=100)
    
    # Split dataset
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    # Create data loaders
    train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False)
    
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    
    # Initialize simple model
    print("Initializing simplified model...")
    model = SimpleDualNetwork().to(device)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")
    
    # Initialize trainer
    trainer = SimpleStegoTrainer(model, device, learning_rate=0.001)
    
    # Training loop - JUST 3 EPOCHS FOR TESTING
    num_epochs = 3
    print(f"\nStarting training for {num_epochs} epochs...")
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 50)
        
        # Train for one epoch
        avg_loss, avg_cls_loss, avg_rec_loss, _ = trainer.train_epoch(train_dataloader)
        
        print(f"Training - Loss: {avg_loss:.4f}, CLS: {avg_cls_loss:.4f}, REC: {avg_rec_loss:.4f}")
        
        # Simple validation
        model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            for stego_imgs, cover_imgs, labels in val_dataloader:
                stego_imgs, labels = stego_imgs.to(device), labels.to(device)
                classification_pred, _, _ = model(stego_imgs)
                pred_labels = (classification_pred > 0.5).float()
                correct += (pred_labels == labels).sum().item()
                total += labels.size(0)
            
            accuracy = correct / total
            print(f"Validation - Accuracy: {accuracy:.4f}")
    
    print("\nTraining completed successfully!")
    print("Model is working correctly!")
    
    # Save the model
    torch.save(model.state_dict(), "/kaggle/working/models/test_model.pth")
    print("Model saved to /kaggle/working/models/test_model.pth")

# Run the main function
if __name__ == "__main__":
    main()

print("\n" + "="*60)
print("PROGRAM COMPLETED SUCCESSFULLY!")
print("="*60)

Using device: cuda
GPU: Tesla P100-PCIE-16GB
Found 20000 image files
Split into 10000 cover and 10000 stego images
Organized 10000 cover images and 10000 stego images

After organization:
Cover images: 10000
Stego images: 10000

Sample cover files:
  00001.png
  00002.png
  00003.png
  00004.png
  00005.png

Sample stego files:
  00001.png
  00002.png
  00003.png
  00004.png
  00005.png

Created mapping for 10000 image pairs
Sample mappings:
  00001.png → 00001.png
  00002.png → 00002.png
  00003.png → 00003.png
  00004.png → 00004.png
  00005.png → 00005.png

Verifying mapping...
✓ 00001.png → 00001.png
✓ 00002.png → 00002.png
✓ 00003.png → 00003.png
✓ 00004.png → 00004.png
✓ 00005.png → 00005.png
✓ 00006.png → 00006.png
✓ 00007.png → 00007.png
✓ 00008.png → 00008.png
✓ 00009.png → 00009.png
✓ 00010.png → 00010.png
✓ All mappings are correct!

CREATING SIMPLIFIED MODEL FOR TESTING

STARTING TRAINING WITH SIMPLIFIED MODEL
Using device: cuda
Creating small test dataset...
Using 100 samp

ValueError: optimizer got an empty parameter list

In [2]:
# First, install required packages
!pip install torch torchvision pillow numpy scipy matplotlib scikit-learn tqdm

# Import libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, confusion_matrix
import time
from tqdm import tqdm
import zipfile
from google.colab import files

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [3]:
# Install required packages


# Import libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, confusion_matrix
import time
from tqdm import tqdm
import zipfile
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"GPU: {torch.cuda.get_device_name(0)}")

Using device: cuda
GPU: Tesla P100-PCIE-16GB


In [4]:
# Check what files are available
import os

# List files in the dataset
dataset_path = '/kaggle/input/bossbase'
print("Dataset contents:")
for root, dirs, files in os.walk(dataset_path):
    level = root.replace(dataset_path, '').count(os.sep)
    indent = ' ' * 2 * level
    print(f'{indent}{os.path.basename(root)}/')
    subindent = ' ' * 2 * (level + 1)
    for file in files[:5]:  # Show first 5 files
        print(f'{subindent}{file}')
    if len(files) > 5:
        print(f'{subindent}... and {len(files) - 5} more files')

# Check the actual structure - BOSSBase might have different organization
bossbase_files = []
for root, dirs, files in os.walk(dataset_path):
    for file in files:
        if file.endswith(('.pgm', '.png', '.jpg', '.bmp')):
            bossbase_files.append(os.path.join(root, file))

print(f"\nFound {len(bossbase_files)} image files")
print("Sample files:", bossbase_files[:5])

Dataset contents:
bossbase/
  boss_256_0.4/
    cover/
      4353.png
      7968.png
      6490.png
      5511.png
      6262.png
      ... and 8995 more files
    stego/
      4353.png
      7968.png
      6490.png
      5511.png
      6262.png
      ... and 8995 more files
  boss_256_0.4_test/
    cover/
      9273.png
      9292.png
      9703.png
      9110.png
      9938.png
      ... and 995 more files
    stego/
      9273.png
      9292.png
      9703.png
      9110.png
      9938.png
      ... and 995 more files

Found 20000 image files
Sample files: ['/kaggle/input/bossbase/boss_256_0.4/cover/4353.png', '/kaggle/input/bossbase/boss_256_0.4/cover/7968.png', '/kaggle/input/bossbase/boss_256_0.4/cover/6490.png', '/kaggle/input/bossbase/boss_256_0.4/cover/5511.png', '/kaggle/input/bossbase/boss_256_0.4/cover/6262.png']


In [5]:
# Create directories for organized dataset
!mkdir -p /kaggle/working/data/cover
!mkdir -p /kaggle/working/data/stego
!mkdir -p /kaggle/working/models
!mkdir -p /kaggle/working/results

# Organize the dataset - this depends on how the BOSSBase dataset is structured
# Typically, we need to separate cover and stego images

def organize_bossbase_dataset(input_path, cover_output, stego_output):
    """
    Organize BOSSBase dataset into cover and stego folders
    This function will need to be adjusted based on the actual dataset structure
    """
    all_image_files = []
    for root, dirs, files in os.walk(input_path):
        for file in files:
            if file.endswith(('.pgm', '.png', '.jpg', '.bmp', '.tiff')):
                all_image_files.append(os.path.join(root, file))
    
    print(f"Found {len(all_image_files)} image files to organize")
    
    # This is a guess - you'll need to adjust based on actual filenames
    # Typically, cover images might be named like "1.pgm" and stego like "1_S-UNIWARD.pgm"
    cover_count = 0
    stego_count = 0
    
    for img_path in all_image_files:
        filename = os.path.basename(img_path)
        
        # Heuristic to identify stego vs cover images
        # Adjust these conditions based on your actual filenames
        if any(keyword in filename.lower() for keyword in ['stego', 's-uniward', 'wow', 'hill', 'juniward']):
            # This is likely a stego image
            destination = os.path.join(stego_output, filename)
            stego_count += 1
        else:
            # This is likely a cover image
            destination = os.path.join(cover_output, filename)
            cover_count += 1
        
        # Copy the file
        !cp "{img_path}" "{destination}"
    
    print(f"Organized {cover_count} cover images and {stego_count} stego images")
    return cover_count, stego_count

# Organize the dataset
cover_dir = '/kaggle/working/data/cover'
stego_dir = '/kaggle/working/data/stego'
cover_count, stego_count = organize_bossbase_dataset(dataset_path, cover_dir, stego_dir)

# If organization didn't work well, we might need manual inspection
print("\nFirst few cover files:")
!ls -la "{cover_dir}" | head -5

print("\nFirst few stego files:")
!ls -la "{stego_dir}" | head -5

Found 20000 image files to organize
Organized 20000 cover images and 0 stego images

First few cover files:
total 113000
drwxr-xr-x 2 root root 262144 Sep  5 10:36 .
drwxr-xr-x 4 root root   4096 Sep  5 09:55 ..
-rw-r--r-- 1 root root  13952 Sep  5 10:38 10000.png
-rw-r--r-- 1 root root  11089 Sep  5 10:15 1000.png
ls: write error: Broken pipe

First few stego files:
total 8
drwxr-xr-x 2 root root 4096 Sep  5 09:55 .
drwxr-xr-x 4 root root 4096 Sep  5 09:55 ..


In [6]:
# Create mapping between cover and stego files
# This is CRUCIAL for BOSSBase as filenames usually don't match directly

def create_filename_mapping(cover_dir, stego_dir):
    """
    Create mapping between cover and stego filenames
    """
    cover_files = sorted([f for f in os.listdir(cover_dir) 
                         if f.endswith(('.pgm', '.png', '.jpg', '.bmp', '.tiff'))])
    stego_files = sorted([f for f in os.listdir(stego_dir) 
                         if f.endswith(('.pgm', '.png', '.jpg', '.bmp', '.tiff'))])
    
    print(f"Cover files: {len(cover_files)}, Stego files: {len(stego_files)}")
    
    # Create mapping - this depends on the naming convention
    stego_to_cover_map = {}
    
    # Common BOSSBase pattern: cover: "1.pgm", stego: "1_S-UNIWARD.pgm"
    for stego_file in stego_files:
        # Try to find the corresponding cover file
        base_name = stego_file.split('_')[0]  # Get the number part
        cover_candidate = f"{base_name}.pgm"  # Assumes .pgm extension
        
        # Check if this cover file exists
        if cover_candidate in cover_files:
            stego_to_cover_map[cover_candidate] = stego_file
        else:
            # Try other extensions
            for ext in ['.png', '.jpg', '.bmp', '.tiff']:
                cover_candidate = f"{base_name}{ext}"
                if cover_candidate in cover_files:
                    stego_to_cover_map[cover_candidate] = stego_file
                    break
    
    print(f"Created mapping for {len(stego_to_cover_map)} image pairs")
    
    # Show some examples
    print("Sample mappings:")
    for i, (cover, stego) in enumerate(list(stego_to_cover_map.items())[:5]):
        print(f"  {cover} → {stego}")
    
    return stego_to_cover_map

# Create filename mapping
stego_to_cover_map = create_filename_mapping(cover_dir, stego_dir)

Cover files: 10000, Stego files: 0
Created mapping for 0 image pairs
Sample mappings:


In [7]:
# ==================== FEATURE EXTRACTION ====================
class ResidualFeatureExtractor(nn.Module):
    def __init__(self):
        super(ResidualFeatureExtractor, self).__init__()
        self.srm_filters = self._create_srm_filters()
        
    def _create_srm_filters(self):
        filters = []
        for angle in [0, 45, 90, 135]:
            filter_size = 5
            filt = np.zeros((filter_size, filter_size))
            center = filter_size // 2
            
            if angle == 0:  # Horizontal
                filt[center, :] = 1
                filt[center, center] = -filter_size + 1
            elif angle == 90:  # Vertical
                filt[:, center] = 1
                filt[center, center] = -filter_size + 1
            elif angle == 45:  # Diagonal
                for i in range(filter_size):
                    filt[i, i] = 1
                filt[center, center] = -filter_size + 1
            elif angle == 135:  # Anti-diagonal
                for i in range(filter_size):
                    filt[i, filter_size-1-i] = 1
                filt[center, center] = -filter_size + 1
            
            filters.append(filt)
        
        filters = [torch.FloatTensor(filt).unsqueeze(0).unsqueeze(0) for filt in filters]
        return nn.ModuleList([nn.Conv2d(1, 1, kernel_size=5, padding=2, bias=False) for _ in range(len(filters))])
    
    def forward(self, x):
        if x.shape[1] == 3:
            x_gray = 0.299 * x[:, 0:1, :, :] + 0.587 * x[:, 1:2, :, :] + 0.114 * x[:, 2:3, :, :]
        else:
            x_gray = x
        
        residuals = []
        for i, filter_layer in enumerate(self.srm_filters):
            residual = filter_layer(x_gray)
            residuals.append(residual)
        
        return torch.cat(residuals, dim=1)

class FeatureFusion(nn.Module):
    def __init__(self, residual_channels=4, rgb_channels=3):
        super(FeatureFusion, self).__init__()
        self.residual_conv = nn.Sequential(
            nn.Conv2d(residual_channels, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        self.rgb_conv = nn.Sequential(
            nn.Conv2d(rgb_channels, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        self.fusion_conv = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, residual_features, rgb_features):
        residual_processed = self.residual_conv(residual_features)
        rgb_processed = self.rgb_conv(rgb_features)
        fused = torch.cat([residual_processed, rgb_processed], dim=1)
        return self.fusion_conv(fused)

# ==================== ATTENTION NETWORK ====================
class ResidualAttentionModule(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(ResidualAttentionModule, self).__init__()
        self.in_channels = in_channels
        self.reduction_ratio = reduction_ratio
        
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1),
            nn.Sigmoid()
        )
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1),
            nn.BatchNorm2d(in_channels // reduction_ratio),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction_ratio, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.filter_gate = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(in_channels, in_channels * 2),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels * 2, in_channels),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        batch_size, channels, height, width = x.shape
        channel_weights = self.channel_attention(x)
        spatial_weights = self.spatial_attention(x)
        filter_weights = self.filter_gate(x).view(batch_size, channels, 1, 1)
        attention_weights = channel_weights * spatial_weights * filter_weights
        attended_features = x * attention_weights
        return x + attended_features, attention_weights

class ResidualAttentionNetwork(nn.Module):
    def __init__(self, in_channels, num_attention_blocks=3):
        super(ResidualAttentionNetwork, self).__init__()
        self.initial_conv = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.attention_blocks = nn.ModuleList()
        for i in range(num_attention_blocks):
            self.attention_blocks.append(ResidualAttentionModule(64))
        self.final_conv = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        features = self.initial_conv(x)
        attention_maps = []
        for attention_block in self.attention_blocks:
            features, attention_weights = attention_block(features)
            attention_maps.append(attention_weights)
        features = self.final_conv(features)
        return features, attention_maps

# ==================== DUAL-ADVERSARIAL NETWORK ====================
class StegoClassifier(nn.Module):
    def __init__(self, in_channels):
        super(StegoClassifier, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d(1)
        )
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        features = self.feature_extractor(x)
        features = features.view(features.size(0), -1)
        return self.classifier(features)

class CoverGenerator(nn.Module):
    def __init__(self, in_channels, out_channels=3):
        super(CoverGenerator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, out_channels, kernel_size=3, padding=1),
            nn.Tanh()
        )
        
    def forward(self, x):
        encoded = self.encoder(x)
        return self.decoder(encoded)

class ConsistencyChecker(nn.Module):
    def __init__(self):
        super(ConsistencyChecker, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(6, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d(1)
        )
        self.consistency_predictor = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        
    def forward(self, original, reconstructed):
        combined = torch.cat([original, reconstructed], dim=1)
        features = self.feature_extractor(combined)
        features = features.view(features.size(0), -1)
        return self.consistency_predictor(features)

class DualAdversarialNetwork(nn.Module):
    def __init__(self, in_channels):
        super(DualAdversarialNetwork, self).__init__()
        self.classifier = StegoClassifier(in_channels)
        self.generator = CoverGenerator(in_channels)
        self.consistency_checker = ConsistencyChecker()
        
    def forward(self, x, original_image=None):
        classification = self.classifier(x)
        reconstructed_cover = self.generator(x)
        consistency_score = None
        if original_image is not None:
            original_normalized = (original_image - 0.5) * 2
            consistency_score = self.consistency_checker(original_normalized, reconstructed_cover)
        return classification, reconstructed_cover, consistency_score

# ==================== DATASET AND TRAINING ====================
class StegoDataset(Dataset):
    def __init__(self, cover_dir, stego_dir, stego_to_cover_map=None, transform=None):
        self.cover_dir = cover_dir
        self.stego_dir = stego_dir
        self.transform = transform
        self.stego_to_cover_map = stego_to_cover_map
        
        # Get list of images
        self.cover_images = sorted([f for f in os.listdir(cover_dir) 
                                  if f.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.pgm', '.tiff'))])
        
        if stego_to_cover_map:
            # Use the mapping to find corresponding stego images
            self.stego_images = [stego_to_cover_map[cover] for cover in self.cover_images 
                               if cover in stego_to_cover_map]
            # Filter cover images to only those with matching stego
            self.cover_images = [cover for cover in self.cover_images 
                               if cover in stego_to_cover_map]
        else:
            # Assume filenames match exactly
            self.stego_images = sorted([f for f in os.listdir(stego_dir) 
                                      if f.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.pgm', '.tiff'))])
            
            # Verify they match
            if len(self.cover_images) != len(self.stego_images):
                print("Warning: Number of cover and stego images don't match!")
            
            # For BOSSBase, sometimes we need to ensure matching
            if not all(c == s for c, s in zip(self.cover_images, self.stego_images)):
                print("Warning: Cover and stego filenames don't match exactly!")
                print("Consider providing a stego_to_cover_map")
    
    def __len__(self):
        return min(len(self.cover_images), len(self.stego_images))
    
    def __getitem__(self, idx):
        # Get cover image
        cover_name = self.cover_images[idx]
        cover_path = os.path.join(self.cover_dir, cover_name)
        
        # Handle different image formats
        try:
            cover_img = Image.open(cover_path).convert('RGB')
        except:
            # Try different approaches for .pgm files
            if cover_path.endswith('.pgm'):
                cover_img = Image.open(cover_path)
                cover_img = cover_img.convert('RGB')
            else:
                raise ValueError(f"Could not open image: {cover_path}")
        
        # Get corresponding stego image
        if self.stego_to_cover_map:
            stego_name = self.stego_to_cover_map[cover_name]
        else:
            stego_name = self.stego_images[idx]
        stego_path = os.path.join(self.stego_dir, stego_name)
        
        try:
            stego_img = Image.open(stego_path).convert('RGB')
        except:
            if stego_path.endswith('.pgm'):
                stego_img = Image.open(stego_path)
                stego_img = stego_img.convert('RGB')
            else:
                raise ValueError(f"Could not open image: {stego_path}")
        
        if self.transform:
            cover_img = self.transform(cover_img)
            stego_img = self.transform(stego_img)
        
        return stego_img, cover_img, 1  # 1 indicates stego image

class StegoLoss(nn.Module):
    def __init__(self, alpha=1.0, beta=0.5, gamma=0.1):
        super(StegoLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.classification_loss = nn.BCELoss()
        self.reconstruction_loss = nn.MSELoss()
        self.consistency_loss = nn.BCELoss()
    
    def forward(self, classification_pred, reconstruction, consistency_pred, 
                classification_target, reconstruction_target, consistency_target=None):
        cls_loss = self.classification_loss(classification_pred, classification_target)
        rec_loss = self.reconstruction_loss(reconstruction, reconstruction_target)
        cons_loss = 0
        if consistency_target is not None and consistency_pred is not None:
            cons_loss = self.consistency_loss(consistency_pred, consistency_target)
        total_loss = (self.alpha * cls_loss + self.beta * rec_loss + self.gamma * cons_loss)
        return total_loss, cls_loss, rec_loss, cons_loss

class StegoTrainer:
    def __init__(self, model, device, learning_rate=0.001):
        self.model = model.to(device)
        self.device = device
        self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=5, factor=0.5)
        self.loss_fn = StegoLoss(alpha=1.0, beta=0.5, gamma=0.1)
    
    def train_epoch(self, dataloader):
        self.model.train()
        total_loss, total_cls_loss, total_rec_loss, total_cons_loss = 0, 0, 0, 0
        
        for batch_idx, (stego_imgs, cover_imgs, labels) in enumerate(tqdm(dataloader, desc="Training")):
            stego_imgs, cover_imgs, labels = stego_imgs.to(self.device), cover_imgs.to(self.device), labels.float().to(self.device).unsqueeze(1)
            
            # Extract features
            residual_extractor = ResidualFeatureExtractor().to(self.device)
            residual_features = residual_extractor(stego_imgs)
            
            feature_fusion = FeatureFusion().to(self.device)
            fused_features = feature_fusion(residual_features, stego_imgs)
            
            # Apply attention
            attention_net = ResidualAttentionNetwork(in_channels=64).to(self.device)
            attended_features, _ = attention_net(fused_features)
            
            # Forward pass through DAN
            classification_pred, reconstructed_cover, consistency_pred = self.model(attended_features, stego_imgs)
            
            # Calculate loss
            loss, cls_loss, rec_loss, cons_loss = self.loss_fn(
                classification_pred, reconstructed_cover, consistency_pred,
                labels, cover_imgs, torch.ones_like(consistency_pred) if consistency_pred is not None else None
            )
            
            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            # Accumulate losses
            total_loss += loss.item()
            total_cls_loss += cls_loss.item()
            total_rec_loss += rec_loss.item()
            if cons_loss != 0:
                total_cons_loss += cons_loss.item()
        
        return (total_loss / len(dataloader), 
                total_cls_loss / len(dataloader), 
                total_rec_loss / len(dataloader), 
                total_cons_loss / len(dataloader) if total_cons_loss > 0 else 0)
    
    def validate(self, dataloader):
        self.model.eval()
        total_loss, correct, total = 0, 0, 0
        all_predictions, all_labels = [], []
        
        with torch.no_grad():
            for stego_imgs, cover_imgs, labels in tqdm(dataloader, desc="Validation"):
                stego_imgs, cover_imgs, labels = stego_imgs.to(self.device), cover_imgs.to(self.device), labels.float().to(self.device).unsqueeze(1)
                
                # Extract features
                residual_extractor = ResidualFeatureExtractor().to(self.device)
                residual_features = residual_extractor(stego_imgs)
                
                feature_fusion = FeatureFusion().to(self.device)
                fused_features = feature_fusion(residual_features, stego_imgs)
                
                # Apply attention
                attention_net = ResidualAttentionNetwork(in_channels=64).to(self.device)
                attended_features, _ = attention_net(fused_features)
                
                # Forward pass
                classification_pred, reconstructed_cover, consistency_pred = self.model(attended_features, stego_imgs)
                
                # Calculate loss
                loss, _, _, _ = self.loss_fn(
                    classification_pred, reconstructed_cover, consistency_pred,
                    labels, cover_imgs, torch.ones_like(consistency_pred) if consistency_pred is not None else None
                )
                
                total_loss += loss.item()
                
                # Calculate accuracy
                pred_labels = (classification_pred > 0.5).float()
                correct += (pred_labels == labels).sum().item()
                total += labels.size(0)
                
                # Store for ROC calculation
                all_predictions.extend(classification_pred.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        # Calculate ROC curve
        all_predictions, all_labels = np.array(all_predictions), np.array(all_labels)
        fpr, tpr, _ = roc_curve(all_labels, all_predictions)
        roc_auc = auc(fpr, tpr)
        
        return total_loss / len(dataloader), correct / total, fpr, tpr, roc_auc, all_predictions, all_labels

def plot_training_history(train_losses, val_losses, val_accuracies, roc_info):
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
    
    # Plot training and validation loss
    ax1.plot(train_losses, label='Training Loss')
    ax1.plot(val_losses, label='Validation Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)
    
    # Plot validation accuracy
    ax2.plot(val_accuracies, label='Validation Accuracy', color='green')
    ax2.set_title('Validation Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    ax2.grid(True)
    
    # Plot ROC curve
    fpr, tpr, roc_auc = roc_info[:3]
    ax3.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
    ax3.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    ax3.set_xlim([0.0, 1.0])
    ax3.set_ylim([0.0, 1.05])
    ax3.set_xlabel('False Positive Rate')
    ax3.set_ylabel('True Positive Rate')
    ax3.set_title('Receiver Operating Characteristic')
    ax3.legend(loc="lower right")
    ax3.grid(True)
    
    # Calculate confusion matrix
    if len(roc_info) > 5:
        all_predictions, all_labels = roc_info[5], roc_info[6]
        cm = confusion_matrix(all_labels, (all_predictions > 0.5).astype(int))
        ax4.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
        ax4.set_title('Confusion Matrix')
        ax4.set_ylabel('True label')
        ax4.set_xlabel('Predicted label')
        tick_marks = np.arange(2)
        ax4.set_xticks(tick_marks)
        ax4.set_yticks(tick_marks)
        ax4.set_xticklabels(['Cover', 'Stego'])
        ax4.set_yticklabels(['Cover', 'Stego'])
        
        thresh = cm.max() / 2.
        for i in range(cm.shape[0]):
            for j in range(cm.shape[1]):
                ax4.text(j, i, format(cm[i, j], 'd'),
                        ha="center", va="center",
                        color="white" if cm[i, j] > thresh else "black")
    
    plt.tight_layout()
    plt.savefig('/kaggle/working/results/training_results.png')
    plt.show()