# CAE Training on Google Colab
**Weld Defect Detection - Convolutional Autoencoder**

This notebook trains the CAE model using Colab's free GPU (10-20x faster than CPU).

## 1. Setup - Enable GPU
Go to **Runtime → Change runtime type → GPU (T4)**

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Mount Google Drive & Upload Dataset
Upload your `cae/dataset` folder to Google Drive first.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Set your dataset path (adjust if needed)
DATASET_PATH = '/content/drive/MyDrive/RIAWELC/cae/dataset'
OUTPUT_PATH = '/content/drive/MyDrive/RIAWELC/cae/models'

import os
os.makedirs(OUTPUT_PATH, exist_ok=True)
print(f"Dataset path: {DATASET_PATH}")
print(f"Output path: {OUTPUT_PATH}")

## 3. Define Model & Dataset Classes

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pathlib import Path
import numpy as np
from tqdm import tqdm

# ImageNet normalization (same as CNN)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.block(x)


class DeconvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.block(x)


class CAE(nn.Module):
    """Convolutional Autoencoder for Anomaly Detection"""
    
    def __init__(self, in_channels=3, latent_dim=128):
        super().__init__()
        self.latent_dim = latent_dim
        
        # Encoder
        self.enc1 = ConvBlock(in_channels, 32)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.enc2 = ConvBlock(32, 64)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.enc3 = ConvBlock(64, 128)
        self.pool3 = nn.MaxPool2d(2, 2)
        
        # Bottleneck
        self.bottleneck = ConvBlock(128, latent_dim)
        
        # Decoder
        self.dec3 = DeconvBlock(latent_dim, 128)
        self.dec2 = DeconvBlock(128, 64)
        self.dec1 = DeconvBlock(64, 32)
        
        # Output (no Sigmoid - using ImageNet normalized values)
        self.output = nn.Conv2d(32, in_channels, kernel_size=3, padding=1)
    
    def encode(self, x):
        x = self.pool1(self.enc1(x))
        x = self.pool2(self.enc2(x))
        x = self.pool3(self.enc3(x))
        x = self.bottleneck(x)
        return x
    
    def decode(self, z):
        x = self.dec3(z)
        x = self.dec2(x)
        x = self.dec1(x)
        x = self.output(x)
        return x
    
    def forward(self, x):
        z = self.encode(x)
        return self.decode(z)


class NormalDataset(Dataset):
    """Dataset for training: normal images only"""
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        extensions = ['*.png', '*.jpg', '*.jpeg', '*.bmp', '*.tif', '*.tiff']
        self.image_paths = []
        for ext in extensions:
            self.image_paths.extend(self.root_dir.glob(ext))
        self.image_paths = sorted(self.image_paths)
        print(f"NormalDataset: Found {len(self.image_paths)} images")
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, str(self.image_paths[idx])


class AnomalyDataset(Dataset):
    """Dataset for validation/testing: normal + defects with labels"""
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.defect_types = []
        
        extensions = ['*.png', '*.jpg', '*.jpeg', '*.bmp', '*.tif', '*.tiff']
        
        # Normal images (label=0)
        normal_dir = self.root_dir / 'normal'
        if normal_dir.exists():
            for ext in extensions:
                for p in normal_dir.glob(ext):
                    self.image_paths.append(p)
                    self.labels.append(0)
                    self.defect_types.append('ND')
        
        # Defect images (label=1)
        defect_dir = self.root_dir / 'defect'
        if defect_dir.exists():
            for defect_type in ['CR', 'LP', 'PO']:
                type_dir = defect_dir / defect_type
                if type_dir.exists():
                    for ext in extensions:
                        for p in type_dir.glob(ext):
                            self.image_paths.append(p)
                            self.labels.append(1)
                            self.defect_types.append(defect_type)
        
        n_normal = sum(1 for l in self.labels if l == 0)
        n_defect = sum(1 for l in self.labels if l == 1)
        print(f"AnomalyDataset: {len(self.image_paths)} images ({n_normal} normal, {n_defect} defects)")
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, self.labels[idx], self.defect_types[idx], str(self.image_paths[idx])


def get_transforms(image_size=224, augment=False):
    if augment:
        return transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.1, contrast=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
        ])
    else:
        return transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
        ])

print("Model and Dataset classes defined!")

## 4. Training Configuration

In [None]:
# Training hyperparameters (optimized for Colab GPU)
CONFIG = {
    'image_size': 224,      # Full resolution
    'batch_size': 64,       # Larger batch for GPU
    'epochs': 50,
    'learning_rate': 1e-3,
    'weight_decay': 1e-5,
    'latent_dim': 128,
    'num_workers': 2,       # Colab works well with 2
}

print("Training Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## 5. Create DataLoaders

In [None]:
# Create datasets
train_transform = get_transforms(CONFIG['image_size'], augment=True)
val_transform = get_transforms(CONFIG['image_size'], augment=False)

train_dataset = NormalDataset(f"{DATASET_PATH}/training/normal", train_transform)
val_dataset = AnomalyDataset(f"{DATASET_PATH}/validation", val_transform)
test_dataset = AnomalyDataset(f"{DATASET_PATH}/testing", val_transform)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], 
                          shuffle=True, num_workers=CONFIG['num_workers'], pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], 
                        shuffle=False, num_workers=CONFIG['num_workers'], pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=CONFIG['batch_size'], 
                         shuffle=False, num_workers=CONFIG['num_workers'], pin_memory=True)

print(f"\nTrain: {len(train_loader)} batches")
print(f"Val: {len(val_loader)} batches")
print(f"Test: {len(test_loader)} batches")

