In [1]:
import torch
import torch.nn as nn
import torchvision.models as models


In [None]:
import numpy as np
import random

def set_seed(seed=42):
    """
    Set the seed for reproducibility in PyTorch, NumPy, and Python's random module on MPS.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    # For MPS (Apple Silicon)
    if torch.backends.mps.is_available():
        torch.mps.manual_seed(seed)
        print("Seed set for MPS.")
    
    torch.use_deterministic_algorithms(True, warn_only=True)
        
    print(f"Seed set to: {seed}")

# Example Usage
set_seed(42)

Seed set for MPS.
Seed set to: 42


In [3]:

class VGG_trans_conv(nn.Module):
    def __init__(self):
        super(VGG_trans_conv, self).__init__()
        
        # Encoder: VGG16 without last maxpool and FC layers
        vgg16 = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
        self.encoder = nn.Sequential(*list(vgg16.features.children())[:-1])
        
        # Decoder: Transposed Convolution layers for upsampling
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),  # 1 channel for edge detection
            nn.Sigmoid()  # Output in range [0, 1]
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [4]:
import torch.nn.functional as F

class BalancedBCELoss(nn.Module):
    def __init__(self):
        super(BalancedBCELoss, self).__init__()

    def forward(self, pred, target):
        """
        Compute class-balanced binary cross-entropy loss using HED paper's formula.

        Args:
            pred (torch.Tensor): Predicted probabilities (values between 0 and 1).
            target (torch.Tensor): Ground truth edge maps (binary: 0 or 1).

        Returns:
            torch.Tensor: Class-balanced BCE loss.
        """
        # Ensure predictions are in range for numerical stability
        pred = torch.clamp(pred, 1e-6, 1 - 1e-6)
        
        # Calculate class counts
        pos_count = torch.sum(target)
        neg_count = target.numel() - pos_count
        
        # Compute β using the HED paper formula
        beta = neg_count / (pos_count + neg_count + 1e-6)

        # Compute class-balanced BCE loss using PyTorch's BCE function
        weights = beta * target + (1 - beta) * (1 - target)
        loss = F.binary_cross_entropy(pred, target, weight=weights)

        return loss

In [5]:
import logging
import os
import torch
import csv

# Configure logger
logging.basicConfig(filename='training.log', level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')

def train_and_validate(model, train_loader, val_loader, criterion, optimizer, num_epochs=100,
                        save_path='checkpoints', model_filename='model.pth', csv_filename='losses.csv'):
    
    # Check and set device
    if torch.backends.mps.is_available():
        device = torch.device('mps')
        logging.info("Using MPS (Apple Metal) for acceleration.")
    elif torch.cuda.is_available():
        device = torch.device('cuda')
        logging.info("Using CUDA for acceleration.")
    else:
        device = torch.device('cpu')
        logging.info("Using CPU for training.")

        
    model.to(device)

    logging.info("Starting training...")
    logging.info(f"Using device: {device}")
    logging.info(f"Number of epochs: {num_epochs}")
    logging.info(f"Batch size: {train_loader.batch_size}")
    logging.info(f"Learning rate: {optimizer.param_groups[0]['lr']}")
    logging.info(f"Model architecture: {model}")
    logging.info(f"Criterion: {criterion}")
    logging.info(f"Optimizer: {optimizer}")
    logging.info(f"Training data size: {len(train_loader.dataset)}")
    logging.info(f"Validation data size: {len(val_loader.dataset)}")
    logging.info(f"Save path: {save_path}")
    logging.info("Training started...")
    logging.info("Creating save directory if it doesn't exist...")
    
    os.makedirs(save_path, exist_ok=True)
    model_path = os.path.join(save_path, model_filename)
    csv_path = os.path.join(save_path, csv_filename)

    # Create CSV and write headers
    with open(csv_path, mode='w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["Epoch", "Train Loss", "Validation Loss"])

    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        # Training Phase
        model.train()
        epoch_loss = 0
        for i, (images, edges) in enumerate(train_loader):
            images, edges = images.to(device), edges.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = criterion(outputs, edges)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            
            logging.info(f'Iteration {i+1}/{len(train_loader)} of Epoch {epoch+1}/{num_epochs}')

        # Save model after every epoch (saving the model with architecture)
        torch.save(model, model_path)
        logging.info(f'Epoch {epoch+1}/{num_epochs}: Model saved to {model_path}')

        train_loss = epoch_loss / len(train_loader)
        train_losses.append(train_loss)

        # Validation Phase
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for images, edges in val_loader:
                images, edges = images.to(device), edges.to(device)

                outputs = model(images)
                loss = criterion(outputs, edges)
                val_loss += loss.item()

        val_loss /= len(val_loader)
        val_losses.append(val_loss)

        # Save to CSV
        with open(csv_path, mode='a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([epoch+1, train_loss, val_loss])

        # Log Progress
        log_msg = f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss}, Validation Loss: {val_loss}'
        logging.info(log_msg)
        print(log_msg)
    
    logging.info("Training completed.")
    
    return train_losses, val_losses

In [None]:
# Dataloader for BSDS500 dataset
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class BSDS500(Dataset):
    def __init__(self, image_dir, edge_dir, transform=None, edge_transform=None):
        """
        Custom dataloader for BSDS500 edge detection dataset using JPG ground truth.

        Args:
            image_dir (str): Path to image directory (train, val, test).
            edge_dir (str): Path to corresponding edge ground truth directory.
            transform (callable, optional): Transformations for images.
            edge_transform (callable, optional): Transformations for edge maps.
        """
        self.image_dir = image_dir
        self.edge_dir = edge_dir
        self.transform = transform
        self.edge_transform = edge_transform
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load Image
        img_name = self.image_files[idx]
        img_path = os.path.join(self.image_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        # Load Ground Truth Edge Image
        edge_path = os.path.join(self.edge_dir, img_name)
        edge_image = Image.open(edge_path).convert('L')

        # Apply transformations
        if self.transform:
            image = self.transform(image)
        if self.edge_transform:
            edge_image = self.edge_transform(edge_image)

        return image, edge_image

# Separate transforms
image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

edge_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

g = torch.Generator()
g.manual_seed(42)

# Create Dataloaders
train_dataset = BSDS500(image_dir='archive/images/train', edge_dir='archive/ground_truth_boundaries/train',
                         transform=image_transform, edge_transform=edge_transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0, generator=g)

val_dataset = BSDS500(image_dir='archive/images/val', edge_dir='archive/ground_truth_boundaries/val',
                       transform=image_transform, edge_transform=edge_transform)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=True, num_workers=0, generator=g)


In [7]:
import torch.optim as optim
# Initialize model, criterion, and optimizer
model = VGG_trans_conv()
criterion = BalancedBCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train and Validate
train_losses, val_losses = train_and_validate(model, train_loader, val_loader, criterion, optimizer, num_epochs=100)


Epoch [1/100], Train Loss: 0.0173794745395963, Validation Loss: 0.029778603427112103


KeyboardInterrupt: 

In [None]:
test_dataset = BSDS500(image_dir='archive/images/test', edge_dir='archive/ground_truth_boundaries/test',
                        transform=image_transform, edge_transform=edge_transform)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True) 

In [None]:
import matplotlib.pyplot as plt

def plot_results(model, dataloader, threshold=0.25, device='cpu', num_batches=2):
    model.eval()
    batch_count = 0
    
    with torch.no_grad():
        for images, edges in dataloader:
            images = images.to(device)
            edges = edges.unsqueeze(1).to(device)

            outputs = model(images)
            predictions = (outputs > threshold).float()

            for i in range(len(images)):
                plt.figure(figsize=(12, 4))
                
                plt.subplot(1, 3, 1)
                plt.imshow(images[i].cpu().permute(1, 2, 0))
                plt.title("Input Image")

                plt.subplot(1, 3, 2)
                plt.imshow(edges[i].cpu().squeeze(), cmap='gray')
                plt.title("Ground Truth")

                plt.subplot(1, 3, 3)
                plt.imshow(predictions[i].cpu().squeeze(), cmap='gray')
                plt.title("Predicted Edges")

                plt.show()
                
            batch_count += 1
            if batch_count >= num_batches:
                break


In [None]:
plot_results(model, test_loader, threshold=0.28, device='cpu', num_batches=2)