In [None]:
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

from torchvision import models
import torch.nn as nn

import torch
import torch.optim as optim

import os
from torchvision.utils import save_image

from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR

from torchvision.models import vgg16
import torch.nn.functional as F

In [None]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
])

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Load dataset
dataset = ImageFolder(root='Path to Train dataset', transform=transform)
val_dataset = ImageFolder(root='Path to Val dataset', transform=transform)
dataloader = DataLoader(dataset, batch_size=12, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=12, shuffle=False)

In [None]:
# Encoder: Truncated ConvNext-Base
convnext_base = models.convnext_base(pretrained=True)
encoder = nn.Sequential(*list(convnext_base.children())[:-2])  # Remove classifier layers
encoder = encoder.to(device)

bottleneck_dim = 1024
bottleneck = nn.Conv2d(1024, bottleneck_dim, kernel_size=1).to(device)

In [None]:
class ConvNextDecoder(nn.Module):
    def __init__(self):
        super(ConvNextDecoder, self).__init__()
        self.upsample_blocks = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose2d(1024, 768, kernel_size=4, stride=2, padding=1),
                nn.LayerNorm([768, 16, 16]),
                nn.GELU()
            ),
            nn.Sequential(
                nn.ConvTranspose2d(768, 512, kernel_size=4, stride=2, padding=1),
                nn.LayerNorm([512, 32, 32]),
                nn.GELU()
            ),
            nn.Sequential(
                nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
                nn.LayerNorm([256, 64, 64]),
                nn.GELU()
            ),
            nn.Sequential(
                nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
                nn.LayerNorm([128, 128, 128]),
                nn.GELU()
            ),
            nn.Sequential(
                nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
                nn.LayerNorm([64, 256, 256]),
                nn.GELU()
            )
        ])
        self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1)
        self.tanh = nn.Tanh()

    def forward(self, x, encoder_features=None):
        for idx, block in enumerate(self.upsample_blocks):
            x = block(x)
            if encoder_features is not None and idx < len(encoder_features):
                x += encoder_features[-(idx+1)]
        x = self.final_conv(x)
        return self.tanh(x)
        
decoder = ConvNextDecoder().to(device)

In [None]:
def extract_encoder_features(encoder, images):
    features = []
    x = images
    for layer in encoder.children():
        x = layer(x)
        if isinstance(layer, nn.Conv2d):
            features.append(x)
    return features, x

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(bottleneck_dim, 768, kernel_size=3, padding=1),
            nn.GELU(),
            nn.LayerNorm([768, 8, 8]),

            nn.Conv2d(768, 512, kernel_size=3, padding=1),
            nn.GELU(),
            nn.LayerNorm([512, 8, 8]),

            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.GELU(),
            nn.LayerNorm([256, 8, 8]),

            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.GELU(),
            nn.LayerNorm([128, 8, 8]),

            nn.Dropout(0.3),  # Regularization
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.GELU(),
            nn.LayerNorm([64, 8, 8])
        )
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, latent_vectors):
        x = self.conv_layers(latent_vectors)
        x = self.fc_layers(x)
        return x
discriminator = Discriminator().to(device)

In [None]:
os.makedirs('Reconstruct', exist_ok=True)

def save_reconstructed_images(epoch, images, reconstructed_images):
    save_path = f'Reconstruct/epoch_{epoch}_to_see.png'
    comparison = torch.cat([images, reconstructed_images])
    save_image(comparison, save_path, nrow=4)

In [None]:
reconstruction_loss_fn = nn.MSELoss()
adversarial_loss_fn = nn.BCELoss()

encoder_optimizer = optim.Adam(encoder.parameters(), lr = 0.0001)
decoder_optimizer = optim.Adam(decoder.parameters(), lr = 0.0001)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.00005)

In [None]:
def compute_gradient_penalty(discriminator, real_samples, fake_samples):
    batch_size = real_samples.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1, device=real_samples.device)

    interpolates = alpha * real_samples + (1 - alpha) * fake_samples
    interpolates = interpolates.requires_grad_(True)

    d_interpolates = discriminator(interpolates)

    fake = torch.ones(d_interpolates.size(), device=real_samples.device)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradients = gradients.reshape(batch_size, -1)

    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [None]:
def hinge_loss(real_pred, fake_pred):
    real_loss = torch.nn.functional.relu(1.0 - real_pred).mean()
    fake_loss = torch.nn.functional.relu(1.0 + fake_pred).mean()
    return real_loss + fake_loss

In [None]:
num_epochs = 200
adv_weight_base = 0.0005
gp_weight = 5
n_critic = 5
best_val_loss = float('inf')
patience = 5
counter = 0

d_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(discriminator_optimizer, T_max=50)
e_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(encoder_optimizer, T_max=50)

