In [None]:
import os
import itertools
import torch
import torchvision
import mlflow
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as T
import torch.nn as NN
import torch.optim as optim
from PIL import Image
import random

# Detecting if a GPU is available:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device used: {torch.cuda.get_device_name(0)}") if device == "cuda" else print("No GPU available: Using CPU")

# ==============================================================================
# 1. NEW ARCHITECTURE: CycleGenerator (Encoder-Decoder)
# Adapted from your DCGAN layers but mirrored to handle Image-to-Image translation
# ==============================================================================
class CycleGenerator(NN.Module):
    def __init__(self, nc, ngf):
        super(CycleGenerator, self).__init__()
        
        # --- Encoder (Downsampling) ---
        # Similar to your Discriminator structure
        self.e1 = NN.Sequential(NN.Conv2d(nc, ngf, 4, 2, 1, bias=False), NN.LeakyReLU(0.2, inplace=True))
        self.e2 = NN.Sequential(NN.Conv2d(ngf, ngf * 2, 4, 2, 1, bias=False), NN.BatchNorm2d(ngf * 2), NN.LeakyReLU(0.2, inplace=True))
        self.e3 = NN.Sequential(NN.Conv2d(ngf * 2, ngf * 4, 4, 2, 1, bias=False), NN.BatchNorm2d(ngf * 4), NN.LeakyReLU(0.2, inplace=True))
        self.e4 = NN.Sequential(NN.Conv2d(ngf * 4, ngf * 8, 4, 2, 1, bias=False), NN.BatchNorm2d(ngf * 8), NN.LeakyReLU(0.2, inplace=True))
        
        # Bottleneck (Latent representation of the image)
        self.b1 = NN.Sequential(NN.Conv2d(ngf * 8, ngf * 8, 4, 1, 1, bias=False), NN.ReLU(True)) # Just passing through

        # --- Decoder (Upsampling) ---
        # Similar to your original Generator structure
        self.d1 = NN.Sequential(NN.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), NN.BatchNorm2d(ngf * 4), NN.ReLU(True))
        self.d2 = NN.Sequential(NN.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), NN.BatchNorm2d(ngf * 2), NN.ReLU(True))
        self.d3 = NN.Sequential(NN.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), NN.BatchNorm2d(ngf), NN.ReLU(True))
        
        # Final Output
        self.d4 = NN.Sequential(NN.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), NN.Tanh())

    def forward(self, x):
        # Encoder
        x1 = self.e1(x)
        x2 = self.e2(x1)
        x3 = self.e3(x2)
        x4 = self.e4(x3)
        
        # Decoder (Ideally with Skip Connections like U-Net, but keeping it simple as Autoencoder for your structure)
        d1 = self.d1(x4) 
        d2 = self.d2(d1)
        d3 = self.d3(d2)
        out = self.d4(d3)
        return out

# ==============================================================================
# 2. DISCRIMINATOR (Kept exactly as your DCGAN Discriminator)
# ==============================================================================
class Discriminator(NN.Module):
    def __init__(self, nc, ndf):
        super(Discriminator, self).__init__()
        self.main = NN.Sequential(
            # Input is 128 x 128
            NN.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            NN.LeakyReLU(0.2, inplace=True),
            
            NN.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            NN.BatchNorm2d(ndf * 2),
            NN.LeakyReLU(0.2, inplace=True),
            
            NN.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            NN.BatchNorm2d(ndf * 4),
            NN.LeakyReLU(0.2, inplace=True),
            
            NN.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            NN.BatchNorm2d(ndf * 8),
            NN.LeakyReLU(0.2, inplace=True),
            
            NN.Conv2d(ndf * 8, ndf * 16, 4, 2, 1, bias=False), # Handling 128x128 depth
            NN.BatchNorm2d(ndf * 16),
            NN.LeakyReLU(0.2, inplace=True),

            NN.Conv2d(ndf * 16, 1, 4, 1, 0, bias=False)
            # No Sigmoid because we will use MSELoss (LSGAN is standard for CycleGAN)
        )

    def forward(self, input):
        return self.main(input)

# Weights init function:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        NN.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        NN.init.normal_(m.weight.data, 1.0, 0.02)
        NN.init.constant_(m.bias.data, 0)

# Helper to manage image buffer (CycleGAN stability trick)
class ReplayBuffer:
    def __init__(self, max_size=50):
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return torch.cat(to_return)

