# üéØ Memory-Optimized Watermarking Model - 85-90% Accuracy

## ‚ö° Optimized for T4 GPU (15GB)

**Key Optimizations:**
- ‚úÖ Smaller batch size (16 instead of 32)
- ‚úÖ Gradient accumulation (effective batch size = 32)
- ‚úÖ Mixed precision training (FP16)
- ‚úÖ Memory-efficient perceptual loss
- ‚úÖ Automatic memory clearing

**Expected Results:** 85-90% accuracy in ~40-50 minutes


In [None]:
%%capture
# Install packages
!pip install -q torch torchvision matplotlib opencv-python-headless scikit-image scikit-learn PyWavelets Pillow tqdm

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# ‚ö†Ô∏è CHANGE THIS PATH ‚ö†Ô∏è
ROOT_IMAGES = '/content/drive/MyDrive/project/JPEGImages'

# Memory-optimized configuration
CONFIG = {
    'epochs': 20,
    'batch_size': 16,          # ‚úÖ Reduced from 32 to save memory
    'accumulation_steps': 2,   # ‚úÖ Gradient accumulation (effective batch = 32)
    'lr': 1e-3,
    'payload_len': 64,
    'train_n': 10000,
    'val_n': 2000,
    'test_n': 2000,
    'early_stop_patience': 5,
    'image_size': 256,
    'use_amp': True,           # ‚úÖ Mixed precision (FP16) for memory savings
}

print('üìÇ Image directory:', ROOT_IMAGES)
print('\n‚öôÔ∏è Memory-Optimized Configuration:')
for k, v in CONFIG.items():
    print(f'  {k:20s} = {v}')
print(f"\n  Effective batch size: {CONFIG['batch_size'] * CONFIG['accumulation_steps']}")

In [None]:
# Imports
import os
import gc
import cv2
import numpy as np
import matplotlib.pyplot as plt
from skimage import io
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from pathlib import Path
import random
import time
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'üñ•Ô∏è  Device: {device}')
if device == 'cuda':
    print(f'    GPU: {torch.cuda.get_device_name(0)}')
    total_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f'    Total Memory: {total_mem:.1f} GB')
    torch.cuda.empty_cache()
    gc.collect()

# Set seeds
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

print('\n‚úÖ Setup complete!')

In [None]:
# Improved Encoder (same as before)
class ImprovedEncoder(nn.Module):
    def __init__(self, payload_len=64, hidden=64):
        super().__init__()
        self.payload_len = payload_len
        
        self.payload_embed = nn.Sequential(
            nn.Linear(payload_len, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 16*16*8)
        )
        
        self.down1 = nn.Sequential(
            nn.Conv2d(3 + 8, hidden, 3, padding=1),
            nn.BatchNorm2d(hidden),
            nn.ReLU(),
            nn.Conv2d(hidden, hidden, 3, padding=1),
            nn.BatchNorm2d(hidden),
            nn.ReLU()
        )
        
        self.pool = nn.MaxPool2d(2)
        
        self.down2 = nn.Sequential(
            nn.Conv2d(hidden, hidden*2, 3, padding=1),
            nn.BatchNorm2d(hidden*2),
            nn.ReLU(),
            nn.Conv2d(hidden*2, hidden*2, 3, padding=1),
            nn.BatchNorm2d(hidden*2),
            nn.ReLU()
        )
        
        self.down3 = nn.Sequential(
            nn.Conv2d(hidden*2, hidden*4, 3, padding=1),
            nn.BatchNorm2d(hidden*4),
            nn.ReLU(),
            nn.Conv2d(hidden*4, hidden*4, 3, padding=1),
            nn.BatchNorm2d(hidden*4),
            nn.ReLU()
        )
        
        self.up1 = nn.ConvTranspose2d(hidden*4, hidden*2, 2, stride=2)
        self.up_conv1 = nn.Sequential(
            nn.Conv2d(hidden*4, hidden*2, 3, padding=1),
            nn.BatchNorm2d(hidden*2),
            nn.ReLU()
        )
        
        self.up2 = nn.ConvTranspose2d(hidden*2, hidden, 2, stride=2)
        self.up_conv2 = nn.Sequential(
            nn.Conv2d(hidden*2, hidden, 3, padding=1),
            nn.BatchNorm2d(hidden),
            nn.ReLU()
        )
        
        self.out_conv = nn.Sequential(
            nn.Conv2d(hidden, hidden, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden, 3, 1)
        )
        
    def forward(self, x, payload):
        B, _, H, W = x.shape
        
        p_feat = self.payload_embed(payload)
        p_feat = p_feat.view(B, 8, 16, 16)
        p_feat = F.interpolate(p_feat, size=(H, W), mode='bilinear', align_corners=False)
        
        x_in = torch.cat([x, p_feat], dim=1)
        
        d1 = self.down1(x_in)
        p1 = self.pool(d1)
        
        d2 = self.down2(p1)
        p2 = self.pool(d2)
        
        d3 = self.down3(p2)
        
        u1 = self.up1(d3)
        u1 = torch.cat([u1, d2], dim=1)
        u1 = self.up_conv1(u1)
        
        u2 = self.up2(u1)
        u2 = torch.cat([u2, d1], dim=1)
        u2 = self.up_conv2(u2)
        
        res = torch.tanh(self.out_conv(u2)) * 0.05
        
        return res

