##Working GAN 26X26

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pickle
import matplotlib.pyplot as plt
import os
import time
from IPython import display

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
def spectral_norm_conv(in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
    return nn.utils.spectral_norm(
        nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                 stride=stride, padding=padding, bias=bias)
    )

def spectral_norm_linear(in_features, out_features, bias=True):
    return nn.utils.spectral_norm(
        nn.Linear(in_features, out_features, bias=bias)
    )

In [None]:
# ------------------------------
# Circle Layer
# ------------------------------
class CircleConstraintLayer(nn.Module):
    def __init__(self, image_size=26):
        super().__init__()
        # Create circle mask parameters
        self.image_size = image_size
        self.register_buffer('circle_mask_inside', None)
        self.register_buffer('circle_mask_outside', None)

    def create_masks(self, device):
        # Create masks only once
        center = self.image_size / 2 - 0.5
        radius = self.image_size / 2

        y, x_coords = torch.meshgrid(
            torch.arange(self.image_size, device=device),
            torch.arange(self.image_size, device=device)
        )

        distance = torch.sqrt((x_coords - center)**2 + (y - center)**2)
        inside_mask = (distance <= radius).float()
        outside_mask = 1.0 - inside_mask

        return inside_mask.unsqueeze(0).unsqueeze(0), outside_mask.unsqueeze(0).unsqueeze(0)

    def forward(self, x):
        # Create masks if not created yet
        if self.circle_mask_inside is None or self.circle_mask_inside.device != x.device:
            inside_mask, outside_mask = self.create_masks(x.device)
            self.circle_mask_inside = inside_mask
            self.circle_mask_outside = outside_mask

        # Apply mask to encourage proper values in each region
        batch_size = x.size(0)
        inside_mask = self.circle_mask_inside.expand(batch_size, -1, -1, -1)
        outside_mask = self.circle_mask_outside.expand(batch_size, -1, -1, -1)

        # Push values to be >= 0 inside the circle
        x = torch.where(
            (x < 0) & (inside_mask > 0.5),
            torch.zeros_like(x),  # Replace with 0 if negative inside circle
            x
        )

        # Don't have this as it makes it too perfect a circle
        '''# Push values to be -1 outside the circle
        x = torch.where(
            (x > -1) & (outside_mask > 0.5),
            -torch.ones_like(x),  # Replace with -1 if not -1 outside circle
            x
        )'''

        return x

In [None]:
# ------------------------------
# Quantize to Ternary
# ------------------------------
class QuantizeToTernary(nn.Module):
    """Optimized ternary quantization using PyTorch vectorized ops"""
    def __init__(self, threshold=0.33):
        super().__init__()
        self.threshold = threshold

    def forward(self, inputs):
        # Vectorized ternary quantization with STE
        quantized = torch.where(
            torch.abs(inputs) > self.threshold,
            torch.sign(inputs),
            torch.zeros_like(inputs)
        )
        return inputs + (quantized - inputs).detach()

In [None]:
# ------------------------------
# Feature Extractor
# ------------------------------
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.leaky_relu1 = nn.LeakyReLU(0.2)
        self.max_pool = nn.MaxPool2d(2)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.instance_norm = nn.InstanceNorm2d(64)
        self.leaky_relu2 = nn.LeakyReLU(0.2)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        x = self.leaky_relu1(self.conv1(x))
        x = self.max_pool(x)  # 13x13

        x = self.conv2(x)
        x = self.instance_norm(x)
        x = self.leaky_relu2(x)

        x = self.conv3(x)
        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1)  # Flatten

        return x

In [None]:
# ------------------------------
# Generator
# ------------------------------
class ResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.ConvTranspose2d(channels, channels, kernel_size=3, padding=1)
        self.instance_norm1 = nn.InstanceNorm2d(channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.ConvTranspose2d(channels, channels, kernel_size=3, padding=1)
        self.instance_norm2 = nn.InstanceNorm2d(channels)

    def forward(self, x):
        shortcut = x
        x = self.conv1(x)
        x = self.instance_norm1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.instance_norm2(x)
        return x + shortcut
class Generator(nn.Module):
    def __init__(self, latent_dim=256):
        super().__init__()
        self.latent_dim = latent_dim

        self.fc = nn.Linear(latent_dim, 13*13*256)

        self.res_block = ResBlock(256)

        self.conv_transpose1 = nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1)
        self.instance_norm = nn.InstanceNorm2d(128)
        self.relu = nn.ReLU()

        self.conv_transpose2 = nn.ConvTranspose2d(128, 1, kernel_size=5, padding=2)
        self.tanh = nn.Tanh()
        self.quantizer = QuantizeToTernary(threshold=0.33)
        self.circle_constraint = CircleConstraintLayer()

    def forward(self, x):
        x = self.fc(x)
        x = x.view(-1, 256, 13, 13)

        x = self.res_block(x)

        x = self.conv_transpose1(x)
        x = self.instance_norm(x)
        x = self.relu(x)

        x = self.conv_transpose2(x)
        x = self.tanh(x)
        x = self.quantizer(x)

        ## Implement this but just don't apply this, based on the resuts we can modify it.
        #x = self.circle_constraint(x)

        return x

In [None]:
# ------------------------------
# Discriminator/Critic
# ------------------------------
class Critic(nn.Module):
    def __init__(self):
        super().__init__()

        # First spectral normalized conv layer
        self.conv1 = spectral_norm_conv(1, 64, kernel_size=5, stride=2, padding=2)
        self.leaky_relu1 = nn.LeakyReLU(0.2)

        # Second spectral normalized conv layer
        self.conv2 = spectral_norm_conv(64, 128, kernel_size=5, stride=2, padding=2)
        self.leaky_relu2 = nn.LeakyReLU(0.2)

        # Flattening happens in forward()

        # Final spectral normalized linear layer
        # Calculate the input size for the linear layer
        # After two stride-2 convolutions on a 26x26 input, we get 7x7 (ceiling division)
        self.linear = spectral_norm_linear(128 * 7 * 7, 1)

    def forward(self, x):
        x = self.leaky_relu1(self.conv1(x))
        x = self.leaky_relu2(self.conv2(x))

        x = x.view(x.size(0), -1)  # Flatten
        x = self.linear(x)

        return x


In [None]:
# ------------------------------
# Data Loading
# ------------------------------
class WaferDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.FloatTensor(self.data[idx])

def load_dataset(file_path):
    with open(file_path, "rb") as f:
        data = np.array(pickle.load(f))

    # Ensure data is in the right format (26x26x1) and normalized to [-1, 1]
    if isinstance(data, np.ndarray):
        # If already a numpy array, reshape if needed
        if len(data.shape) == 3:  # [num_samples, height, width]
            data = np.expand_dims(data, axis=1)  # Add channel dimension (PyTorch uses NCHW format)
        elif len(data.shape) == 2:  # [num_samples, flattened_dim]
            data = data.reshape(-1, 1, 26, 26)  # Reshape to NCHW format
        elif len(data.shape) == 4 and data.shape[3] == 1:  # NHWC format from TensorFlow
            data = np.transpose(data, (0, 3, 1, 2))  # Convert to NCHW format
    else:
        raise ValueError("Expected numpy array in the pickle file")

    # Verify the shape is correct (note: PyTorch uses NCHW format)
    assert data.shape[1:] == (1, 26, 26), f"Expected shape (*, 1, 26, 26), got {data.shape}"

    # Verify data range is [-1, 1]
    data_min, data_max = np.min(data), np.max(data)
    if not (np.isclose(data_min, -1) and np.isclose(data_max, 1)):
        print(f"Warning: Data range is [{data_min}, {data_max}], normalizing to [-1, 1]")
        data = 2 * (data - data_min) / (data_max - data_min) - 1

    np.random.seed(42)  # For reproducibility
    np.random.shuffle(data)

    return data