# ==============================================================================
# 3. TRAINING LOOP ADAPTED FOR CYCLEGAN
# ==============================================================================
def train_cycle_epoch(dataloader_A, dataloader_B, netG_A2B, netG_B2A, netD_A, netD_B, 
                      criterion_GAN, criterion_Cycle, optimizerG, optimizerD, device, lambda_cycle):

    # Metric acumulators:
    running_loss_G = 0.0
    running_loss_D = 0.0
    
    # Fake Image Buffers
    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    # Zip the two dataloaders to get pairs (unpaired) in each step
    # Make sure both dataloaders have drop_last=True to avoid size mismatch
    for i, (data_A, data_B) in enumerate(zip(dataloader_A, dataloader_B)):
        
        real_A = data_A[0].to(device)
        real_B = data_B[0].to(device)
        b_size = real_A.size(0)

        # Labels for LSGAN (MSE Loss)
        # Real = 1.0, Fake = 0.0
        target_real = torch.full((b_size,), 1.0, device=device, requires_grad=False)
        target_fake = torch.full((b_size,), 0.0, device=device, requires_grad=False)

        # --------------------------------------------------
        # Update Generators (G_A2B and G_B2A)
        # --------------------------------------------------
        optimizerG.zero_grad()

        # 1. Identity Loss (Optional but good for preserving colors/shape)
        # G_A2B(B) should equal B if we feed it a picture that is already domain B
        loss_id_A = criterion_Cycle(netG_B2A(real_A), real_A) * lambda_cycle * 0.5
        loss_id_B = criterion_Cycle(netG_A2B(real_B), real_B) * lambda_cycle * 0.5
        
        # 2. GAN Loss
        # GAN Loss D_A(G_A(A))
        fake_B = netG_A2B(real_A)
        output_fake_B = netD_B(fake_B).view(-1)
        loss_GAN_A2B = criterion_GAN(output_fake_B, target_real)

        # GAN Loss D_B(G_B(B))
        fake_A = netG_B2A(real_B)
        output_fake_A = netD_A(fake_A).view(-1)
        loss_GAN_B2A = criterion_GAN(output_fake_A, target_real)

        # 3. Cycle Consistency Loss
        # Forward Cycle: A -> FakeB -> RecA
        rec_A = netG_B2A(fake_B)
        loss_cycle_A = criterion_Cycle(rec_A, real_A) * lambda_cycle

        # Backward Cycle: B -> FakeA -> RecB
        rec_B = netG_A2B(fake_A)
        loss_cycle_B = criterion_Cycle(rec_B, real_B) * lambda_cycle

        # Total Generator Loss
        loss_G = loss_GAN_A2B + loss_GAN_B2A + loss_cycle_A + loss_cycle_B + loss_id_A + loss_id_B
        loss_G.backward()
        optimizerG.step()

        # --------------------------------------------------
        # Update Discriminators (D_A and D_B)
        # --------------------------------------------------
        optimizerD.zero_grad()

        # Discriminator A
        pred_real = netD_A(real_A).view(-1)
        loss_D_real = criterion_GAN(pred_real, target_real)
        
        # Get fake from buffer to stabilize training
        fake_A_ = fake_A_buffer.push_and_pop(fake_A)
        pred_fake = netD_A(fake_A_.detach()).view(-1)
        loss_D_fake = criterion_GAN(pred_fake, target_fake)
        loss_D_A = (loss_D_real + loss_D_fake) * 0.5
        loss_D_A.backward()

        # Discriminator B
        pred_real = netD_B(real_B).view(-1)
        loss_D_real = criterion_GAN(pred_real, target_real)
        
        fake_B_ = fake_B_buffer.push_and_pop(fake_B)
        pred_fake = netD_B(fake_B_.detach()).view(-1)
        loss_D_fake = criterion_GAN(pred_fake, target_fake)
        loss_D_B = (loss_D_real + loss_D_fake) * 0.5
        loss_D_B.backward()

        optimizerD.step()

        running_loss_G += loss_G.item()
        running_loss_D += (loss_D_A.item() + loss_D_B.item())

    n_batches = len(dataloader_A)
    return running_loss_G / n_batches, running_loss_D / n_batches