## 6. Initialize Model & Training

In [None]:
from sklearn.metrics import roc_auc_score, f1_score
import torch.optim as optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create model
model = CAE(latent_dim=CONFIG['latent_dim']).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Loss, optimizer, scheduler
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)

# Training history
history = {'train_loss': [], 'val_auc': [], 'threshold': []}
best_auc = 0
best_threshold = 0

## 7. Training Loop

In [None]:
def validate(model, val_loader, device):
    """Validate and compute AUC"""
    model.eval()
    all_errors, all_labels = [], []
    
    with torch.no_grad():
        for images, labels, _, _ in val_loader:
            images = images.to(device)
            recon = model(images)
            errors = torch.mean((images - recon) ** 2, dim=[1, 2, 3])
            all_errors.extend(errors.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    all_errors = np.array(all_errors)
    all_labels = np.array(all_labels)
    
    # Compute AUC
    auc = roc_auc_score(all_labels, all_errors)
    
    # Find optimal threshold
    thresholds = np.percentile(all_errors[all_labels == 0], [90, 95, 99])
    best_f1, best_thresh = 0, thresholds[1]
    for t in np.linspace(all_errors.min(), all_errors.max(), 100):
        preds = (all_errors > t).astype(int)
        f1 = f1_score(all_labels, preds)
        if f1 > best_f1:
            best_f1, best_thresh = f1, t
    
    return auc, best_thresh, best_f1


# Training loop
print("Starting training...")
print("=" * 60)

for epoch in range(CONFIG['epochs']):
    model.train()
    train_loss = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']}")
    for images, _ in pbar:
        images = images.to(device)
        
        recon = model(images)
        loss = criterion(recon, images)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.6f}'})
    
    train_loss /= len(train_loader)
    
    # Validate
    val_auc, threshold, val_f1 = validate(model, val_loader, device)
    scheduler.step(val_auc)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['val_auc'].append(val_auc)
    history['threshold'].append(threshold)
    
    print(f"Epoch {epoch+1}: Loss={train_loss:.6f}, AUC={val_auc:.4f}, F1={val_f1:.4f}, Thresh={threshold:.6f}")
    
    # Save best model
    if val_auc > best_auc:
        best_auc = val_auc
        best_threshold = threshold
        torch.save({
            'model_state_dict': model.state_dict(),
            'threshold': threshold,
            'auc': val_auc,
            'config': CONFIG,
            'epoch': epoch
        }, f"{OUTPUT_PATH}/best_cae_model.pth")
        print(f"  ★ New best model saved! AUC: {val_auc:.4f}")

print("\n" + "=" * 60)
print(f"Training complete! Best AUC: {best_auc:.4f}")

## 8. Test Final Model

In [None]:
# Load best model
checkpoint = torch.load(f"{OUTPUT_PATH}/best_cae_model.pth")
model.load_state_dict(checkpoint['model_state_dict'])
threshold = checkpoint['threshold']

# Test
model.eval()
all_errors, all_labels, all_types = [], [], []

with torch.no_grad():
    for images, labels, defect_types, _ in tqdm(test_loader, desc="Testing"):
        images = images.to(device)
        recon = model(images)
        errors = torch.mean((images - recon) ** 2, dim=[1, 2, 3])
        all_errors.extend(errors.cpu().numpy())
        all_labels.extend(labels.numpy())
        all_types.extend(defect_types)

all_errors = np.array(all_errors)
all_labels = np.array(all_labels)

# Compute metrics
test_auc = roc_auc_score(all_labels, all_errors)
predictions = (all_errors > threshold).astype(int)
test_f1 = f1_score(all_labels, predictions)
accuracy = np.mean(predictions == all_labels)

print(f"\n{'='*40}")
print(f"TEST RESULTS")
print(f"{'='*40}")
print(f"AUC Score: {test_auc:.4f}")
print(f"F1 Score: {test_f1:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print(f"Threshold: {threshold:.6f}")

# Per-class analysis
print(f"\nPer-class Detection Rate:")
for dtype in ['ND', 'CR', 'LP', 'PO']:
    mask = np.array(all_types) == dtype
    if mask.sum() > 0:
        if dtype == 'ND':
            rate = np.mean(predictions[mask] == 0)  # Normal should be predicted as 0
            print(f"  {dtype} (Normal): {rate:.2%} correct")
        else:
            rate = np.mean(predictions[mask] == 1)  # Defects should be predicted as 1
            print(f"  {dtype} (Defect): {rate:.2%} detected")

## 9. Download Model

In [None]:
# Save final model with all info
torch.save({
    'model_state_dict': model.state_dict(),
    'threshold': threshold,
    'test_auc': test_auc,
    'test_f1': test_f1,
    'config': CONFIG,
}, f"{OUTPUT_PATH}/cae_final.pth")

print(f"Model saved to: {OUTPUT_PATH}/cae_final.pth")
print("\nDownload from Google Drive or use:")

# Download directly
from google.colab import files
files.download(f"{OUTPUT_PATH}/cae_final.pth")

## 10. Plot Training History

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history['train_loss'])
axes[0].set_title('Training Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('MSE Loss')

axes[1].plot(history['val_auc'])
axes[1].set_title('Validation AUC')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('AUC Score')
axes[1].axhline(y=best_auc, color='r', linestyle='--', label=f'Best: {best_auc:.4f}')
axes[1].legend()

plt.tight_layout()
plt.savefig(f"{OUTPUT_PATH}/training_history.png", dpi=150)
plt.show()