In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.utils as vutils
from torch.utils.data import DataLoader
import torch
import csv
# Disable benchmark mode and enable deterministic mode
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


# ----------------------------
# Enable Debugging
# ----------------------------
torch.autograd.set_detect_anomaly(True)

# ----------------------------
# Device Setup
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ----------------------------
# Data Preparation
# ----------------------------
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = datasets.ImageFolder(root="C:/Users/abhis/Downloads/Abisheck_Chandru_Python_code/code/Dataset/healthy", transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)

# ----------------------------
# Define Multi-Scale Perceptual Loss Network with Residual Connections
# ----------------------------
class MultiScalePerceptualLossNetwork(nn.Module):
    def __init__(self):
        super(MultiScalePerceptualLossNetwork, self).__init__()
        
        def conv_block(in_channels, out_channels, kernel_size, stride, padding):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
        
        # Two parallel branches with different kernel sizes
        self.branch1 = nn.Sequential(
            conv_block(3, 64, 3, 1, 1),
            conv_block(64, 128, 3, 1, 1),
            conv_block(128, 256, 3, 1, 1)
        )
        
        self.branch2 = nn.Sequential(
            conv_block(3, 64, 5, 1, 2),
            conv_block(64, 128, 5, 1, 2),
            conv_block(128, 256, 5, 1, 2)
        )
        
        # Residual connection to directly map input to feature space
        self.residual = nn.Conv2d(3, 256, kernel_size=1)
        
    def forward(self, x):
        return self.branch1(x) + self.branch2(x) + self.residual(x)

perceptual_net = MultiScalePerceptualLossNetwork().to(device)

def perceptual_loss(fake, real):
    real_feats = perceptual_net(real)
    fake_feats = perceptual_net(fake)
    return torch.nn.functional.mse_loss(fake_feats, real_feats)

# ----------------------------
# Define the Vector Quantizer (EMA based)
# ----------------------------
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay=0.99, epsilon=1e-5):
        super(VectorQuantizer, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost
        self.decay = decay
        self.epsilon = epsilon

        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1 / num_embeddings, 1 / num_embeddings)

        self.ema_cluster_size = torch.zeros(num_embeddings, device=device)
        self.ema_weights = torch.zeros_like(self.embedding.weight, device=device)

    def forward(self, inputs):
        input_shape = inputs.shape
        flat_input = inputs.permute(0, 2, 3, 1).contiguous().view(-1, self.embedding_dim)

        distances = (
            torch.sum(flat_input ** 2, dim=1, keepdim=True)
            + torch.sum(self.embedding.weight ** 2, dim=1)
            - 2 * torch.matmul(flat_input, self.embedding.weight.t())
        )

        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)
        quantized = torch.matmul(encodings, self.embedding.weight)

        quantized = quantized.view(input_shape[0], input_shape[2], input_shape[3], self.embedding_dim)
        quantized = quantized.permute(0, 3, 1, 2).contiguous()

        # EMA Update
        with torch.no_grad():
            self.ema_cluster_size = self.decay * self.ema_cluster_size + (1 - self.decay) * encodings.sum(0)
            self.ema_weights = self.decay * self.ema_weights + (1 - self.decay) * torch.matmul(encodings.t(), flat_input)
            n = self.ema_cluster_size.sum()
            self.ema_cluster_size = ((self.ema_cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n)
            self.embedding.weight.data = self.ema_weights / self.ema_cluster_size.unsqueeze(1)

        e_latent_loss = torch.mean((quantized.detach() - inputs) ** 2)
        q_latent_loss = torch.mean((quantized - inputs.detach()) ** 2)
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        return quantized.clone(), loss

# ----------------------------
# Define Residual Block for Encoder/Decoder
# ----------------------------
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(channels)
        )
    
    def forward(self, x):
        return x + self.block(x)

# ----------------------------
# Define the Encoder with Residual Connections
# ----------------------------
class Encoder(nn.Module):
    def __init__(self, in_channels=3, hidden_dim=256, latent_dim=256):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, kernel_size=4, stride=2, padding=1),
            ResidualBlock(hidden_dim),
            nn.Conv2d(hidden_dim, latent_dim, kernel_size=4, stride=2, padding=1),
            ResidualBlock(latent_dim)
        )
    
    def forward(self, x):
        return self.encoder(x)