# Ploting function for CycleGAN:
def plot_cycle_epoch(epoch, netG_A2B, fixed_real_A, plots_filename):
    # Generate images:
    with torch.no_grad():
        # A -> Fake B (Translation)
        fake_B = netG_A2B(fixed_real_A).detach().cpu()
        
    # Show images (Real A vs Fake B)
    real_img = fixed_real_A.cpu()
    
    # Concatenate Real and Generated side by side
    combined_grid = torch.cat((real_img[:8], fake_B[:8]), 0) # Show 8 examples

    plt.figure(figsize=(8,8))
    plt.axis("off")
    plt.title(f"CycleGAN Epoch {epoch}: Real (Top) -> Defective (Bottom)")
    
    plt.imshow(np.transpose(torchvision.utils.make_grid(combined_grid, padding=2, normalize=True, nrow=8).numpy(), (1,2,0)))
    
    # Create directory if not exists (Safety check)
    os.makedirs(os.path.dirname(plots_filename), exist_ok=True)
    plt.savefig(plots_filename, bbox_inches='tight') 
    plt.close()

# ==========================================================================
# ####### MODEL TRAINING EXECUTION #######
# ==========================================================================

# MLFLow information:
EXPERIMENT_NAME = "Defect_Generation_On_Manufactured_Pieces"
RUN_NAME = "Try: CycleGAN for casting-metal pieces for image generation"

# Hyperparameters
HP_LR = 0.0002
HP_NGF = 64
HP_NDF = 64
HP_NC = 3 # RGB Images
HP_N_EPOCHS = 50 
HP_BATCH_SIZE = 16 # CycleGAN uses smaller batches (often 1 or 4, but 16 works for simple cases)
HP_LAMBDA_CYCLE = 10.0 # Weight for cycle consistency

# Transforming the data (128x128 needed for this deeper architecture):
transform = T.Compose([
    T.Resize(128),
    T.CenterCrop(128),
    T.ToTensor(),
    T.Normalize([0.5] * 3, [0.5] * 3)
])

# --------------------------------------------------------------------------
# CRITICAL FIX: DATA LOADING STRATEGY
# --------------------------------------------------------------------------

# 1. Point ImageFolder to the PARENT directory containing both subfolders
# Structure expected:
# .../train/
#       ├── ok_front/   (Good pieces)
#       └── def_front/  (Defective pieces)
root_dir = "../data/raw/casting/casting_data/casting_data/train"

full_dataset = torchvision.datasets.ImageFolder(
    root=root_dir, 
    transform=transform
)

# 2. Identify class mapping
# This tells us which integer corresponds to which folder
class_map = full_dataset.class_to_idx
print(f"Classes found: {class_map}") 
# Expected: {'def_front': 0, 'ok_front': 1} (or vice versa, code handles it dynamically)

idx_ok = class_map['ok_front']
idx_def = class_map['def_front']

# 3. Filter indices to separate domains
# We scan the 'targets' list to find which image belongs to which class
all_targets = torch.tensor(full_dataset.targets)

# Get indices where target == ok_front
indices_A = (all_targets == idx_ok).nonzero(as_tuple=True)[0]

# Get indices where target == def_front
indices_B = (all_targets == idx_def).nonzero(as_tuple=True)[0]

# 4. Create Subsets (Virtual independent datasets)
dataset_A = torch.utils.data.Subset(full_dataset, indices_A) # Domain A: Good
dataset_B = torch.utils.data.Subset(full_dataset, indices_B) # Domain B: Defective

print(f"Domain A (Good) images: {len(dataset_A)}")
print(f"Domain B (Defect) images: {len(dataset_B)}")

# 5. Create Independent Dataloaders
# CycleGAN needs to iterate them separately
loader_A = torch.utils.data.DataLoader(dataset_A, batch_size=HP_BATCH_SIZE, shuffle=True, drop_last=True)
loader_B = torch.utils.data.DataLoader(dataset_B, batch_size=HP_BATCH_SIZE, shuffle=True, drop_last=True)

# --------------------------------------------------------------------------
# END OF FIX
# --------------------------------------------------------------------------

# Initialize Models
# We need 2 Generators and 2 Discriminators
netG_A2B = CycleGenerator(HP_NC, HP_NGF).to(device) # Good -> Defective
netG_B2A = CycleGenerator(HP_NC, HP_NGF).to(device) # Defective -> Good
netD_A = Discriminator(HP_NC, HP_NDF).to(device)    # Is it real Good?
netD_B = Discriminator(HP_NC, HP_NDF).to(device)    # Is it real Defective?