print('‚úÖ Encoder defined')

In [None]:
# Improved Decoder (same as before)
class ImprovedDecoder(nn.Module):
    def __init__(self, payload_len=64, hidden=64):
        super().__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, hidden, 3, padding=1),
            nn.BatchNorm2d(hidden),
            nn.ReLU(),
            nn.Conv2d(hidden, hidden, 3, padding=1),
            nn.BatchNorm2d(hidden),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(hidden, hidden*2, 3, padding=1),
            nn.BatchNorm2d(hidden*2),
            nn.ReLU(),
            nn.Conv2d(hidden*2, hidden*2, 3, padding=1),
            nn.BatchNorm2d(hidden*2),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(hidden*2, hidden*4, 3, padding=1),
            nn.BatchNorm2d(hidden*4),
            nn.ReLU(),
            nn.Conv2d(hidden*4, hidden*4, 3, padding=1),
            nn.BatchNorm2d(hidden*4),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((8, 8))
        )
        
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(hidden*4*8*8, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, payload_len)
        )
        
    def forward(self, x):
        f1 = self.conv1(x)
        f2 = self.conv2(f1)
        f3 = self.conv3(f2)
        logits = self.fc(f3)
        return logits

print('‚úÖ Decoder defined')

In [None]:
# Attack Pipeline (same as before)
class ImprovedAttack(nn.Module):
    def __init__(self, p_jpeg=0.7):
        super().__init__()
        self.p_jpeg = p_jpeg
        
    def forward(self, imgs):
        x = imgs
        
        # Resize
        if random.random() < 0.95:
            scales = torch.empty(x.size(0)).uniform_(0.75, 0.95).tolist()
            out = torch.zeros_like(x)
            for i, s in enumerate(scales):
                h, w = x.shape[2], x.shape[3]
                nh, nw = max(1, int(h*s)), max(1, int(w*s))
                small = F.interpolate(x[i:i+1], size=(nh, nw), mode='bilinear', align_corners=False)
                back = F.interpolate(small, size=(h, w), mode='bilinear', align_corners=False)
                out[i:i+1] = back
            x = out
        
        # Rotation
        if random.random() < 0.6:
            angles = torch.empty(x.size(0)).uniform_(-5, 5).tolist()
            theta_batch = []
            for ang in angles:
                rad = np.deg2rad(ang)
                theta = torch.tensor([
                    [np.cos(rad), -np.sin(rad), 0.0],
                    [np.sin(rad), np.cos(rad), 0.0]
                ], dtype=torch.float)
                theta_batch.append(theta.unsqueeze(0))
            theta_batch = torch.cat(theta_batch, dim=0).to(x.device)
            grid = F.affine_grid(theta_batch, x.size(), align_corners=False)
            x = F.grid_sample(x, grid, padding_mode='border', align_corners=False)
        
        # Blur
        if random.random() < 0.8:
            k = random.choice([3, 5])
            kernel = torch.tensor(cv2.getGaussianKernel(k, k/3).astype(np.float32))
            kernel2 = kernel @ kernel.T
            kernel2 = kernel2 / kernel2.sum()
            k_t = kernel2.unsqueeze(0).unsqueeze(0).to(x.device)
            pad = k // 2
            out = F.pad(x, (pad, pad, pad, pad), mode='reflect')
            out_c = []
            for c in range(3):
                out_c.append(F.conv2d(out[:, c:c+1, :, :], k_t, padding=0))
            x = torch.cat(out_c, dim=1)
        
        # Noise
        if random.random() < 0.9:
            noise = torch.randn_like(x) * random.uniform(0.003, 0.01)
            x = torch.clamp(x + noise, 0, 1)
        
        # JPEG
        if random.random() < self.p_jpeg:
            x_np = (x.detach().cpu().numpy() * 255).astype(np.uint8)
            out_batch = []
            for i in range(x_np.shape[0]):
                img_bgr = cv2.cvtColor(x_np[i].transpose(1, 2, 0), cv2.COLOR_RGB2BGR)
                q = random.randint(70, 95)
                _, enc = cv2.imencode('.jpg', img_bgr, [int(cv2.IMWRITE_JPEG_QUALITY), q])
                dec = cv2.imdecode(enc, cv2.IMREAD_COLOR)
                dec_rgb = cv2.cvtColor(dec, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
                out_batch.append(dec_rgb)
            x = torch.from_numpy(np.stack(out_batch, axis=0)).permute(0, 3, 1, 2).to(imgs.device).float()
        
        return x

print('‚úÖ Attack pipeline defined')

In [None]:
# Dataset
class ImageDataset(Dataset):
    def __init__(self, paths, image_size=256):
        self.paths = [str(p) for p in paths]
        self.image_size = image_size
        
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, idx):
        try:
            img = io.imread(self.paths[idx])
            if img.ndim == 2:
                img = np.stack([img, img, img], axis=-1)
            if img.shape[2] == 4:
                img = img[:, :, :3]
            
            img = (img.astype(np.float32) / 255.0) if img.max() > 1.0 else img.astype(np.float32)
            
            H, W = img.shape[:2]
            side = min(H, W)
            cy, cx = H // 2, W // 2
            img_crop = img[cy-side//2:cy-side//2+side, cx-side//2:cx-side//2+side]
            img_resized = cv2.resize(img_crop, (self.image_size, self.image_size), interpolation=cv2.INTER_AREA)
            
            img_t = torch.from_numpy(img_resized).permute(2, 0, 1).float()
            return img_t
        except Exception as e:
            return torch.zeros(3, self.image_size, self.image_size)

def create_datasets(root_dir, train_n=10000, val_n=2000, test_n=2000, seed=42):
    paths = list(Path(root_dir).glob('**/*.jpg')) + list(Path(root_dir).glob('**/*.png'))
    random.Random(seed).shuffle(paths)
    
    total_needed = train_n + val_n + test_n
    available = len(paths)
    
    print(f'Found {available} images in {root_dir}')
    
    if available < total_needed:
        print(f'‚ö†Ô∏è  Warning: Only {available} images available')
        ratio = available / total_needed
        train_n = int(train_n * ratio)
        val_n = int(val_n * ratio)
        test_n = available - train_n - val_n
    
    train_paths = paths[:train_n]
    val_paths = paths[train_n:train_n+val_n]
    test_paths = paths[train_n+val_n:train_n+val_n+test_n]
    
    print(f'\nDataset splits:')
    print(f'  Train: {len(train_paths):,} images')
    print(f'  Val:   {len(val_paths):,} images')
    print(f'  Test:  {len(test_paths):,} images')
    
    return train_paths, val_paths, test_paths

print('‚úÖ Dataset utilities defined')

# üöÄ Memory-Optimized Training Function

In [None]:
def train_model_memory_efficient(root_images, epochs=20, batch_size=16, accumulation_steps=2,
                                  lr=1e-3, payload_len=64, train_n=10000, val_n=2000, 
                                  test_n=2000, early_stop_patience=5, use_amp=True):
    """
    Memory-optimized training with:
    - Gradient accumulation
    - Mixed precision (FP16)
    - Simplified perceptual loss
    - Memory cleanup
    """
    
    print('='*70)
    print('MEMORY-OPTIMIZED TRAINING')
    print('='*70)
    print(f'\nBatch size: {batch_size}')
    print(f'Accumulation steps: {accumulation_steps}')
    print(f'Effective batch size: {batch_size * accumulation_steps}')
    print(f'Mixed precision: {use_amp}\n')
    
    # Clear memory
    torch.cuda.empty_cache()
    gc.collect()
    
    # Create datasets
    train_paths, val_paths, test_paths = create_datasets(
        root_images, train_n=train_n, val_n=val_n, test_n=test_n
    )
    
    train_ds = ImageDataset(train_paths)
    val_ds = ImageDataset(val_paths)
    test_ds = ImageDataset(test_paths)
    
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, 
                              num_workers=2, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, 
                            num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, 
                             num_workers=2, pin_memory=True)
    
    # Initialize models
    print('\nüèóÔ∏è  Initializing models...')
    encoder = ImprovedEncoder(payload_len=payload_len).to(device)
    decoder = ImprovedDecoder(payload_len=payload_len).to(device)
    attack = ImprovedAttack(p_jpeg=0.7).to(device)
    
    # Optimizer
    params = list(encoder.parameters()) + list(decoder.parameters())
    optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-5)
    
    # Mixed precision scaler
    scaler = GradScaler() if use_amp else None
    
    # Simplified perceptual loss (no VGG to save memory)
    def simple_perceptual_loss(x, y):
        """Lightweight perceptual loss using gradient magnitude"""
        def get_gradients(img):
            dx = img[:, :, :, 1:] - img[:, :, :, :-1]
            dy = img[:, :, 1:, :] - img[:, :, :-1, :]
            return dx, dy
        
        dx_x, dy_x = get_gradients(x)
        dx_y, dy_y = get_gradients(y)
        
        return F.mse_loss(dx_x, dx_y) + F.mse_loss(dy_x, dy_y)
    
    # Training history
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [],
        'val_precision': [], 'val_recall': [], 'val_f1': []
    }
    
    best_val_acc = 0.0
    no_improve = 0
    
    print(f'\nüéØ Starting training for {epochs} epochs...')
    print('='*70)
    
    for epoch in range(epochs):
        # ========== TRAINING ==========
        encoder.train()
        decoder.train()
        
        train_losses = []
        train_accs = []
        optimizer.zero_grad()
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
        for batch_idx, imgs in enumerate(pbar):
            imgs = imgs.to(device)
            B = imgs.size(0)
            
            payload = torch.randint(0, 2, (B, payload_len)).float().to(device)
            
            # Mixed precision forward pass
            with autocast(enabled=use_amp):
                residual = encoder(imgs, payload)
                watermarked = torch.clamp(imgs + residual, 0.0, 1.0)
                attacked = attack(watermarked)
                logits = decoder(attacked)
                
                # Losses
                bce_loss = F.binary_cross_entropy_with_logits(logits, payload)
                mse_loss = F.mse_loss(watermarked, imgs)
                perc_loss = simple_perceptual_loss(watermarked, imgs)
                
                # Combined loss (normalized by accumulation steps)
                loss = (bce_loss + 0.1 * mse_loss + 0.1 * perc_loss) / accumulation_steps
            
            # Backward pass
            if use_amp:
                scaler.scale(loss).backward()
            else:
                loss.backward()
            
            # Update weights after accumulation
            if (batch_idx + 1) % accumulation_steps == 0:
                if use_amp:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(params, max_norm=1.0)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(params, max_norm=1.0)
                    optimizer.step()
                optimizer.zero_grad()
            
            # Compute accuracy
            with torch.no_grad():
                pred_bits = (torch.sigmoid(logits) > 0.5).float()
                acc = (pred_bits == payload).float().mean().item()
            
            train_losses.append(loss.item() * accumulation_steps)
            train_accs.append(acc)
            
            # Show memory usage
            if batch_idx % 50 == 0 and torch.cuda.is_available():
                mem_used = torch.cuda.memory_allocated() / 1e9
                mem_cached = torch.cuda.memory_reserved() / 1e9
                pbar.set_postfix({
                    'loss': f'{loss.item()*accumulation_steps:.4f}',
                    'acc': f'{acc*100:.1f}%',
                    'mem': f'{mem_used:.1f}GB'
                })
            else:
                pbar.set_postfix({
                    'loss': f'{loss.item()*accumulation_steps:.4f}',
                    'acc': f'{acc*100:.1f}%'
                })
        
        avg_train_loss = np.mean(train_losses)
        avg_train_acc = np.mean(train_accs)
        history['train_loss'].append(avg_train_loss)
        history['train_acc'].append(avg_train_acc)
        
        # Clear cache after training
        torch.cuda.empty_cache()
        
        # ========== VALIDATION ==========
        encoder.eval()
        decoder.eval()
        
        val_losses = []
        all_preds = []
        all_targets = []
        
        with torch.no_grad():
            for imgs in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]  ", leave=False):
                imgs = imgs.to(device)
                B = imgs.size(0)
                
                payload = torch.randint(0, 2, (B, payload_len)).float().to(device)
                
                with autocast(enabled=use_amp):
                    residual = encoder(imgs, payload)
                    watermarked = torch.clamp(imgs + residual, 0.0, 1.0)
                    attacked = attack(watermarked)
                    logits = decoder(attacked)
                    bce_loss = F.binary_cross_entropy_with_logits(logits, payload)
                
                val_losses.append(bce_loss.item())
                
                preds = (torch.sigmoid(logits) > 0.5).long().cpu().numpy().reshape(-1)
                targs = payload.long().cpu().numpy().reshape(-1)
                
                all_preds.extend(preds.tolist())
                all_targets.extend(targs.tolist())
        
        # Compute metrics
        avg_val_loss = np.mean(val_losses)
        val_acc = accuracy_score(all_targets, all_preds)
        val_prec = precision_score(all_targets, all_preds, zero_division=0)
        val_rec = recall_score(all_targets, all_preds, zero_division=0)
        val_f1 = f1_score(all_targets, all_preds, zero_division=0)
        
        history['val_loss'].append(avg_val_loss)
        history['val_acc'].append(val_acc)
        history['val_precision'].append(val_prec)
        history['val_recall'].append(val_rec)
        history['val_f1'].append(val_f1)
        
        # Print epoch summary
        print(f'\nEpoch {epoch+1}/{epochs}:')
        print(f'  Train - Loss: {avg_train_loss:.4f}, Acc: {avg_train_acc*100:.2f}%')
        print(f'  Val   - Loss: {avg_val_loss:.4f}, Acc: {val_acc*100:.2f}%, F1: {val_f1:.3f}')
        
        # Memory stats
        if torch.cuda.is_available():
            mem_used = torch.cuda.memory_allocated() / 1e9
            mem_max = torch.cuda.max_memory_allocated() / 1e9
            print(f'  Memory - Used: {mem_used:.2f}GB, Peak: {mem_max:.2f}GB')
            torch.cuda.reset_peak_memory_stats()
        
        # Learning rate scheduling
        scheduler.step()
        
        # Early stopping and checkpoint
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            no_improve = 0
            torch.save({
                'epoch': epoch,
                'encoder': encoder.state_dict(),
                'decoder': decoder.state_dict(),
                'optimizer': optimizer.state_dict(),
                'val_acc': val_acc,
            }, '/content/best_model_checkpoint.pt')
            print(f'  ‚úÖ New best model saved! (Val Acc: {val_acc*100:.2f}%)')
        else:
            no_improve += 1
            if no_improve >= early_stop_patience:
                print(f'\nüõë Early stopping triggered')
                break
        
        print('-'*70)
        
        # Clear cache after epoch
        torch.cuda.empty_cache()
        gc.collect()
    
    # Load best model
    checkpoint = torch.load('/content/best_model_checkpoint.pt')
    encoder.load_state_dict(checkpoint['encoder'])
    decoder.load_state_dict(checkpoint['decoder'])
    
    # Test evaluation
    print('\n' + '='*70)
    print('FINAL TEST EVALUATION')
    print('='*70)
    
    encoder.eval()
    decoder.eval()
    
    all_test_preds = []
    all_test_targets = []
    
    with torch.no_grad():
        for imgs in tqdm(test_loader, desc="Testing"):
            imgs = imgs.to(device)
            B = imgs.size(0)
            
            payload = torch.randint(0, 2, (B, payload_len)).float().to(device)
            
            with autocast(enabled=use_amp):
                residual = encoder(imgs, payload)
                watermarked = torch.clamp(imgs + residual, 0.0, 1.0)
                attacked = attack(watermarked)
                logits = decoder(attacked)
            
            preds = (torch.sigmoid(logits) > 0.5).long().cpu().numpy().reshape(-1)
            targs = payload.long().cpu().numpy().reshape(-1)
            
            all_test_preds.extend(preds.tolist())
            all_test_targets.extend(targs.tolist())
    
    test_acc = accuracy_score(all_test_targets, all_test_preds)
    test_prec = precision_score(all_test_targets, all_test_preds, zero_division=0)
    test_rec = recall_score(all_test_targets, all_test_preds, zero_division=0)
    test_f1 = f1_score(all_test_targets, all_test_preds, zero_division=0)
    
    print(f'\nüìä Test Results:')
    print(f'   Accuracy:  {test_acc*100:.2f}%')
    print(f'   Precision: {test_prec:.4f}')
    print(f'   Recall:    {test_rec:.4f}')
    print(f'   F1-Score:  {test_f1:.4f}')
    
    if test_acc >= 0.85:
        print(f'\nüéâ SUCCESS! Achieved target accuracy')
    elif test_acc >= 0.75:
        print(f'\n‚ö†Ô∏è  Close to target')
    else:
        print(f'\n‚ùå Below target')
    
    print('\n' + '='*70)
    
    return encoder, decoder, history