# ----------------------------
# Define the Decoder with Residual Connections
# ----------------------------
class Decoder(nn.Module):
    def __init__(self, out_channels=3, latent_dim=256, hidden_dim=256):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            ResidualBlock(latent_dim),
            nn.ConvTranspose2d(latent_dim, hidden_dim, kernel_size=4, stride=2, padding=1),
            ResidualBlock(hidden_dim),
            nn.ConvTranspose2d(hidden_dim, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.decoder(x)

# ----------------------------
# Define the VQGAN Model
# ----------------------------
class VQGAN(nn.Module):
    def __init__(self):
        super(VQGAN, self).__init__()
        self.encoder = Encoder()
        self.quantizer = VectorQuantizer(1024, 256, 0.25)
        self.decoder = Decoder()
    
    def forward(self, x):
        z = self.encoder(x)
        quantized, vq_loss = self.quantizer(z)
        x_recon = self.decoder(quantized)
        return x_recon, vq_loss

# ----------------------------
# Define the Discriminator
# ----------------------------
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 1, kernel_size=4, stride=1, padding=0)
        )
    
    def forward(self, x):
        return torch.sigmoid(self.net(x).view(x.size(0), -1))

# ----------------------------
# Initialize Models
# ----------------------------
vqgan = VQGAN().to(device)
discriminator = Discriminator().to(device)

optimizer_vq = optim.Adam(vqgan.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
recon_loss_fn = nn.L1Loss()
adv_loss_fn = nn.BCELoss()
perceptual_weight = 0.05


# Create CSV file and write headers
with open("training_losses.csv", mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(["Epoch", "Recon Loss", "Adv Loss", "D Loss", "Perceptual Loss"])



# ----------------------------
# Training Loop with Image Saving
# ----------------------------
    for epoch in range(1000):
        adv_weight = min(0.1 + epoch * 0.0005, 0.5)  # Gradually increase adversarial weight
    
        for images, _ in dataloader:
            images = images.to(device)

        # Train Discriminator
            optimizer_d.zero_grad()
            real_preds = discriminator(images)
            fake_images, _ = vqgan(images)
            fake_preds = discriminator(fake_images.detach())
            real_loss = adv_loss_fn(real_preds, torch.ones_like(real_preds))
            fake_loss = adv_loss_fn(fake_preds, torch.zeros_like(fake_preds))
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_d.step()

        # Train VQGAN (Generator)
            optimizer_vq.zero_grad()
            fake_preds = discriminator(fake_images)
            adv_loss = adv_loss_fn(fake_preds, torch.ones_like(fake_preds))
            recon_loss = recon_loss_fn(fake_images, images)
            perceptual_loss_val = perceptual_loss(fake_images, images)
            total_loss = recon_loss + adv_weight * adv_loss + perceptual_weight * perceptual_loss_val
            total_loss.backward()
            optimizer_vq.step()

    # Save losses at the end of each epoch
        writer.writerow([epoch + 1, recon_loss.item(), adv_loss.item(), d_loss.item(), perceptual_loss_val.item()])
    # Save images every 100 epochs & final epoch
        if epoch % 100 == 0 or epoch == 999:
            with torch.no_grad():
                sample_images, _ = next(iter(dataloader))
                sample_images = sample_images.to(device)
                recon_sample, _ = vqgan(sample_images)
                vutils.save_image(recon_sample, f"for report_epoch_{epoch}.png", normalize=True)
    
        print(f"Epoch {epoch+1}/1000 | Recon Loss: {recon_loss.item():.4f} | Adv Loss: {adv_loss.item():.4f} | D Loss: {d_loss.item():.4f} | Perceptual Loss: {perceptual_loss_val.item():.4f}")

# Save final model
torch.save(vqgan.state_dict(), "report2novgg_vqgan.pth")
torch.save(discriminator.state_dict(), "report2nongg_discriminator.pth")


Using device: cuda
Epoch 1/1000 | Recon Loss: 0.1914 | Adv Loss: 0.8340 | D Loss: 0.6188 | Perceptual Loss: 0.7564
Epoch 2/1000 | Recon Loss: 0.1719 | Adv Loss: 1.5738 | D Loss: 0.3653 | Perceptual Loss: 0.7164
Epoch 3/1000 | Recon Loss: 0.2083 | Adv Loss: 1.8808 | D Loss: 0.2331 | Perceptual Loss: 0.8009