# Init weights
netG_A2B.apply(weights_init)
netG_B2A.apply(weights_init)
netD_A.apply(weights_init)
netD_B.apply(weights_init)

# Loss functions
criterion_GAN = NN.MSELoss() # LSGAN
criterion_Cycle = NN.L1Loss() # L1 is better for cycle consistency

# Optimizers (One for Gs, One for Ds)
optimizerG = optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=HP_LR, betas=(0.5, 0.999))
optimizerD = optim.Adam(itertools.chain(netD_A.parameters(), netD_B.parameters()), lr=HP_LR, betas=(0.5, 0.999))

# Training Loop
MLFLOW_TRACKING = True

if MLFLOW_TRACKING:
    mlflow.set_experiment(EXPERIMENT_NAME)
    with mlflow.start_run(run_name=RUN_NAME):
        
        # Log Params
        mlflow.log_param('architecture', "CycleGAN_Custom")
        mlflow.log_param('batch_size', HP_BATCH_SIZE)
        mlflow.log_param('lambda_cycle', HP_LAMBDA_CYCLE)

        # Fixed sample for visualization (A batch of good parts)
        fixed_real_A = next(iter(loader_A))[0].to(device)

        for epoch in range(HP_N_EPOCHS):
            avg_loss_G, avg_loss_D = train_cycle_epoch(
                loader_A, loader_B, 
                netG_A2B, netG_B2A, netD_A, netD_B, 
                criterion_GAN, criterion_Cycle, optimizerG, optimizerD, 
                device, HP_LAMBDA_CYCLE
            )

            print(f"[{epoch}/{HP_N_EPOCHS}] Loss_G: {avg_loss_G:.4f} Loss_D: {avg_loss_D:.4f}")
            mlflow.log_metric('loss_g', avg_loss_G, step=epoch)
            mlflow.log_metric('loss_d', avg_loss_D, step=epoch)

            if epoch % 5 == 0 or epoch == HP_N_EPOCHS - 1:
                plots_filename = f"../reports/figures/CycleGAN/epoch_{epoch}.png"
                plot_cycle_epoch(epoch, netG_A2B, fixed_real_A, plots_filename)
                mlflow.log_artifact(plots_filename, 'plots')
                # os.remove(plots_filename) # Uncomment to clean up local files

        # Save Models (We mainly care about G_A2B: Good -> Defect)
        torch.save(netG_A2B.state_dict(), "../models/netG_A2B.pth")
        mlflow.log_artifact("../models/netG_A2B.pth", "models")
        print("CycleGAN Training Completed.")

Device used: NVIDIA GeForce RTX 5080
Classes found: {'def_front': 0, 'ok_front': 1}
Domain A (Good) images: 2875
Domain B (Defect) images: 3758
[0/50] Loss_G: 36.0320 Loss_D: 19.9476
[1/50] Loss_G: 16.8887 Loss_D: 3.4917
[2/50] Loss_G: 14.5513 Loss_D: 2.1973
[3/50] Loss_G: 12.3174 Loss_D: 1.3768
[4/50] Loss_G: 10.4103 Loss_D: 1.9147
[5/50] Loss_G: 7.2716 Loss_D: 0.6264
[6/50] Loss_G: 6.5894 Loss_D: 0.7915
[7/50] Loss_G: 6.2022 Loss_D: 0.7059
[8/50] Loss_G: 6.0760 Loss_D: 0.7542
[9/50] Loss_G: 5.8870 Loss_D: 0.7373
[10/50] Loss_G: 5.6783 Loss_D: 0.6622
[11/50] Loss_G: 5.4188 Loss_D: 0.6634
[12/50] Loss_G: 5.3609 Loss_D: 0.7897
[13/50] Loss_G: 4.9059 Loss_D: 0.6043
[14/50] Loss_G: 4.8377 Loss_D: 0.7393
[15/50] Loss_G: 4.5992 Loss_D: 0.6725
[16/50] Loss_G: 4.4612 Loss_D: 0.6570
[17/50] Loss_G: 4.3261 Loss_D: 0.6548
[18/50] Loss_G: 4.5677 Loss_D: 0.7003
[19/50] Loss_G: 4.1818 Loss_D: 0.6151
[20/50] Loss_G: 4.0979 Loss_D: 0.6513
[21/50] Loss_G: 4.1008 Loss_D: 0.7706
[22/50] Loss_G: 3.9398 L