In [None]:
# ------------------------------
# Gradient Penalty
# ------------------------------
def gradient_penalty(critic, real_images, fake_images):
    batch_size = real_images.size(0)

    # Generate random interpolation factors
    alpha = torch.rand(batch_size, 1, 1, 1, device=device)

    # Create interpolated images
    interpolated = alpha * real_images + (1 - alpha) * fake_images
    interpolated.requires_grad_(True)

    # Get critic scores for interpolated images
    pred = critic(interpolated)

    # Calculate gradients w.r.t. interpolated images
    gradients = torch.autograd.grad(
        outputs=pred,
        inputs=interpolated,
        grad_outputs=torch.ones_like(pred, device=device),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    # Compute gradient norms
    gradients = gradients.view(batch_size, -1)
    grad_norms = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)

    # Return gradient penalty (mean squared deviation from 1)
    return torch.mean((grad_norms - 1.0) ** 2)

# ------------------------------
# Perceptual Loss
# ------------------------------
def perceptual_loss(feature_extractor, real, fake):
    real_features = feature_extractor(real)
    fake_features = feature_extractor(fake)
    return torch.mean((real_features - fake_features) ** 2)


In [None]:
# ------------------------------
# Circle Penalty
# ------------------------------
def circle_constraint_loss(generated_images, image_size=26, weight=0.5):
    """
    Creates a loss that penalizes:
    - Values of -1 inside an inscribed circle (loss of 1 per pixel)
    - Values of 0 or 1 outside the inscribed circle (loss of 1 per pixel)

    Args:
        generated_images: Tensor of shape [batch_size, 1, height, width]
        image_size: Size of the square image (assuming height=width)
        weight: Weight factor for this loss component

    Returns:
        Weighted loss tensor
    """
    batch_size = generated_images.size(0)

    # Create a circular mask
    center = image_size / 2 - 0.5  # Center coordinates (0-indexed)
    radius = image_size / 2        # Radius of inscribed circle

    y, x = torch.meshgrid(
        torch.arange(image_size, device=generated_images.device),
        torch.arange(image_size, device=generated_images.device)
    )

    # Calculate distance from center
    distance = torch.sqrt((x - center)**2 + (y - center)**2)

    # Create masks for inside and outside the circle
    inside_mask = (distance <= radius).float().unsqueeze(0).unsqueeze(0)  # Shape: [1, 1, H, W]
    outside_mask = 1 - inside_mask  # Inverse of inside mask

    # Repeat masks for batch size
    inside_mask = inside_mask.repeat(batch_size, 1, 1, 1)
    outside_mask = outside_mask.repeat(batch_size, 1, 1, 1)

    # Calculate penalties:
    # 1. Penalty for -1 values inside the circle (should be 0 or 1)
    inside_penalty = inside_mask * (generated_images == -1).float()

    # 2. Penalty for 0 or 1 values outside the circle (should be -1)
    outside_penalty = outside_mask * (generated_images != -1).float()

    # Sum all penalties
    total_penalty = inside_penalty.sum() * 2 + outside_penalty.sum()

    # Average per image in batch
    avg_penalty = total_penalty / batch_size

    # Apply weight
    return weight * avg_penalty

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_training_losses(gen_losses, critic_losses_real, critic_losses_fake, smoothing_factor=128):
    """
    Plot the training losses over epochs with smoothing.
    
    Args:
        gen_losses (list): Generator losses for each epoch
        critic_losses_real (list): Critic losses for real images for each epoch
        critic_losses_fake (list): Critic losses for fake images for each epoch
        smoothing_factor (int): Window size for moving average smoothing
    """
    epochs = len(gen_losses)
    plt.figure(figsize=(12, 8))
    
    # Apply smoothing using moving average
    def smooth(y, box_pts):
        box = np.ones(box_pts) / box_pts
        y_smooth = np.convolve(y, box, mode='same')
        # Fix the edges where the window extends beyond data
        for i in range(box_pts//2):
            if i < len(y):
                # Left edge
                y_smooth[i] = np.mean(y[:i*2+1])
                # Right edge
                if i < min(box_pts//2, len(y)):
                    y_smooth[-(i+1)] = np.mean(y[-(i*2+1):])
        return y_smooth
    
    # Only apply smoothing if we have enough data
    if epochs >= smoothing_factor:
        gen_smooth = smooth(gen_losses, smoothing_factor)
        critic_real_smooth = smooth(critic_losses_real, smoothing_factor)
        critic_fake_smooth = smooth(critic_losses_fake, smoothing_factor)
    else:
        gen_smooth = gen_losses
        critic_real_smooth = critic_losses_real
        critic_fake_smooth = critic_losses_fake
    
    # Create x-axis for epochs
    x = np.arange(1, epochs + 1)
    
    # Plot the data
    plt.plot(x, gen_smooth, label='Generator Loss (smoothed)', color='blue', linewidth=2)
    plt.plot(x, critic_real_smooth, label='Critic Loss - Real (smoothed)', color='green', linewidth=2)
    plt.plot(x, critic_fake_smooth, label='Critic Loss - Fake (smoothed)', color='red', linewidth=2)
    
    # Add original data as light dotted lines for reference
    plt.plot(x, gen_losses, label='Generator Loss (raw)', color='blue', alpha=0.3, linestyle='dotted')
    plt.plot(x, critic_losses_real, label='Critic Loss - Real (raw)', color='green', alpha=0.3, linestyle='dotted')
    plt.plot(x, critic_losses_fake, label='Critic Loss - Fake (raw)', color='red', alpha=0.3, linestyle='dotted')
    
    # Add labels and legend
    plt.xlabel('Epoch', fontsize=14)
    plt.ylabel('Loss', fontsize=14)
    plt.title('WGAN Training Losses Over Time', fontsize=16)
    plt.legend(loc='best', fontsize=12)
    plt.grid(True, alpha=0.3)
    
    # Add horizontal line at y=0 for reference
    plt.axhline(y=0, color='k', linestyle='-', alpha=0.3)
    
    
    # Show the figure
    plt.show()

In [None]:
# ------------------------------
# Training Functions
# ------------------------------
def train_critic(critic, generator, critic_optimizer, real_images, latent_dim, GP_WEIGHT):
    batch_size = real_images.size(0)

    # Zero gradients
    critic_optimizer.zero_grad()

    # Generate latent vectors
    noise = torch.randn(batch_size, latent_dim, device=device)

    # Generate fake images
    fake_images = generator(noise)

    # Get critic scores
    real_scores = critic(real_images)
    fake_scores = critic(fake_images.detach())  # Detach to avoid training generator

    # Calculate Wasserstein loss
    critic_loss = torch.mean(fake_scores) - torch.mean(real_scores)

    # Add gradient penalty
    gp = gradient_penalty(critic, real_images, fake_images.detach())
    total_critic_loss = critic_loss + GP_WEIGHT * gp

    # Update critic weights
    total_critic_loss.backward()
    critic_optimizer.step()

    return total_critic_loss.item()

def train_generator(generator, critic, feature_extractor, generator_optimizer, batch_size, latent_dim):
    # Zero gradients
    generator_optimizer.zero_grad()

    # Generate latent vectors
    noise = torch.randn(batch_size, latent_dim, device=device)

    # Generate fake images
    fake_images = generator(noise)

    # Get critic scores for fake images
    fake_scores = critic(fake_images)

    # Calculate adversarial loss (negative of critic score)
    adv_loss = -torch.mean(fake_scores)

    # Calculate perceptual loss if feature extractor is provided
    if feature_extractor is not None:
        # Since we don't have real images here, we'll skip the perceptual loss
        # This will be calculated in the main training loop with real images
        perc_loss = torch.tensor(0.0, device=device)
    else:
        perc_loss = torch.tensor(0.0, device=device)

    # Calculate circle constraint loss
    ## Implement this but let the weight be zero.
    circle_loss = circle_constraint_loss(fake_images, image_size=26, weight=0)

    # Final generator loss
    # Do an implementation where the weight gradually grows as the epochs increase.
    total_gen_loss = adv_loss + perc_loss + circle_loss


    # Update generator weights
    total_gen_loss.backward()
    generator_optimizer.step()

    return total_gen_loss.item()

# ------------------------------
# Generate and Save Images
# ------------------------------
def generate_and_save_images(model, epoch, test_input, checkpoint_dir):
    model.eval()  # Set to evaluation mode
    with torch.no_grad():
        predictions = model(test_input)

    # Convert to numpy for plotting
    predictions = predictions.cpu().numpy()

    plt.figure(figsize=(8, 8))
    for i in range(min(16, test_input.size(0))):
        plt.subplot(4, 4, i+1)

        # Get the generated wafer map and scale to [0,1]
        wafer_map = (predictions[i, 0, :, :] + 1) / 2.0  # Assuming [-1,1] input

        # Plot with BRG colormap
        plt.imshow(wafer_map, cmap='brg', vmin=0, vmax=1)

        # Add title
        plt.title(f"Defect Map {i+1}")
        plt.axis('off')

    # Save to checkpoint directory
    save_path = os.path.join(checkpoint_dir, f'epoch_{epoch:04d}.png')
    plt.savefig(save_path)

    # Display the figure
    plt.show()
    plt.close()

    model.train()  # Set back to training mode

# ------------------------------
# Create GIF of Progress
# ------------------------------
def create_progress_gif(image_dir):
    try:
        import imageio
        filenames = sorted([f for f in os.listdir(image_dir) if f.startswith('epoch_')])

        with imageio.get_writer('training_progress.gif', mode='I') as writer:
            for filename in filenames:
                image = imageio.imread(os.path.join(image_dir, filename))
                writer.append_data(image)

        print("GIF created successfully.")
    except Exception as e:
        print(f"Failed to create GIF: {e}")

# ------------------------------
# Main Training Loop
# ------------------------------
def train(dataloader, generator, critic, feature_extractor, generator_optimizer,
          critic_optimizer, epochs, latent_dim, checkpoint_dir, save_interval, GP_WEIGHT):

    # Create fixed noise vector for visualization
    fixed_noise = torch.randn(16, latent_dim, device=device)

    # Track losses for plotting
    gen_losses = []
    critic_losses = []

    start_time = time.time()

    for epoch in range(epochs):
        epoch_start = time.time()

        total_gen_loss = 0
        total_critic_loss = 0
        batches = 0

        # Train on batches
        for batch_images in dataloader:
            batch_images = batch_images.to(device)
            batch_size = batch_images.size(0)

            # Train critic multiple times
            for _ in range(5):
                c_loss = train_critic(critic, generator, critic_optimizer, batch_images, latent_dim, GP_WEIGHT)
                total_critic_loss += c_loss

            # Train generator
            noise = torch.randn(batch_size, latent_dim, device=device)
            fake_images = generator(noise)

            # Calculate perceptual loss if feature extractor is provided
            if feature_extractor is not None:
                perc_loss = 0.1 * perceptual_loss(feature_extractor, batch_images, fake_images)
            else:
                perc_loss = torch.tensor(0.0, device=device)

            # Get fake scores from critic
            fake_scores = critic(fake_images)

            # Calculate generator loss
            adv_loss = -torch.mean(fake_scores)
            total_gen_loss_tensor = adv_loss + perc_loss

            # Update generator
            generator_optimizer.zero_grad()
            total_gen_loss_tensor.backward()
            generator_optimizer.step()

            total_gen_loss += total_gen_loss_tensor.item()

            batches += 1

            # Print occasional batch updates
            if batches % 20 == 0:
                print(f"  Batch {batches}, G Loss: {total_gen_loss_tensor.item():.4f}, C Loss: {c_loss:.4f}")

        # Calculate average losses for the epoch
        avg_gen_loss = total_gen_loss / batches if batches > 0 else float('nan')
        avg_critic_loss = total_critic_loss / (batches * 5) if batches > 0 else float('nan')

        gen_losses.append(avg_gen_loss)
        critic_losses.append(avg_critic_loss)

        # Print status
        epoch_time = time.time() - epoch_start
        print(f'Epoch {epoch+1}/{epochs}, Gen Loss: {avg_gen_loss:.4f}, '
              f'Critic Loss: {avg_critic_loss:.4f}, Time: {epoch_time:.2f}s')

        # Generate/save/show images EVERY epoch
        if (epoch % 5 == 0):
            plot_training_losses(gen_losses, critic_losses_real, critic_losses_fake, smoothing_factor=min(128, len(gen_losses)))
            generate_and_save_images(generator, epoch + 1, fixed_noise, checkpoint_dir)
          

        # Save checkpoint
        if (epoch + 1) % save_interval == 0:
            torch.save({
                'generator_state_dict': generator.state_dict(),
                'critic_state_dict': critic.state_dict(),
                'generator_optimizer_state_dict': generator_optimizer.state_dict(),
                'critic_optimizer_state_dict': critic_optimizer.state_dict(),
                'epoch': epoch,
                'gen_loss': avg_gen_loss,
                'critic_loss': avg_critic_loss
            }, os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pt'))

    total_time = time.time() - start_time
    print(f"Training completed in {total_time:.2f} seconds")

    # Create GIF showing training progress
    create_progress_gif(checkpoint_dir)

    # Save final models
    torch.save(generator, os.path.join(checkpoint_dir, 'Scratch_generator_final.pt'))
    torch.save(critic, os.path.join(checkpoint_dir, 'Scratch_critic_final.pt'))

# ------------------------------
# Generate Images from Trained Model
# ------------------------------
def generate_samples(generator, num_images, latent_dim, output_dir):
    """Generate a specific number of images"""
    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Set model to evaluation mode
    generator.eval()

    # Generate in batches to avoid memory issues
    batch_size = 16
    all_images = []

    with torch.no_grad():
        for i in range(0, num_images, batch_size):
            batch_count = min(batch_size, num_images - i)
            noise = torch.randn(batch_count, latent_dim, device=device)
            images = generator(noise)
            all_images.append(images.cpu())

    # Concatenate all batches
    all_images = torch.cat(all_images, dim=0)

    # Plot a grid of generated images
    plt.figure(figsize=(10, 10))
    for i in range(min(25, num_images)):
        plt.subplot(5, 5, i+1)
        plt.imshow(all_images[i, 0, :, :], cmap='brg')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig(f'{output_dir}/generated_sample_grid.png')
    plt.show()

    return all_images


In [None]:

# ------------------------------
# Main Function
# ------------------------------
if __name__ == "__main__":
    # Parameters
    latent_dim = 256
    BATCH_SIZE = 64
    GP_WEIGHT = 10.0
    EPOCHS = 100
    save_interval = 5
    checkpoint_dir = '/content/Donut 26X26'

    # Create checkpoint directory
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    # Load dataset
    print("Loading dataset...")
    dataset = load_dataset("/content/Donut_output.pkl")  # Update with your path
    print(f"Dataset loaded with shape: {dataset.shape}")

    # Create data loader
    wafer_dataset = WaferDataset(dataset)
    dataloader = DataLoader(wafer_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

    # Initialize models
    feature_extractor = FeatureExtractor().to(device)
    feature_extractor.eval()  # Set to evaluation mode since it's not trained

    generator = Generator(latent_dim=latent_dim).to(device)
    critic = Critic().to(device)

    # Initialize optimizers
    generator_optimizer = optim.RMSprop(generator.parameters(), lr=5e-5)
    critic_optimizer = optim.RMSprop(critic.parameters(), lr=5e-5)

    # Print model summaries
    print("\nGenerator architecture:")
    print(generator)

    print("\nCritic architecture:")
    print(critic)

    print("\nFeature Extractor architecture:")
    print(feature_extractor)

    # Start training
    print("\nStarting training...")
    train(dataloader, generator, critic, feature_extractor, generator_optimizer,
          critic_optimizer, EPOCHS, latent_dim, checkpoint_dir, save_interval, GP_WEIGHT)

    print("Training completed successfully!")

    # Generate samples
    print("Generating samples...")
    generated_samples = generate_samples(generator, 100, latent_dim, 'generated_samples')

    # Save to pickle file
    with open('./generated_data.pkl', 'wb') as f:
        pickle.dump(generated_samples.numpy(), f)

    print("Samples generated and saved successfully!")

## WGAN 128X128

In [None]:
# ------------------------------
# Generator
# ------------------------------
class ResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.ConvTranspose2d(channels, channels, kernel_size=3, padding=1)
        self.instance_norm1 = nn.InstanceNorm2d(channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.ConvTranspose2d(channels, channels, kernel_size=3, padding=1)
        self.instance_norm2 = nn.InstanceNorm2d(channels)

    def forward(self, x):
        shortcut = x
        x = self.conv1(x)
        x = self.instance_norm1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.instance_norm2(x)
        return x + shortcut

class Generator(nn.Module):
    def __init__(self, latent_dim=256):
        super().__init__()
        self.latent_dim = latent_dim

        # Increased initial dense layer size (16x16x512 instead of 13x13x256)
        self.fc = nn.Linear(latent_dim, 16*16*512)

        # Multiple ResBlocks at different resolutions
        self.res_block1 = ResBlock(512)
        self.res_block2 = ResBlock(512)

        # Upsampling path: 16x16 → 32x32 → 64x64 → 128x128
        # First upsampling: 16x16 → 32x32
        self.conv_transpose1 = nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=2, output_padding=1)
        self.instance_norm1 = nn.InstanceNorm2d(256)
        self.relu1 = nn.ReLU()

        # Second upsampling: 32x32 → 64x64
        self.conv_transpose2 = nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1)
        self.instance_norm2 = nn.InstanceNorm2d(128)
        self.relu2 = nn.ReLU()

        # Third upsampling: 64x64 → 128x128
        self.conv_transpose3 = nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1)
        self.instance_norm3 = nn.InstanceNorm2d(64)
        self.relu3 = nn.ReLU()

        # Final output layer
        self.conv_transpose4 = nn.ConvTranspose2d(64, 1, kernel_size=5, padding=2)
        self.tanh = nn.Tanh()
        self.quantizer = QuantizeToTernary(threshold=0.33)
        self.circle_constraint = CircleConstraintLayer(image_size=128)

    def forward(self, x):
        x = self.fc(x)
        x = x.view(-1, 512, 16, 16)

        # Apply ResBlocks
        x = self.res_block1(x)
        x = self.res_block2(x)

        # First upsampling: 16x16 → 32x32
        x = self.conv_transpose1(x)
        x = self.instance_norm1(x)
        x = self.relu1(x)

        # Second upsampling: 32x32 → 64x64
        x = self.conv_transpose2(x)
        x = self.instance_norm2(x)
        x = self.relu2(x)

        # Third upsampling: 64x64 → 128x128
        x = self.conv_transpose3(x)
        x = self.instance_norm3(x)
        x = self.relu3(x)

        # Final output layer
        x = self.conv_transpose4(x)
        x = self.tanh(x)
        x = self.quantizer(x)

        # Uncomment if you want to apply the circle constraint
        # x = self.circle_constraint(x)

        return x

In [None]:
# ------------------------------
# Discriminator/Critic
# ------------------------------
class Critic(nn.Module):
    def __init__(self):
        super().__init__()

        # Convolutional path: 128x128 → 64x64 → 32x32 → 16x16 → 8x8 → 4x4
        # First layer: 128x128 → 64x64
        self.conv1 = spectral_norm_conv(1, 32, kernel_size=5, stride=2, padding=2)
        self.leaky_relu1 = nn.LeakyReLU(0.2)

        # Second layer: 64x64 → 32x32
        self.conv2 = spectral_norm_conv(32, 64, kernel_size=5, stride=2, padding=2)
        self.leaky_relu2 = nn.LeakyReLU(0.2)

        # Third layer: 32x32 → 16x16
        self.conv3 = spectral_norm_conv(64, 128, kernel_size=5, stride=2, padding=2)
        self.leaky_relu3 = nn.LeakyReLU(0.2)

        # Fourth layer: 16x16 → 8x8
        self.conv4 = spectral_norm_conv(128, 256, kernel_size=5, stride=2, padding=2)
        self.leaky_relu4 = nn.LeakyReLU(0.2)

        # Fifth layer: 8x8 → 4x4
        self.conv5 = spectral_norm_conv(256, 512, kernel_size=5, stride=2, padding=2)
        self.leaky_relu5 = nn.LeakyReLU(0.2)

        # Final linear layer with adjusted input size (512 * 4 * 4)
        self.linear = spectral_norm_linear(512 * 4 * 4, 1)

    def forward(self, x):
        x = self.leaky_relu1(self.conv1(x))
        x = self.leaky_relu2(self.conv2(x))
        x = self.leaky_relu3(self.conv3(x))
        x = self.leaky_relu4(self.conv4(x))
        x = self.leaky_relu5(self.conv5(x))

        x = x.view(x.size(0), -1)  # Flatten
        x = self.linear(x)

        return x

In [None]:
# ------------------------------
# Feature Extractor (Optional Modification)
# ------------------------------
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        # Modified for 128x128 input
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.leaky_relu1 = nn.LeakyReLU(0.2)
        self.max_pool1 = nn.MaxPool2d(2)  # 64x64

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.instance_norm1 = nn.InstanceNorm2d(64)
        self.leaky_relu2 = nn.LeakyReLU(0.2)
        self.max_pool2 = nn.MaxPool2d(2)  # 32x32

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.instance_norm2 = nn.InstanceNorm2d(128)
        self.leaky_relu3 = nn.LeakyReLU(0.2)
        self.max_pool3 = nn.MaxPool2d(2)  # 16x16

        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        x = self.leaky_relu1(self.conv1(x))
        x = self.max_pool1(x)  # 64x64

        x = self.conv2(x)
        x = self.instance_norm1(x)
        x = self.leaky_relu2(x)
        x = self.max_pool2(x)  # 32x32

        x = self.conv3(x)
        x = self.instance_norm2(x)
        x = self.leaky_relu3(x)
        x = self.max_pool3(x)  # 16x16

        x = self.conv4(x)
        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1)  # Flatten

        return x

In [None]:
# ------------------------------
# Circle Layer
# ------------------------------
class CircleConstraintLayer(nn.Module):
    def __init__(self, image_size=128):
        super().__init__()
        # Create circle mask parameters
        self.image_size = image_size
        self.register_buffer('circle_mask_inside', None)
        self.register_buffer('circle_mask_outside', None)

    def create_masks(self, device):
        # Create masks only once
        center = self.image_size / 2 - 0.5
        radius = self.image_size / 2

        y, x_coords = torch.meshgrid(
            torch.arange(self.image_size, device=device),
            torch.arange(self.image_size, device=device)
        )

        distance = torch.sqrt((x_coords - center)**2 + (y - center)**2)
        inside_mask = (distance <= radius).float()
        outside_mask = 1.0 - inside_mask

        return inside_mask.unsqueeze(0).unsqueeze(0), outside_mask.unsqueeze(0).unsqueeze(0)

    def forward(self, x):
        # Create masks if not created yet
        if self.circle_mask_inside is None or self.circle_mask_inside.device != x.device:
            inside_mask, outside_mask = self.create_masks(x.device)
            self.circle_mask_inside = inside_mask
            self.circle_mask_outside = outside_mask

        # Apply mask to encourage proper values in each region
        batch_size = x.size(0)
        inside_mask = self.circle_mask_inside.expand(batch_size, -1, -1, -1)
        outside_mask = self.circle_mask_outside.expand(batch_size, -1, -1, -1)

        # Push values to be >= 0 inside the circle
        x = torch.where(
            (x < 0) & (inside_mask > 0.5),
            torch.zeros_like(x),  # Replace with 0 if negative inside circle
            x
        )

        # Don't have this as it makes it too perfect a circle
        '''# Push values to be -1 outside the circle
        x = torch.where(
            (x > -1) & (outside_mask > 0.5),
            -torch.ones_like(x),  # Replace with -1 if not -1 outside circle
            x
        )'''

        return x

In [None]:
# ------------------------------
# Circle Constraint Loss
# ------------------------------
def circle_constraint_loss(generated_images, image_size=128, weight=0.5):  # Changed default from 26 to 128
    """
    Creates a loss that penalizes:
    - Values of -1 inside an inscribed circle (loss of 1 per pixel)
    - Values of 0 or 1 outside the inscribed circle (loss of 1 per pixel)
    """
    # Rest of the implementation remains the same
    # ...
    batch_size = generated_images.size(0)

    # Create a circular mask
    center = image_size / 2 - 0.5  # Center coordinates (0-indexed)
    radius = image_size / 2        # Radius of inscribed circle

    y, x = torch.meshgrid(
        torch.arange(image_size, device=generated_images.device),
        torch.arange(image_size, device=generated_images.device)
    )

    # Calculate distance from center
    distance = torch.sqrt((x - center)**2 + (y - center)**2)

    # Create masks for inside and outside the circle
    inside_mask = (distance <= radius).float().unsqueeze(0).unsqueeze(0)  # Shape: [1, 1, H, W]
    outside_mask = 1 - inside_mask  # Inverse of inside mask

    # Repeat masks for batch size
    inside_mask = inside_mask.repeat(batch_size, 1, 1, 1)
    outside_mask = outside_mask.repeat(batch_size, 1, 1, 1)

    # Calculate penalties:
    # 1. Penalty for -1 values inside the circle (should be 0 or 1)
    inside_penalty = inside_mask * (generated_images == -1).float()

    # 2. Penalty for 0 or 1 values outside the circle (should be -1)
    outside_penalty = outside_mask * (generated_images != -1).float()

    # Sum all penalties
    total_penalty = inside_penalty.sum() * 2 + outside_penalty.sum()

    # Average per image in batch
    avg_penalty = total_penalty / batch_size

    # Apply weight
    return weight * avg_penalty

In [None]:
# ------------------------------
# Data Loading
# ------------------------------
def load_dataset(file_path):
    with open(file_path, "rb") as f:
        data = np.array(pickle.load(f))

    # Ensure data is in the right format (128x128x1) and normalized to [-1, 1]
    if isinstance(data, np.ndarray):
        # If already a numpy array, reshape if needed
        if len(data.shape) == 3:  # [num_samples, height, width]
            data = np.expand_dims(data, axis=1)  # Add channel dimension (PyTorch uses NCHW format)
        elif len(data.shape) == 2:  # [num_samples, flattened_dim]
            data = data.reshape(-1, 1, 128, 128)  # Reshape to NCHW format
        elif len(data.shape) == 4 and data.shape[3] == 1:  # NHWC format from TensorFlow
            data = np.transpose(data, (0, 3, 1, 2))  # Convert to NCHW format
    else:
        raise ValueError("Expected numpy array in the pickle file")

    # Verify the shape is correct (note: PyTorch uses NCHW format)
    assert data.shape[1:] == (1, 128, 128), f"Expected shape (*, 1, 128, 128), got {data.shape}"


    # Verify data range is [-1, 1]
    data_min, data_max = np.min(data), np.max(data)
    if not (np.isclose(data_min, -1) and np.isclose(data_max, 1)):
        print(f"Warning: Data range is [{data_min}, {data_max}], normalizing to [-1, 1]")
        data = 2 * (data - data_min) / (data_max - data_min) - 1

    np.random.seed(42)  # For reproducibility
    np.random.shuffle(data)

    return data

In [None]:
def visualize_input_data(data_path, samples_per_row=5):
    """
    Visualize input data samples using the same colormap as in training.
    Displays 5 samples per row and as many rows as needed.

    Args:
        data_path: Path to the pickle file containing the data
        samples_per_row: Number of samples to display per row (default: 5)
    """
    # Load the dataset
    with open(data_path, "rb") as f:
        data = np.array(pickle.load(f))

    data = np.array(data)
    print("Data shape after np.array:", data.shape)
    


# Usage:
visualize_input_data("/user/apurvara/Donut_augmented_data.pkl")

In [None]:
# ------------------------------
# Main Training Loop with Loss Plotting
# ------------------------------
def train(dataloader, generator, critic, feature_extractor, generator_optimizer,
          critic_optimizer, epochs, latent_dim, checkpoint_dir, save_interval, GP_WEIGHT):

    # Create fixed noise vector for visualization
    fixed_noise = torch.randn(16, latent_dim, device=device)

    # Track losses for plotting
    gen_losses = []
    critic_losses_real = []
    critic_losses_fake = []

    start_time = time.time()

    for epoch in range(epochs):
        epoch_start = time.time()

        total_gen_loss = 0
        total_critic_loss_real = 0
        total_critic_loss_fake = 0
        batches = 0

        # Train on batches
        for batch_images in dataloader:
            batch_images = batch_images.to(device)
            batch_size = batch_images.size(0)

            # Train critic multiple times
            for _ in range(5):
                # Zero gradients
                critic_optimizer.zero_grad()
                
                # Generate latent vectors and fake images
                noise = torch.randn(batch_size, latent_dim, device=device)
                fake_images = generator(noise)
                
                # Get critic scores
                real_scores = critic(batch_images)
                fake_scores = critic(fake_images.detach())
                
                # Track separate losses
                critic_loss_real = -torch.mean(real_scores)
                critic_loss_fake = torch.mean(fake_scores)
                
                # Calculate Wasserstein loss
                critic_loss = critic_loss_fake + critic_loss_real
                
                # Add gradient penalty
                gp = gradient_penalty(critic, batch_images, fake_images.detach())
                total_critic_loss = critic_loss + GP_WEIGHT * gp
                
                # Update critic weights
                total_critic_loss.backward()
                critic_optimizer.step()
                
                # Accumulate losses for averaging
                total_critic_loss_real += critic_loss_real.item()
                total_critic_loss_fake += critic_loss_fake.item()

            # Train generator
            noise = torch.randn(batch_size, latent_dim, device=device)
            fake_images = generator(noise)

            # Calculate perceptual loss if feature extractor is provided
            if feature_extractor is not None:
                perc_loss = 0.1 * perceptual_loss(feature_extractor, batch_images, fake_images)
            else:
                perc_loss = torch.tensor(0.0, device=device)

            # Get fake scores from critic
            fake_scores = critic(fake_images)

            # Calculate generator loss
            adv_loss = -torch.mean(fake_scores)
            total_gen_loss_tensor = adv_loss + perc_loss

            # Update generator
            generator_optimizer.zero_grad()
            total_gen_loss_tensor.backward()
            generator_optimizer.step()

            total_gen_loss += total_gen_loss_tensor.item()

            batches += 1

            # Print occasional batch updates
            if batches % 20 == 0:
                print(f"  Batch {batches}, G Loss: {total_gen_loss_tensor.item():.4f}, "
                      f"C Loss Real: {critic_loss_real.item():.4f}, "
                      f"C Loss Fake: {critic_loss_fake.item():.4f}")

        # Calculate average losses for the epoch
        avg_gen_loss = total_gen_loss / batches if batches > 0 else float('nan')
        avg_critic_loss_real = total_critic_loss_real / (batches * 5) if batches > 0 else float('nan')
        avg_critic_loss_fake = total_critic_loss_fake / (batches * 5) if batches > 0 else float('nan')

        # Store losses for plotting
        gen_losses.append(avg_gen_loss)
        critic_losses_real.append(avg_critic_loss_real)
        critic_losses_fake.append(avg_critic_loss_fake)

        # Print status
        epoch_time = time.time() - epoch_start
        print(f'Epoch {epoch+1}/{epochs}, Gen Loss: {avg_gen_loss:.4f}, '
              f'Critic Loss Real: {avg_critic_loss_real:.4f}, '
              f'Critic Loss Fake: {avg_critic_loss_fake:.4f}, Time: {epoch_time:.2f}s')

        # Generate/save/show images every 5 epochs
        if (epoch % 5 == 0):
            generate_and_save_images(generator, epoch + 1, fixed_noise, checkpoint_dir)

            # Create final loss plot
            plot_training_losses(gen_losses, critic_losses_real, critic_losses_fake, smoothing_factor=min(128, len(gen_losses)))

            
        
        # Save checkpoint
        if (epoch + 1) % save_interval == 0:
            torch.save({
                'generator_state_dict': generator.state_dict(),
                'critic_state_dict': critic.state_dict(),
                'generator_optimizer_state_dict': generator_optimizer.state_dict(),
                'critic_optimizer_state_dict': critic_optimizer.state_dict(),
                'epoch': epoch,
                'gen_loss': avg_gen_loss,
                'critic_loss_real': avg_critic_loss_real,
                'critic_loss_fake': avg_critic_loss_fake
            }, os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pt'))

    total_time = time.time() - start_time
    print(f"Training completed in {total_time:.2f} seconds")

    # Create final loss plot
    plot_training_losses(gen_losses, critic_losses_real, critic_losses_fake, smoothing_factor=min(128, len(gen_losses)))

    
    # Create GIF showing training progress
    create_progress_gif(checkpoint_dir)

    # Save final models
    torch.save(generator, os.path.join(checkpoint_dir, 'Scratch_generator_final.pt'))
    torch.save(critic, os.path.join(checkpoint_dir, 'Scratch_critic_final.pt'))
    
    # Return losses for potential further analysis
    return gen_losses, critic_losses_real, critic_losses_fake

In [None]:
# ------------------------------
# Main Function
# ------------------------------
if __name__ == "__main__":
    # Parameters
    latent_dim = 256
    BATCH_SIZE = 256  # Reduced from 64 to 16 for larger images
    GP_WEIGHT = 10.0
    EPOCHS = 500
    save_interval = 50
    checkpoint_dir = '/user/apurvara/Donut 128X128'

    # Create checkpoint directory
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    # Load dataset
    print("Loading dataset...")
    dataset = load_dataset("/user/apurvara/Donut_augmented_data.pkl")  # Update with your path
    print(f"Dataset loaded with shape: {dataset.shape}")

    # Create data loader
    wafer_dataset = WaferDataset(dataset)
    dataloader = DataLoader(wafer_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

    # Initialize models
    feature_extractor = FeatureExtractor().to(device)
    feature_extractor.eval()  # Set to evaluation mode since it's not trained

    generator = Generator(latent_dim=latent_dim).to(device)
    critic = Critic().to(device)

    # Initialize optimizers
    generator_optimizer = optim.RMSprop(generator.parameters(), lr=5e-5)
    critic_optimizer = optim.RMSprop(critic.parameters(), lr=5e-5)

    # Print model summaries
    print("\nGenerator architecture:")
    print(generator)

    print("\nCritic architecture:")
    print(critic)

    print("\nFeature Extractor architecture:")
    print(feature_extractor)

    # Start training
    print("\nStarting training...")
    train(dataloader, generator, critic, feature_extractor, generator_optimizer,
          critic_optimizer, EPOCHS, latent_dim, checkpoint_dir, save_interval, GP_WEIGHT)

    print("Training completed successfully!")

    # Generate samples
    print("Generating samples...")
    generated_samples = generate_samples(generator, 100, latent_dim, 'generated_samples')

    # Save to pickle file
    with open('./generated_data.pkl', 'wb') as f:
        pickle.dump(generated_samples.numpy(), f)

    print("Samples generated and saved successfully!")

In [None]:

def train(dataloader, generator, critic, feature_extractor, generator_optimizer,
          critic_optimizer, epochs, latent_dim, checkpoint_dir, save_interval, GP_WEIGHT, 
          start_epoch=0, patience=20, min_delta=0.001, device=None):
    """
    Train WGAN with early stopping to prevent overfitting.
    
    Args:
        dataloader: Training data loader
        generator: Generator network
        critic: Critic network
        feature_extractor: Optional feature extractor for perceptual loss
        generator_optimizer: Optimizer for generator
        critic_optimizer: Optimizer for critic
        epochs: Maximum number of epochs to train
        latent_dim: Dimension of latent noise vector
        checkpoint_dir: Directory to save checkpoints and images
        save_interval: Interval (in epochs) to save checkpoints
        GP_WEIGHT: Weight for gradient penalty
        start_epoch: Epoch to start training from (for resuming)
        patience: Number of epochs to wait for improvement before stopping
        min_delta: Minimum change to qualify as an improvement
        device: Device to use for training (cpu or cuda)
    """
    # Create checkpoint directory if it doesn't exist
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    
    # Set device if not specified
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    # Move models to device
    generator = generator.to(device)
    critic = critic.to(device)
    if feature_extractor is not None:
        feature_extractor = feature_extractor.to(device)
    
    # Create fixed noise vector for visualization
    fixed_noise = torch.randn(16, latent_dim, device=device)

    # Track losses for plotting
    # If resuming training, load previous losses if available
    loss_file = os.path.join(checkpoint_dir, 'training_losses.pkl')
    if start_epoch > 0 and os.path.exists(loss_file):
        print("Loading previous training losses...")
        with open(loss_file, 'rb') as f:
            losses = pickle.load(f)
            gen_losses = losses['gen_losses']
            critic_losses_real = losses['critic_losses_real']
            critic_losses_fake = losses['critic_losses_fake']
    else:
        gen_losses = []
        critic_losses_real = []
        critic_losses_fake = []

    # Early stopping variables
    best_val_loss = float('inf')
    counter = 0
    best_epoch = 0
    
    start_time = time.time()
    
    # Adjusted Wasserstein distance tracking for convergence
    prev_wasserstein_distance = None
    convergence_threshold = 0.01  # Threshold for considering model converged
    convergence_counter = 0
    convergence_patience = 10  # Number of epochs with minimal change to declare convergence

    for epoch in range(start_epoch, epochs):
        epoch_start = time.time()
        
        # Set models to training mode
        generator.train()
        critic.train()

        # Training phase
        total_gen_loss = 0
        total_critic_loss_real = 0
        total_critic_loss_fake = 0
        batches = 0

        # Train on batches
        for batch_images in dataloader:
            batch_images = batch_images.to(device)
            batch_size = batch_images.size(0)

            # Train critic multiple times
            for _ in range(5):
                # Zero gradients
                critic_optimizer.zero_grad()
                
                # Generate latent vectors and fake images
                noise = torch.randn(batch_size, latent_dim, device=device)
                fake_images = generator(noise)
                
                # Get critic scores
                real_scores = critic(batch_images)
                fake_scores = critic(fake_images.detach())
                
                # Track separate losses
                critic_loss_real = -torch.mean(real_scores)
                critic_loss_fake = torch.mean(fake_scores)
                
                # Calculate Wasserstein loss
                critic_loss = critic_loss_fake + critic_loss_real
                
                # Add gradient penalty
                gp = gradient_penalty(critic, batch_images, fake_images.detach())
                total_critic_loss = critic_loss + GP_WEIGHT * gp
                
                # Update critic weights
                total_critic_loss.backward()
                critic_optimizer.step()
                
                # Accumulate losses for averaging
                total_critic_loss_real += critic_loss_real.item()
                total_critic_loss_fake += critic_loss_fake.item()

            # Train generator
            noise = torch.randn(batch_size, latent_dim, device=device)
            fake_images = generator(noise)

            # Calculate perceptual loss if feature extractor is provided
            if feature_extractor is not None:
                perc_loss = 0.1 * perceptual_loss(feature_extractor, batch_images, fake_images)
            else:
                perc_loss = torch.tensor(0.0, device=device)

            # Get fake scores from critic
            fake_scores = critic(fake_images)

            # Calculate generator loss
            adv_loss = -torch.mean(fake_scores)
            total_gen_loss_tensor = adv_loss + perc_loss

            # Update generator
            generator_optimizer.zero_grad()
            total_gen_loss_tensor.backward()
            generator_optimizer.step()

            total_gen_loss += total_gen_loss_tensor.item()

            batches += 1

            # Print occasional batch updates
            if batches % 20 == 0:
                print(f"  Batch {batches}, "
                      f"G Loss: {total_gen_loss_tensor.item():.4f}, "
                      f"C Loss Real: {critic_loss_real.item():.4f}, "
                      f"C Loss Fake: {critic_loss_fake.item():.4f}")

        # Calculate average training losses
        avg_gen_loss = total_gen_loss / batches if batches > 0 else float('nan')
        avg_critic_loss_real = total_critic_loss_real / (batches * 5) if batches > 0 else float('nan')
        avg_critic_loss_fake = total_critic_loss_fake / (batches * 5) if batches > 0 else float('nan')

        # Store training losses for plotting
        gen_losses.append(avg_gen_loss)
        critic_losses_real.append(avg_critic_loss_real)
        critic_losses_fake.append(avg_critic_loss_fake)
        
        # Calculate Wasserstein distance
        wasserstein_distance = avg_critic_loss_fake - avg_critic_loss_real
        
        # Check for convergence
        if prev_wasserstein_distance is not None:
            delta_wd = abs(wasserstein_distance - prev_wasserstein_distance)
            if delta_wd < convergence_threshold:
                convergence_counter += 1
                if convergence_counter >= convergence_patience:
                    print(f"\nModel converged at epoch {epoch+1}! Wasserstein distance stable for {convergence_patience} epochs.")
                    break
            else:
                convergence_counter = 0
        
        prev_wasserstein_distance = wasserstein_distance
        
        # Print status
        epoch_time = time.time() - epoch_start
        print(f'Epoch {epoch+1}/{epochs}:')
        print(f'  Train - Gen: {avg_gen_loss:.4f}, Critic Real: {avg_critic_loss_real:.4f}, '
              f'Critic Fake: {avg_critic_loss_fake:.4f}')
        print(f'  Wasserstein Distance: {wasserstein_distance:.4f}, Convergence Counter: {convergence_counter}/{convergence_patience}')
        print(f'  Time: {epoch_time:.2f}s')
        
        # Generate/save/show images periodically
        if (epoch % 5 == 0) or (epoch == epochs - 1):
            generate_and_save_images(generator, epoch + 1, fixed_noise, checkpoint_dir)
        
        # Plot losses
        plot_training_losses(
            gen_losses, critic_losses_real, critic_losses_fake,
            smoothing_factor=min(128, len(gen_losses))
        )
        
        # Save current losses
        losses = {
            'gen_losses': gen_losses,
            'critic_losses_real': critic_losses_real,
            'critic_losses_fake': critic_losses_fake
        }
        with open(os.path.join(checkpoint_dir, 'training_losses.pkl'), 'wb') as f:
            pickle.dump(losses, f)
        
        # Apply early stopping based on Wasserstein distance stability
        current_val_loss = abs(wasserstein_distance)
        
        # Check if loss improved
        if current_val_loss < best_val_loss - min_delta:
            best_val_loss = current_val_loss
            counter = 0
            best_epoch = epoch
            
            # Save best model
            torch.save({
                'generator_state_dict': generator.state_dict(),
                'critic_state_dict': critic.state_dict(),
                'generator_optimizer_state_dict': generator_optimizer.state_dict(),
                'critic_optimizer_state_dict': critic_optimizer.state_dict(),
                'epoch': epoch,
                'best_val_loss': best_val_loss
            }, os.path.join(checkpoint_dir, 'best_model.pt'))
            
            print(f"  ✓ Wasserstein distance improved to {best_val_loss:.6f} - Saved best model")
        else:
            counter += 1
            print(f"  ✗ No improvement in Wasserstein distance for {counter} epochs. "
                  f"Best: {best_val_loss:.6f} at epoch {best_epoch+1}")
        
        # Check if we should stop early
        if counter >= patience:
            print(f"\nEarly stopping triggered after {epoch+1} epochs!")
            print(f"Best model was at epoch {best_epoch+1} with Wasserstein distance {best_val_loss:.6f}")
            break
            
        # Regular checkpoint saving
        if (epoch + 1) % save_interval == 0 or (epoch == epochs - 1):
            torch.save({
                'generator_state_dict': generator.state_dict(),
                'critic_state_dict': critic.state_dict(),
                'generator_optimizer_state_dict': generator_optimizer.state_dict(),
                'critic_optimizer_state_dict': critic_optimizer.state_dict(),
                'epoch': epoch,
                'gen_loss': avg_gen_loss,
                'critic_loss_real': avg_critic_loss_real,
                'critic_loss_fake': avg_critic_loss_fake
            }, os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pt'))

    total_time = time.time() - start_time
    print(f"Training completed in {total_time:.2f} seconds")
    
    # Create final loss plot
    plot_training_losses(
        gen_losses, critic_losses_real, critic_losses_fake,
        smoothing_factor=min(128, len(gen_losses))
    )

    # Create GIF showing training progress
    create_progress_gif(checkpoint_dir)
    
    # Load the best model before returning
    best_checkpoint = torch.load(os.path.join(checkpoint_dir, 'best_model.pt'))
    generator.load_state_dict(best_checkpoint['generator_state_dict'])
    critic.load_state_dict(best_checkpoint['critic_state_dict'])
    
    # Save final models (best versions)
    torch.save(generator, os.path.join(checkpoint_dir, 'Scratch_generator_final.pt'))
    torch.save(critic, os.path.join(checkpoint_dir, 'Scratch_critic_final.pt'))
    
    return gen_losses, critic_losses_real, critic_losses_fake

In [None]:
# ------------------------------ 
# Main Function for Resuming Training 
# ------------------------------ 
if __name__ == "__main__":
    # Parameters
    latent_dim = 256
    BATCH_SIZE = 256  
    GP_WEIGHT = 10.0
    ADDITIONAL_EPOCHS = 1000  # Train for 500 more epochs (total will be 1000)
    save_interval = 50
    checkpoint_dir = '/user/apurvara/Donut 128X128'
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Create checkpoint directory
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    
    # Load dataset
    print("Loading dataset...")
    dataset = load_dataset("/user/apurvara/Donut_augmented_data.pkl")
    print(f"Dataset loaded with shape: {dataset.shape}")
    
    # Create data loader
    wafer_dataset = WaferDataset(dataset)
    dataloader = DataLoader(wafer_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
    
    # Initialize models
    feature_extractor = FeatureExtractor().to(device)
    feature_extractor.eval()  # Set to evaluation mode since it's not trained
    
    generator = Generator(latent_dim=latent_dim).to(device)
    critic = Critic().to(device)
    
    # Initialize optimizers with lower learning rate for better convergence
    generator_optimizer = optim.RMSprop(generator.parameters(), lr=1e-5)  # Reduced from 5e-5
    critic_optimizer = optim.RMSprop(critic.parameters(), lr=1e-5)  # Reduced from 5e-5
    
    # Find the latest checkpoint
    checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith('checkpoint_epoch_')]
    if checkpoint_files:
        # Extract epoch numbers from filenames
        epoch_numbers = [int(f.split('_')[-1].split('.')[0]) for f in checkpoint_files]
        latest_epoch = max(epoch_numbers)
        latest_checkpoint = f'checkpoint_epoch_{latest_epoch}.pt'
        latest_checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)
        
        print(f"Loading checkpoint: {latest_checkpoint}")
        checkpoint = torch.load(latest_checkpoint_path, map_location=device)
        
        # Load model states
        generator.load_state_dict(checkpoint['generator_state_dict'])
        critic.load_state_dict(checkpoint['critic_state_dict'])
        
        # Load optimizer states
        generator_optimizer.load_state_dict(checkpoint['generator_optimizer_state_dict'])
        critic_optimizer.load_state_dict(checkpoint['critic_optimizer_state_dict'])
        
        # Update learning rates in the loaded optimizer states
        for param_group in generator_optimizer.param_groups:
            param_group['lr'] = 1e-5
        
        for param_group in critic_optimizer.param_groups:
            param_group['lr'] = 1e-5
        
        # Get the starting epoch
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resuming training from epoch {start_epoch}")
    else:
        start_epoch = 0
        print("No checkpoint found. Starting from scratch.")
    
    print("\nGenerator architecture:")
    print(generator)
    
    print("\nCritic architecture:")
    print(critic)
    
    print("\nFeature Extractor architecture:")
    print(feature_extractor)
    
    # Start training
    print("\nStarting/Resuming training...")
    
    # Set convergence parameters
    patience = 30  # Increased patience for a more thorough assessment
    min_delta = 0.001  # Minimum improvement threshold
    
    train(dataloader=dataloader, 
          generator=generator, 
          critic=critic, 
          feature_extractor=feature_extractor, 
          generator_optimizer=generator_optimizer,
          critic_optimizer=critic_optimizer, 
          epochs=start_epoch + ADDITIONAL_EPOCHS, 
          latent_dim=latent_dim, 
          checkpoint_dir=checkpoint_dir, 
          save_interval=save_interval, 
          GP_WEIGHT=GP_WEIGHT, 
          start_epoch=start_epoch,
          patience=patience,
          min_delta=min_delta,
          device=device)
    
    print("Training completed successfully!")
    
    # Generate samples
    print("Generating samples...")
    generated_samples = generate_samples(generator, 100, latent_dim, 'generated_samples', device)
    
    # Save to pickle file
    with open('./generated_data.pkl', 'wb') as f:
        pickle.dump(generated_samples.numpy(), f)
    
    print("Samples generated and saved successfully!")

In [None]:
def generate_samples_from_best_model(checkpoint_dir, output_dir, num_images=100, latent_dim=256, device=None):
    """
    Generate samples using the best model saved during training.
    
    Args:
        checkpoint_dir: Directory containing model checkpoints
        output_dir: Directory to save generated samples
        num_images: Number of images to generate
        latent_dim: Dimension of latent noise vector
        device: Device to use for inference
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Load the best model
    best_model_path = os.path.join(checkpoint_dir, 'best_model.pt')
    if os.path.exists(best_model_path):
        print(f"Loading best model from {best_model_path}")
        checkpoint = torch.load(best_model_path, map_location=device)
        
        # Initialize generator
        generator = Generator(latent_dim=latent_dim).to(device)
        
        # Load weights
        generator.load_state_dict(checkpoint['generator_state_dict'])
        
        # Set to evaluation mode
        generator.eval()
        
        # Generate in batches
        batch_size = 16
        all_images = []
        
        with torch.no_grad():
            for i in range(0, num_images, batch_size):
                batch_count = min(batch_size, num_images - i)
                noise = torch.randn(batch_count, latent_dim, device=device)
                images = generator(noise)
                all_images.append(images.cpu())
        
        # Concatenate all batches
        all_images = torch.cat(all_images, dim=0)
        
        # Plot a grid of generated images
        plt.figure(figsize=(15, 15))
        for i in range(min(25, num_images)):
            plt.subplot(5, 5, i+1)
            # Adjust based on your image format
            plt.imshow(all_images[i, 0, :, :].numpy(), cmap='brg')
            plt.axis('off')
        plt.tight_layout()
        plt.savefig(f'{output_dir}/best_model_samples_grid.png', dpi=300)
        plt.show()
        
        # Save individual images
        for i in range(min(100, num_images)):
            img = all_images[i, 0, :, :].numpy()
            plt.figure(figsize=(3, 3))
            plt.imshow(img, cmap='brg')
            plt.axis('off')
            plt.savefig(f'{output_dir}/sample_{i+1}.png', dpi=150)
            plt.close()
        
        # Save as pickle file
        with open(f'{output_dir}/best_model_samples.pkl', 'wb') as f:
            pickle.dump(all_images.numpy(), f)
        
        print(f"Generated {num_images} samples from the best model and saved to {output_dir}")
        return all_images
    else:
        print("Error: Best model checkpoint not found!")
        return None

In [None]:
# After training is complete, generate samples from the best model
print("Generating samples from the best model...")
output_dir = 'best_model_samples'
best_samples = generate_samples_from_best_model(
    checkpoint_dir=checkpoint_dir,
    output_dir=output_dir,
    num_images=200,
    latent_dim=latent_dim,
    device=device
)