print('‚úÖ Memory-optimized training function defined')

# ‚ñ∂Ô∏è RUN TRAINING

In [None]:
# üöÄ START TRAINING
encoder, decoder, history = train_model_memory_efficient(
    root_images=ROOT_IMAGES,
    epochs=CONFIG['epochs'],
    batch_size=CONFIG['batch_size'],
    accumulation_steps=CONFIG['accumulation_steps'],
    lr=CONFIG['lr'],
    payload_len=CONFIG['payload_len'],
    train_n=CONFIG['train_n'],
    val_n=CONFIG['val_n'],
    test_n=CONFIG['test_n'],
    early_stop_patience=CONFIG['early_stop_patience'],
    use_amp=CONFIG['use_amp']
)

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss', fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot([a*100 for a in history['train_acc']], label='Train Acc', linewidth=2)
axes[1].plot([a*100 for a in history['val_acc']], label='Val Acc', linewidth=2)
axes[1].axhline(y=85, color='g', linestyle='--', label='Target', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training and Validation Accuracy', fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

axes[2].plot([f*100 for f in history['val_f1']], label='F1', linewidth=2)
axes[2].plot([p*100 for p in history['val_precision']], label='Precision', linewidth=2)
axes[2].plot([r*100 for r in history['val_recall']], label='Recall', linewidth=2)
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Score (%)')
axes[2].set_title('Validation Metrics', fontweight='bold')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('/content/training_history.png', dpi=150, bbox_inches='tight')
plt.show()
print('‚úÖ Saved to /content/training_history.png')

In [None]:
# Download files
from google.colab import files
import shutil

# Save to Drive
shutil.copy('/content/best_model_checkpoint.pt', '/content/drive/MyDrive/watermark_model.pt')
shutil.copy('/content/training_history.png', '/content/drive/MyDrive/training_history.png')
print('‚úÖ Saved to Google Drive')

# Download
files.download('/content/best_model_checkpoint.pt')
files.download('/content/training_history.png')
print('‚úÖ Files downloaded!')