### Importar librerias, configuraciones y la ATT-GAN

In [1]:
import torch
import torch.nn as nn
from nn import LinearBlock, Conv2dBlock, ConvTranspose2dBlock
from torchsummary import summary
import json
import torch.utils.data as data
import h5py
import pickle as pkl
import torchvision.transforms as transforms
from jmetal.algorithm.multiobjective import NSGAII
from jmetal.operator import SBXCrossover, PolynomialMutation
from jmetal.util.termination_criterion import StoppingByEvaluations
from jmetal.util.termination_criterion import StoppingByQualityIndicator
from jmetal.util.observer import ProgressBarObserver
from jmetal.core.problem import FloatProblem
from jmetal.core.solution import FloatSolution
from jmetal.lab.experiment import Experiment, Job, generate_summary_from_experiment
from jmetal.core.quality_indicator import *
import numpy as np
import cv2
import matplotlib.pyplot as plt
import pandas as pd
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from jmetal.core.quality_indicator import HyperVolume
import os
import inspect

# custom scripts
from my_attgan import AttGAN
from my_mivolo_inference import mivolo_inference
from my_mivolo_inference import predictor
from cf_utils import *
from generate_gender_cfs import AttGanPlausibleCounterfactualProblem
from data import Custom
from data import CelebA_HQ_custom

Model summary (fused): 112 layers, 68,125,494 parameters, 0 gradients, 257.4 GFLOPs


In [2]:
# Load settings and base att names

with open('./384_shortcut1_inject1_none_hq/setting.txt', 'r') as f:
    gan_args = json.load(f)
    
base_attrs = gan_args.get('attrs')

# Load AttGAN

attgan = AttGAN(gan_args)
attgan.load('./384_shortcut1_inject1_none_hq/weights.149.pth')
attgan.eval()

### Cargar los datos para generar los CF

In [25]:
custom_image = 'from_front'
path = './Counterfactuals/Front_13053_0.pkl'

if custom_image == 'custom':

  # Load dataloader ATTGAN repo

  test_dataset = Custom('./data/custom', './data/list_attr_custom.txt', gan_args.get('img_size'), gan_args.get('attrs'))
  test_dataloader = data.DataLoader(
      test_dataset, batch_size=1, num_workers=gan_args.get('num_workers'),
      shuffle=False, drop_last=False
  )
  # Normalizes data using mean 0.5 and std 0.5 -> range [-1, 1]

  # Test Data
  
  training_set_images = torch.tensor([])
  training_set_attributes = torch.tensor([])
  training_set_file_names = []
  
  for idx, (img_a, att_a, file_nm) in enumerate(test_dataloader):
    training_set_images = torch.cat((training_set_images, img_a), dim = 0)
    training_set_attributes = torch.cat((training_set_attributes, att_a), dim = 0)
    training_set_file_names.append(file_nm)
    
  # Receives PIL Image (384, 384, 3) x (0, 255)
  # To Tensor scales to (0, 1)
  # Normalize scales to (-1, 1) [(0 - 0.5)/0.5, (1 - 0.5)/0.5]
  # Outputs tensor (1, 3, 384, 384) x (-1, 1)
  
elif custom_image == 'random':
  
  # CelebaHQ N samples

  celeba_path = './celeba_hq_dataset/CelebA-HQ-img'
  atts_path = './celeba_hq_dataset/CelebAMask-HQ-attribute-anno.txt'
  base_attrs = gan_args.get('attrs')

  sample_celeba_data = CelebA_HQ_custom(
                        data_path = celeba_path,
                        attr_path = atts_path,
                        selected_attrs = base_attrs,
                        image_size = gan_args.get('img_size'),
                        mode = 'train'
                      )

  sample_celeba_dataloader = data.DataLoader(
                              sample_celeba_data, batch_size= 1, num_workers=gan_args.get('num_workers'),
                              shuffle=True, drop_last=False
                            )
  
  data_iterator = iter(sample_celeba_dataloader)
  training_set_images, training_set_attributes, training_set_file_names = next(data_iterator)
  
elif custom_image == 'from_front':
  with open(path, 'rb') as f:
    pareto_front = pkl.load(f)
    training_set_images = pkl.load(f)
    training_set_attributes = pkl.load(f)
    training_set_file_names = [(str.split(path, '_')[1] + '.jpg', )]

### Generar los CFs

In [34]:
# Define factual image

null_intervention = True

sample_idx = 0
factual_img = torch.index_select(training_set_images, 0, torch.tensor(sample_idx))
factual_atts = torch.index_select(training_set_attributes, 0, torch.tensor(sample_idx))
img_file_name = training_set_file_names[sample_idx][0]
d_real, dc_real = attgan.D(factual_img)
prediction_orig, _ = mivolo_inference(factual_img, True)

if prediction_orig >= 0.5:
  desired_pred = 0
else:
  desired_pred = 1
  
if null_intervention:
  desired_pred = 1 - desired_pred

In [18]:
%matplotlib inline

# Generate CFs

# Hyperparameters

pop_size = 100
max_evals = 25000

# Solver 

problem = AttGanPlausibleCounterfactualProblem(
            image = factual_img, 
            code = dc_real, # use the predicted scores for each attribute
            decoder = attgan.G, 
            discriminator = attgan.D, 
            classifier = mivolo_inference, 
            original_pred = prediction_orig,
            original_discriminator_score = d_real,
            desired_pred = desired_pred,
            use_lpips = True,
            non_actionable_features = None
          )

algorithm = NSGAII(
             problem=problem,
             population_size=pop_size,
             offspring_population_size=pop_size,
             mutation=PolynomialMutation(
                 probability=1/problem._number_of_variables,
                 distribution_index=20),
             crossover=SBXCrossover(probability=1.0, distribution_index=20),
             termination_criterion=StoppingByEvaluations(max_evaluations=max_evals)
         )

#jobs = [Job(algorithm, algorithm_tag = 'NSGAII', problem_tag = 'Gender_CF', run = max_evals)]
#experiment = Experiment(output_dir='./Counterfactuals', jobs=jobs)
#pareto_front = experiment.jobs[0].algorithm.result()

progress_bar = ProgressBarObserver(max=max_evals)
algorithm.observable.register(progress_bar)
        
algorithm.run()
pareto_front = algorithm.result()
runtime_in_seconds = round(algorithm.total_computing_time, 3)

write_pkl_file(img_file_name, algorithm, problem, factual_img, factual_atts, pareto_front, runtime_in_seconds, desired_pred, pop_size, max_evals)

[2025-11-26 08:59:57,102] [jmetal.core.algorithm] [DEBUG] Creating initial set of solutions...
[2025-11-26 08:59:57,103] [jmetal.core.algorithm] [DEBUG] Evaluating solutions...
[2025-11-26 09:01:13,258] [jmetal.core.algorithm] [DEBUG] Initializing progress...
Progress:   0%|          | 0/25000 [00:00<?, ?it/s][2025-11-26 09:01:13,292] [jmetal.core.algorithm] [DEBUG] Running main loop until termination criteria is met
Progress: 100%|##########| 25000/25000 [5:57:19<00:00,  1.17it/s]  
[2025-11-26 14:58:32,662] [jmetal.core.algorithm] [DEBUG] Finished!


In [33]:
write_pkl_file(img_file_name, algorithm, problem, factual_img, factual_atts, pareto_front, runtime_in_seconds, desired_pred, pop_size, max_evals)