# Integrate changes into the training loop
for epoch in range(num_epochs):
    encoder.train()
    decoder.train()
    discriminator.train()

    epoch_recon_loss = 0
    epoch_encoder_adv_loss = 0
    epoch_discriminator_loss = 0

    real_mean = 0
    fake_mean = 0
    n_batches = 0

    for images, _ in dataloader:
        batch_size = images.size(0)
        images = images.cuda()
        n_batches += 1

        for _ in range(n_critic):
            discriminator.zero_grad()

            with torch.no_grad():
                encoder_features, latent_vectors = extract_encoder_features(encoder, images)
                latent_vectors = bottleneck(latent_vectors)
            real_latent_vectors = torch.randn_like(latent_vectors).cuda()

            noise = 0.01 * torch.randn_like(latent_vectors).cuda()
            latent_vectors = latent_vectors + noise
            real_latent_vectors = real_latent_vectors + noise

            real_pred = discriminator(real_latent_vectors)
            fake_pred = discriminator(latent_vectors.detach())

            # Compute hinge loss and gradient penalty
            d_loss = hinge_loss(real_pred, fake_pred)
            gradient_penalty = compute_gradient_penalty(discriminator, real_latent_vectors, latent_vectors.detach())
            d_loss = d_loss + gp_weight * gradient_penalty

            d_loss.backward()
            torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=0.5)
            discriminator_optimizer.step()

        # Train Encoder-Decoder
        encoder.zero_grad()
        decoder.zero_grad()

        encoder_features, latent_vectors = extract_encoder_features(encoder, images)
        latent_vectors = bottleneck(latent_vectors)
        reconstructed_images = decoder(latent_vectors, encoder_features)
        recon_loss = F.mse_loss(reconstructed_images, images)

        fake_pred = discriminator(latent_vectors)
        adv_weight = adv_weight_base * (epoch / num_epochs)  # Dynamic weighting
        g_loss = -fake_pred.mean()

        # Total loss prioritizing reconstruction
        total_loss = recon_loss + adv_weight * g_loss

        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=0.5)
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=0.5)
        encoder_optimizer.step()
        decoder_optimizer.step()

        with torch.no_grad():
            real_mean += torch.sigmoid(real_pred).mean().item()
            fake_mean += torch.sigmoid(fake_pred).mean().item()

        epoch_recon_loss += recon_loss.item()
        epoch_encoder_adv_loss += g_loss.item()
        epoch_discriminator_loss += d_loss.item()

    epoch_recon_loss /= n_batches
    epoch_encoder_adv_loss /= n_batches
    epoch_discriminator_loss /= n_batches
    real_mean /= n_batches
    fake_mean /= n_batches

    d_scheduler.step()
    e_scheduler.step()

    print(f"Epoch [{epoch+1}/{num_epochs}] | Recon Loss: {epoch_recon_loss:.4f} | "
          f"Encoder Adv Loss: {epoch_encoder_adv_loss:.6f} | Discriminator Loss: {epoch_discriminator_loss:.4f}")
    print(f"Real Mean: {real_mean:.4f}, Fake Mean: {fake_mean:.4f}")


    if real_mean < 0.6 or fake_mean > 0.6:
        print("Warning: Discriminator may be overconfident!")

    # Validation step
    encoder.eval()
    decoder.eval()
    val_loss = 0
    with torch.no_grad():
        for val_images, _ in val_dataloader:
            val_images = val_images.cuda()
            val_features = encoder(val_images)
            val_latent_vectors = bottleneck(val_features)
            val_reconstructed_images = decoder(val_latent_vectors)

            val_loss += F.mse_loss(val_reconstructed_images, val_images).item()

        val_loss /= len(val_dataloader)

    print(f"Validation Loss: {val_loss:.6f}")

    # Check for early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0

        torch.save({
            'encoder': encoder,
            'decoder': decoder,
            'discriminator': discriminator,
            'bottleneck': bottleneck
        }, "To save the complete AAE")
        print(f"Models Saved...")
        
        with torch.no_grad():
            sample_images, _ = next(iter(val_dataloader))
            sample_images = sample_images.cuda()
            features = encoder(sample_images)
            latent_vectors = bottleneck(features)
            reconstructed_samples = decoder(latent_vectors)
            save_reconstructed_images(epoch, sample_images, reconstructed_samples)
    else:
        counter += 1
        print(f"Val loss didn't improve for {counter} epochs")
        if counter >= patience:
            print("Early stopping triggered.")
            break

In [None]:
checkpoint = torch.load("To save the complete AAE", weights_only = False)
encoder = checkpoint['encoder'].to(device)
bottleneck = checkpoint['bottleneck'].to(device)
decoder = checkpoint['decoder'].to(device)
discriminator = checkpoint['discriminator'].to(device)

encoder.eval()
decoder.eval()
bottleneck.eval()
discriminator.eval()

with torch.no_grad():
    sample_images, _ = next(iter(val_dataloader))
    sample_images = sample_images.cuda()

    features = encoder(sample_images)
    latent_vectors = bottleneck(features)

    # Add noise if required
    noise = 0.01 * torch.randn_like(latent_vectors).cuda()
    latent_vectors = latent_vectors + noise

    reconstructed_samples = decoder(latent_vectors)
    save_reconstructed_images(222, sample_images, reconstructed_samples)