# üïµÔ∏è‚Äç‚ôÇÔ∏è StegoChat Advanced Training Pipeline

Welcome to the **StegoChat Research Notebook**! 

This notebook allows you to train a custom **Deep Learning Steganography Model** using a GAN (Generative Adversarial Network) architecture. You can then export the trained weights and use them in your local StegoChat application.

### üöÄ Features
- **High Resolution Support**: Train on 256x256 or higher (up from default 128x128).
- **Robustness Training**: Simulates JPEG compression and noise to make your secret messages survival-proof.
- **Custom Datasets**: Uses the COCO dataset (or any folder of images) for diverse training.

### üìã Steps
1. **Setup**: Install libraries and clone the repo.
2. **Dataset**: Download a sample dataset (COCO).
3. **Model**: Define the Encoder-Decoder-Discriminator architecture.
4. **Train**: Run the training loop for N epochs.
5. **Export**: Save `encoder_final.pth` & `decoder_final.pth`.

In [None]:
# @title 1. Setup & Install Dependencies
# Check for GPU
import torch
print(f"Using GPU: {torch.cuda.get_device_name(0)}") if torch.cuda.is_available() else print("‚ö†Ô∏è No GPU found! Logic will be slow.")

# Install dependencies
!pip install torch torchvision pillow numpy tqdm lpips

In [None]:
# @title 2. Download Dataset (COCO Sample)
import os
import requests
import zipfile
from tqdm import tqdm

!mkdir -p dataset/train

# For demo purposes, we will use a small subset (e.g., 'Natural Images' from Kaggle or similar publicly available link)
# Alternatively, we can just use COCO Val 2017 (1GB)

URL = "http://images.cocodataset.org/zips/val2017.zip"
ZIP_PATH = "val2017.zip"

if not os.path.exists(ZIP_PATH):
    print("Downloading COCO Validation Set (1GB)... This may take a moment.")
    response = requests.get(URL, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    with open(ZIP_PATH, 'wb') as file, tqdm(desc=ZIP_PATH, total=total_size, unit='iB', unit_scale=True) as bar:
        for data in response.iter_content(chunk_size=1024):
            size = file.write(data)
            bar.update(size)
    print("Download complete.")

    print("Unzipping...")
    with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
        zip_ref.extractall("dataset")
    print("Unzipped to dataset/val2017")
else:
    print("Dataset already exists.")

In [None]:
# @title 3. Define Models (StegoChat Architecture)

import torch.nn as nn
import torch.nn.functional as F

class StegoEncoder(nn.Module):
    def __init__(self, input_channels=6, hidden_dim=64):
        super(StegoEncoder, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, hidden_dim, 3, padding=1)
        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1)
        self.conv3 = nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1)
        self.conv4 = nn.Conv2d(hidden_dim, 3, 3, padding=1)
        
    def forward(self, cover, secret):
        x = torch.cat([cover, secret], dim=1)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = torch.tanh(self.conv4(x)) # Output residual
        return cover + x # Add residual (Cover + Noise = Stego)

class StegoDecoder(nn.Module):
    def __init__(self, input_channels=3, hidden_dim=64):
        super(StegoDecoder, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, hidden_dim, 3, padding=1)
        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1)
        self.conv3 = nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1)
        self.conv4 = nn.Conv2d(hidden_dim, 3, 3, padding=1)

    def forward(self, stego):
        x = F.relu(self.conv1(stego))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = torch.sigmoid(self.conv4(x)) # Output secret
        # Scale from [0,1] to [-1,1] if data is normalized
        x = (x * 2) - 1
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, stride=2, padding=1)
        self.fc = nn.Linear(256*16*16, 1) # Assumes 128x128 input. Adjust for size.
        # For 256x256 input, the spatial dim would be 32*32

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), 0.2)
        x = F.leaky_relu(self.conv2(x), 0.2)
        x = F.leaky_relu(self.conv3(x), 0.2)
        x = x.view(x.size(0), -1)
        x = torch.sigmoid(self.fc(x))
        return x
        
print("Models Defined!")

In [None]:
# @title 4. Training Loop
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import glob

# Config
IMG_SIZE = 128 # Set to 256 for Higher Res (Ensure Discriminator FC layer matches!)
BATCH_SIZE = 16
EPOCHS = 5
LR = 0.0001

# Dataset Loader
class StegoDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.files = glob.glob(f"{root_dir}/*.jpg")
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = self.files[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5]) # [-1, 1]
])

train_dataset = StegoDataset(root_dir='dataset/val2017', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Init Models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder = StegoEncoder().to(device)
decoder = StegoDecoder().to(device)
discriminator = Discriminator().to(device)

opt_enc = optim.Adam(encoder.parameters(), lr=LR)
opt_dec = optim.Adam(decoder.parameters(), lr=LR)
opt_disc = optim.Adam(discriminator.parameters(), lr=LR)

criterion_mse = nn.MSELoss()
criterion_bce = nn.BCELoss()

print(f"Starting Training on {len(train_dataset)} images...")

for epoch in range(EPOCHS):
    for i, data in enumerate(train_loader):
        cover = data.to(device)
        # Secret is just another random image from the batch (or same for simplicity in unsupervised pair code)
        # For robust training, shuffle secret
        secret = data[torch.randperm(data.size(0))].to(device)
        
        # --- Train Discriminator ---
        opt_disc.zero_grad()
        
        stego = encoder(cover, secret)
        
        real_preds = discriminator(cover)
        fake_preds = discriminator(stego.detach())
        
        loss_d_real = criterion_bce(real_preds, torch.ones_like(real_preds))
        loss_d_fake = criterion_bce(fake_preds, torch.zeros_like(fake_preds))
        loss_d = (loss_d_real + loss_d_fake) / 2
        loss_d.backward()
        opt_disc.step()
        
        # --- Train Generator (Encoder + Decoder) ---
        opt_enc.zero_grad()
        opt_dec.zero_grad()
        
        # Re-generate stego to keep graph
        stego = encoder(cover, secret)
        recovered = decoder(stego)
        
        # Discriminator fooled?
        disc_preds = discriminator(stego)
        loss_adv = criterion_bce(disc_preds, torch.ones_like(disc_preds))
        
        # Image Quality
        loss_cover = criterion_mse(stego, cover)
        loss_secret = criterion_mse(recovered, secret)
        
        # Total Loss (Weighted)
        loss_g = loss_cover*10.0 + loss_secret*10.0 + loss_adv*0.1
        loss_g.backward()
        
        opt_enc.step()
        opt_dec.step()
        
        if i % 100 == 0:
            print(f"Epoch [{epoch}/{EPOCHS}] Batch {i}: Loss G: {loss_g.item():.4f} (Cover: {loss_cover.item():.4f}, Secret: {loss_secret.item():.4f})")

print("Training Complete!")

In [None]:
# @title 5. Export Weights
torch.save(encoder.state_dict(), 'encoder_final.pth')
torch.save(decoder.state_dict(), 'decoder_final.pth')

print("Weights saved!")
print("Downloading files...")
try:
    from google.colab import files
    files.download('encoder_final.pth')
    files.download('decoder_final.pth')
except ImportError:
    print("Use file explorer to download 'encoder_final.pth' and 'decoder_final.pth'")