In [1]:
import warnings
import os
import sys
# try to import peal and if not installed, add the parent directory to the path
try:
    import peal

except ImportError:
    # if peal not installed, but project downloaded locally
    module_path = os.path.abspath(os.path.join('..'))
    if module_path not in sys.path:
        sys.path.append(module_path)

# import basic libraries needed for sure and set the device depending on whether cuda is available or not
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

warnings.filterwarnings('ignore')

In [2]:
from peal.global_utils import load_yaml_config
from peal.data.datasets import SymbolicDataset

unpoisened_dataset_config = load_yaml_config('<PEAL_BASE>/configs/data/circle_dataset_diffusion_unpoisened.yaml')

dataset = SymbolicDataset(data_dir=unpoisened_dataset_config.dataset_path, mode='train', config=unpoisened_dataset_config)

In [3]:
train_size = int(0.95 * len(dataset))  # 80% for training
test_size = len(dataset) - train_size 
train_set, test_set = torch.utils.data.random_split(dataset, [train_size, test_size])

In [4]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm, trange
from typing import Tuple
import logging
from torch.utils.data import DataLoader
import math
%matplotlib inline

logging.getLogger().setLevel(logging.INFO)

class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=500):

        super(PositionalEncoding, self).__init__()
        max_len += 1
        self.P = torch.zeros(max_len, embed_dim)
        freqs = torch.arange(max_len)[:, None] / (torch.pow(10000, torch.arange(0, embed_dim, 2, dtype=torch.float32)/embed_dim))

        self.P[:,0::2] = torch.sin(freqs)
        self.P[:,1::2] = torch.cos(freqs)
        
        self.P = self.P[1:]
        
    def forward(self, t):
        return self.P[t]
    
class ScoreNetwork(nn.Module):
    def __init__(self, input_dim, embed_dim):
        super(ScoreNetwork, self).__init__()
        self.embed_dim = embed_dim
        self.layer1 = nn.LazyLinear(embed_dim)
        self.layer2 = nn.LazyLinear(embed_dim)
        self.layer3 = nn.LazyLinear(embed_dim)
        self.norm = nn.LayerNorm(embed_dim)
        self.layer4 = nn.LazyLinear(input_dim)
    
    def forward(self, x, time_embed):
        x = self.layer1(x) + time_embed
        x = F.silu(self.layer2(x))
        x = F.silu(self.layer3(x))  
        return self.layer4((self.norm(x)))
                           
class BasicDiscreteTimeModel(nn.Module):
    def __init__(self, input_dim: int, embed_dim: int, num_timesteps: int):
        super(BasicDiscreteTimeModel, self).__init__()

        self.positional_embeddings = PositionalEncoding(embed_dim=embed_dim, max_len=num_timesteps)
        self.score_network = ScoreNetwork(input_dim=input_dim, embed_dim=embed_dim)

    def forward(self, x, t):

        time_embed = self.positional_embeddings(t)
        return self.score_network(x, time_embed)

    

class CircleDiffusionAdaptor(nn.Module):
    def __init__(self, config, dataset, model_dir=None):
        super(CircleDiffusionAdaptor, self).__init__()
        # self.config = load_yaml_config(config)
        self.config = config
        
        if not model_dir is None:
            self.model_dir = model_dir
        else:
            self.model_dir = config['base_path']
        
        #if not os.path.exists(model_dir):
        #    os.mkdir(model_dir)
        #self.model_dir = model_dir
        self.input_dim = config['input_dim']
        try: 
            self.num_timesteps = config['num_timesteps']
        except KeyError: 
            pass
        
        self.dataset = dataset
        self.input_idx = [idx for idx, element in enumerate(self.dataset.attributes) if element not in ['Confounder', 'Target']]
        self.target_idx = [idx for idx, element in enumerate(self.dataset.attributes) if element == 'Target']
        #data = torch.zeros([len(dataset.data),len(dataset.attributes)], dtype=torch.float16)
        #for idx, key in enumerate(dataset.data):
        #    data[idx] = dataset.data[key]
        #self.model = self.train_and_load_diffusion(model_name='diffusion.pth')
        
        def schedules(num_timesteps: int, type: str='linear'):
 
            if type=='linear':
                scale = 1000 / num_timesteps
                min_var = scale * 1e-4
                max_var = scale * 1e-2
                return torch.linspace(min_var, max_var, num_timesteps, dtype=torch.float32)
            elif type=='cosine':
                steps = num_timesteps + 1
                x = torch.linspace(0, num_timesteps, steps, dtype=torch.float64)
                alphas_cumprod = torch.cos(((x / num_timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
                alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
                betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
                return torch.clip(betas, 0, 0.999)
        
        betas = schedules(num_timesteps=config['num_timesteps'], type=config['var_schedule'])

        self.register_buffer("beta", betas)
        self.register_buffer("alpha", 1 - self.beta)
        self.register_buffer("alpha_bar", self.alpha.cumprod(0))
        
        
    def forward_diffusion(self, clean_x: torch.Tensor, noise: torch.tensor, timestep: torch.Tensor):
        
        if isinstance(timestep, int):
            timestep = torch.tensor([timestep])
            alpha_bar_t = self.alpha_bar[timestep].repeat(clean_x.shape[0])[:, None]
        else:
            alpha_bar_t = self.alpha_bar[timestep][:, None]
        mu = torch.sqrt(alpha_bar_t)
        std = torch.sqrt(1 - alpha_bar_t)
        noisy_x = mu * clean_x + std * noise
        return noisy_x
    

    def reverse_diffusion_ddpm(self, noisy_x: torch.Tensor, model: nn.Module, timestep: torch.Tensor):
        alpha_t = self.alpha[timestep].repeat(noisy_x.shape[0])[:, None]
        alpha_bar_t = self.alpha_bar[timestep].repeat(noisy_x.shape[0])[:, None]
        beta_t = 1 - alpha_t
        eps_hat = model(x=noisy_x, t=timestep)
        posterior_mean = (1 / torch.sqrt(alpha_t)) * (noisy_x - (beta_t / torch.sqrt(1 - alpha_bar_t) * eps_hat))
        z = torch.randn_like(noisy_x)
        
        if timestep > 0:
            alpha_bar_t_minus_1 = self.alpha_bar[timestep-1].repeat(noisy_x.shape[0])[:, None]
            sigma_t = beta_t * (1 - alpha_bar_t_minus_1) / (1 - alpha_bar_t)
            denoised_x = posterior_mean + torch.sqrt(sigma_t)*z #* z * (timestep > 0))  # variance = beta_t
        else:
            denoised_x = posterior_mean
                                           
        return denoised_x
    
    def train_and_load_diffusion(self, model_name='diffusion.pt', mode=None):
        
        self.model_path = os.path.join(self.model_dir, model_name)
        model = BasicDiscreteTimeModel(input_dim=self.config['input_dim'], embed_dim=self.config['embed_dim'], num_timesteps=self.config['num_timesteps'])
        if model_name in os.listdir(self.model_dir) and not mode == "train":
            model.load_state_dict(torch.load(self.model_path))
            logging.info(f'Model found with path {self.model_path}')
        elif model_name not in os.listdir(self.model_dir) and mode != 'train':
            logging.info('Model not found. Please run train_and_load_diffusion method and set its argument mode="train" ')
        else:
            logging.info(
                f'Training model with path {self.model_path}'
            )
        
        def diffusion_loss(model: nn.Module, clean_x: torch.Tensor) -> torch.Tensor:
            t = torch.randint(self.num_timesteps, (clean_x.shape[0],))
            eps_t = torch.randn_like(clean_x)
            alpha_bar_t = self.alpha_bar[t][:, None]
            x_t = self.forward_diffusion(clean_x=clean_x, noise=eps_t, timestep=t)
            eps_hat = model(x=x_t, t=t)
            loss_diff = nn.MSELoss(reduction='sum')(eps_hat, eps_t)
            
            return loss_diff
                    
        def run_epoch(model: nn.Module, dataloader: torch.utils.data.dataloader.DataLoader):
            model.train()
            epoch_loss = 0.0

            for x, _ in dataloader:
                optimizer.zero_grad()
                loss = diffusion_loss(model, x[:, self.input_idx])
                epoch_loss += loss
                loss.backward()
                optimizer.step()
                
            return epoch_loss / len(dataloader.dataset)
        
        if mode == 'train':
            model.train()
            num_epochs = self.config['num_epochs']
            dataloader = DataLoader(self.dataset, batch_size=self.config['batch_size'], shuffle=True)
            learning_rate = 1e-4
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
            
            losses = []
            for i in trange(num_epochs):
                epoch_loss = 0.0
                for x, _ in dataloader:
                    optimizer.zero_grad()
                    loss = diffusion_loss(model, x[:, self.input_idx])
                    epoch_loss += loss
                    loss.backward()
                    optimizer.step()

                train_loss = epoch_loss / len(dataloader.dataset)
                print(f'Epoch: {i}, train_loss: {train_loss}')
                losses.append(train_loss.detach().numpy())
            
            torch.save(model.state_dict(), self.model_path) 
            
        self.model = model
    
    @torch.no_grad()
    def sample_ddpm(self, model: nn.Module, n_samples: int = 256, label=None):
        """
        iteratively denoises pure noise to produce a list of denoised samples at each timestep
        """
        model.eval()
        
        x_pred = []
        x = torch.randn(n_samples, self.input_dim)
        x_pred.append(x)
        
        step = 1
        for t in reversed(range(0, self.num_timesteps, step)):
                
            x = self.reverse_diffusion_ddpm(noisy_x=x, model=model, timestep=t)
                
            x_pred.append(x)
        return x_pred
    
    def sample_x(self, batch_size=1):
        x = self.sample_ddpm(model=self.model, n_samples=batch_size)[-1] 
        return x

    def sample_counterfactual_ddpm(self, clean_batch: torch.Tensor, model: nn.Module, classifier: nn.Module, num_noise_steps: int, target_classes: int, classifier_grad_weight: float):
        
        
        classifier.eval()
        self.classifier = classifier
        
        # DEFINE BATCH SIZE AND COUNTERFACTUAL CLASS
        bs = clean_batch.shape[0]

        # COMPUTE CLEAN GRADIENTS FOR THE FIRST STEP
        
        classifier_criterion = lambda x: F.cross_entropy(classifier(x), target_classes)
        clean_batch_copy = torch.nn.Parameter(clean_batch)
        loss = classifier_criterion(clean_batch_copy)
        loss.backward()
        clean_grad = classifier_grad_weight * clean_batch_copy.grad.detach()
        
        # PERFORMING FORWARD DIFFUSION UNTIL NUM_NOISE_STEPS
        eps_t = torch.randn_like(clean_batch)
        next_z = self.forward_diffusion(clean_x=clean_batch, noise=eps_t, timestep=num_noise_steps)
        counterfactuals = [] # total counterfactuals
        counterfactuals.append(clean_batch)
        guided_grads = []  # guided grads at the first step
        unconditional_grads = [] # diffusion grads at the first step
        total_series = [] # contains evolution from noisy to cleaned instance for each data point
        for i in tqdm(range(0, num_noise_steps)[::-1]):
            # Denoise z_t to create z_t-1 (next z)
            alpha_i = self.alpha[i].repeat(bs)[:, None]
            alpha_bar_i = self.alpha_bar[i].repeat(bs)[:, None]
            sigma_i = torch.sqrt(1 - self.alpha[i])
            eps_hat = model(next_z, i)

            # Unconditional mean
            unconditional_grad = -eps_hat / torch.sqrt(1 - alpha_bar_i)
            z_t_mean = (next_z + unconditional_grad * (1 - alpha_i)) / torch.sqrt(alpha_i)

            # Guided mean
            z_t_mean -= sigma_i * (clean_grad / torch.sqrt(alpha_bar_i))

            if i > 0:
                next_z = z_t_mean + (sigma_i * torch.randn_like(clean_batch))
            else:
                next_z = z_t_mean

            next_x = next_z.clone()
            # Denoise to create a cleaned x (next x)
            series = []
            series.append(next_x.detach())
            for t in range(0, i)[::-1]:
                if i == 0:
                    break
                next_x = self.reverse_diffusion_ddpm(noisy_x=next_x, model=model, timestep=t)
                series.append(next_x.detach())
            total_series.append(series)
            guided_grads.append(-sigma_i * clean_grad.detach() / torch.sqrt(alpha_bar_i))
            unconditional_grads.append(unconditional_grad.detach() * (1 - alpha_i) / torch.sqrt(alpha_i) )
            
            
            if i != 0:
                counterfactuals.append(next_x.detach())

            # Gradient wrt denoised image (next_x)
            next_x_copy = torch.nn.Parameter(next_x.clone())
            loss = classifier_criterion(next_x_copy)
            loss.backward()
            clean_classifier_grad = next_x_copy.grad.detach()
            clean_grad = classifier_grad_weight * clean_classifier_grad
            del clean_classifier_grad
            next_x_copy.grad.zero_()
            
            
        counterfactuals = torch.stack(counterfactuals).permute(1, 0, 2) 
        guided_grads = torch.stack(guided_grads).permute(1, 0, 2) 
        unguided_grads = torch.stack(unconditional_grads).permute(1, 0, 2)
        
        self.counterfactuals_series = counterfactuals
        self.guided_grads = guided_grads
        self.unguided_grads = unguided_grads
        
        return counterfactuals, guided_grads, unguided_grads, total_series

    
    def discard_counterfactuals(self, counterfactuals, classifier, target_classes, target_confidence, minimal_counterfactuals, tolerance=0.1):
        
        # compute distance of current minimal_counterefactuals from radius 1.0
        #current_counterfactual_distance_from_manifold = torch.abs((torch.pow(minimal_counterfactuals, 2).sum(dim=-1) - 1.0))
        
        for i in range(len(counterfactuals)):  

            # compute classifier  for all the counterfactuals for each point
            new_counterfactuals_confidence = classifier(counterfactuals[i]).softmax(dim=-1)[:, target_classes[i]]
            
            
            # check if new counterfactuals satisfy the confidence constraint
            new_confidence_satisfied = new_counterfactuals_confidence > target_confidence
            
            new_confidence_satisfied_indices = torch.nonzero(new_counterfactuals_confidence > target_confidence)
            
            current_confidence_satisfied = classifier(minimal_counterfactuals[i:i+1]).softmax(dim=-1)[0][target_classes[i]].item() > target_confidence
            
            # if current counterfactual satisfies confidence and tolerance, maintain status quo 
          
            if new_confidence_satisfied_indices.nelement() != 0:
                print('new confidence satisfied')
                print(new_confidence_satisfied_indices[0].item())
                minimal_counterfactuals[i] = counterfactuals[i][new_confidence_satisfied_indices[0].item()]
                
            else:
                print('neither current nor new confidence satisfied')
                minimal_counterfactuals[i] = counterfactuals[i][-1]

        return minimal_counterfactuals
        
        
    def edit(
        self,
        x_in: torch.Tensor,
        target_confidence_goal: float,
        target_classes: torch.Tensor,
        classifier: nn.Module
    ) -> Tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]:
        
        self.original_sample = x_in
        #minimal_counterfactuals = torch.zeros(size=x_in.shape)

        scales = self.config['grad_scales']
        noise_steps = self.config['noise_steps_for_counterfactuals']
        
        minimal_counterfactuals = x_in.clone()
        
        for it in range(self.config['num_iterations']):
            for steps in noise_steps:

                for s in scales:

                    counterfactuals, guided_grads, unguided_grads, total_series = self.sample_counterfactual_ddpm(clean_batch=minimal_counterfactuals, model=self.model, classifier=classifier, num_noise_steps=steps, target_classes=target_classes, classifier_grad_weight=s)

                    minimal_counterfactuals = self.discard_counterfactuals(counterfactuals=counterfactuals, classifier=classifier, target_confidence=target_confidence_goal, target_classes=target_classes, minimal_counterfactuals=minimal_counterfactuals)
                    self.counterfactuals = minimal_counterfactuals
                    flip_rate = sum(classifier(minimal_counterfactuals).softmax(dim=-1).argmax(dim=-1) != classifier(x_in).softmax(dim=-1).argmax(dim=-1)) / len(x_in)
                    self.plot_counterfactuals()
                    plt.title(f'Noise Steps: {steps}, Gradient Scale: {s}, Flip Rate: {round(flip_rate.item(),3)}')
                    #plt.show()
        list_counterfactuals = [row_tensor for row_tensor in minimal_counterfactuals]
        diff_latent = x_in - minimal_counterfactuals
        
        confidences = classifier(minimal_counterfactuals).softmax(dim=-1)
        y_target_end_confidence = [confidences[i][target_classes[i]].detach() for i in range(len(minimal_counterfactuals))]
        x_list = [row_tensor for row_tensor in x_in]
        
        #self.counterfactuals = minimal_counterfactuals
        
        return list_counterfactuals, diff_latent, y_target_end_confidence, x_list
    
    
    def plot_counterfactuals(self):
        #plt.figure(figsize=(5,5))
        data = torch.zeros([len(dataset.data),len(dataset.attributes)], dtype=torch.float16)
        for idx, key in enumerate(dataset.data):
            data[idx] = dataset.data[key]
        self.data = data
        plt.scatter(data[:,self.input_idx[0]], data[:,self.input_idx[1]], c=np.where(data[:,self.target_idx] == 0, 'lightcyan', 'lightgray')[0])
        for i, point in enumerate(self.counterfactuals):
            plt.scatter(self.original_sample[i, 0], self.original_sample[i, 1], color='green', label='start')
            plt.scatter(point[0], point[1], color='red', label='end')
            plt.arrow(
                self.original_sample[i,0], self.original_sample[i, 1], # plot the original point plus arrow until (j+granularity)th point
                point[0] - self.original_sample[i, 0], 
                point[1] - self.original_sample[i, 1],
                head_width=0.05, head_length=0.05, fc='blue', ec='blue',

            )   
    
    
        

In [6]:
from peal.configs.adaptors.adaptor_template import AdaptorConfig
adaptor_config = load_yaml_config('<PEAL_BASE>/configs/adaptors/circle_diffusion.yaml', AdaptorConfig)
adaptor = CircleDiffusionAdaptor(config=load_yaml_config(adaptor_config).generator, dataset=train_set.dataset, model_dir=None)
adaptor.train_and_load_diffusion(model_name='diffusion.pt')#, mode='train')
student = torch.load('peal_runs/artificial_circle_poisened_classifier/model.cpl')
teacher = torch.load('peal_runs/artificial_circle_unpoisened_classifier/model.cpl')


INFO:root:Model found with path peal_runs/artificial_circle_diffusion_generator/diffusion.pt


In [None]:
series = adaptor.sample_ddpm(adaptor.model, 100)

In [None]:
fig, axs = plt.subplots(1, 5, figsize=(20, 4)) 

indices = np.array([0, 400, 450, 470, 500], dtype=np.int)

for i, idx in enumerate(indices):
    axs[i].scatter(series[idx][:,0], series[idx][:,1], s=14.0, color='teal')
    if idx == 500:
        axs[i].set_title(f'Final')
    else:
        axs[i].set_title(f'Timestep {500-idx}')
plt.suptitle('Diffusion Samples', fontsize=20)
plt.tight_layout()
plt.show()

# Diffusion Samples

In [None]:

samples = adaptor.sample_x(batch_size=500)
plt.figure(figsize=(4,4))
plt.scatter(samples[:,0], samples[:,1], s=2)
plt.show()

In [None]:
def circle_distance(samples):
    radius = 1
    return (((samples.pow(2)).sum(dim=-1) - radius).pow(2)).mean()

def angle_cdf(samples):
    scores = abs(samples[:, 1] / samples[:, 0])

    first_quad_mask = (samples[:, 0] > 0) & (samples[:, 1] > 0)
    second_quad_mask = (samples[:, 0] < 0) & (samples[:, 1] > 0)
    third_quad_mask = (samples[:, 0] < 0) & (samples[:, 1] < 0)
    fourth_quad_mask = (samples[:, 0] > 0) & (samples[:, 1] < 0)
    theta_1 = torch.atan(scores) * first_quad_mask
    theta_1 = theta_1[theta_1 != 0]
    theta_2 = (torch.pi - torch.atan(scores)) * second_quad_mask
    theta_2 = theta_2[theta_2 != 0]
    theta_3 = (torch.pi + torch.atan(scores)) * third_quad_mask
    theta_3 = theta_3[theta_3 != 0]
    theta_4 = (2 * torch.pi - torch.atan(scores)) * fourth_quad_mask
    theta_4 = theta_4[theta_4 != 0]
    thetas, indices = torch.cat([theta_1, theta_2, theta_3, theta_4]).sort(dim=-1)

    return thetas

def circle_ks(samples):
    dist = torch.distributions.uniform.Uniform(0, 2*torch.pi)
    sample_thetas = angle_cdf(samples)
    
    ecdf = torch.arange(len(samples)) / len(samples)
    true_cdf = dist.cdf(sample_thetas)
    return torch.max(torch.abs(dist.cdf(sample_thetas) - ecdf))

In [None]:
sizes = np.linspace(5, 30000, 50, dtype=np.int)
distances = []
ks_stats = []
for size in sizes:
    dist = 0.0
    ks = 0.0
    for it in range(5):
        samples = adaptor.sample_x(batch_size=size)
        dist += circle_distance(samples)
        ks += circle_ks(samples)
    distances.append(dist/5)
    ks_stats.append(ks/5)

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm, trange
from typing import Tuple
import logging
from torch.utils.data import DataLoader
import math
from peal.generators.interfaces import EditCapableGenerator
%matplotlib inline

logging.getLogger().setLevel(logging.INFO)

class VAE(nn.Module):
    def __init__(self, input_dim: int, encoder_dims: list, decoder_dims: list, latent_dim: int):
        super(VAE, self).__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim

        self.encoder = nn.Sequential()
        for i, dim in enumerate(encoder_dims):
            self.encoder.add_module(f'layer_{i + 1}', nn.Sequential(nn.LazyLinear(dim), nn.SiLU()))
        self.encoder.add_module('norm_encoder', nn.LayerNorm(dim))

        self.latent_mean = nn.Sequential(nn.LazyLinear(latent_dim))
        self.latent_logvar = nn.Sequential(nn.LazyLinear(latent_dim))

        self.decoder = nn.Sequential()
        for i, dim in enumerate(decoder_dims):
            self.decoder.add_module(f'layer_{i + 1}', nn.Sequential(nn.LazyLinear(dim), nn.SiLU()))
        self.decoder.add_module('norm_decoder', nn.LayerNorm(dim))
        self.decoder.add_module(f'to_original', nn.Sequential(nn.LazyLinear(input_dim)))

    def reparameterize(self, mean, logvar):
        if self.training:
            z = mean + torch.exp(logvar * 0.5) * torch.randn_like(logvar)
        else:
            z = mean
        return z

    def encode(self, x):
        x = self.encoder(x)
        mean, logvar = self.latent_mean(x), self.latent_logvar(x)
        return mean, logvar

    def decode(self, z):
        x_hat = self.decoder(z)
        return x_hat

    def sample(self, num_samples):
        eps = torch.randn([num_samples, self.latent_dim])
        return self.decode(eps).detach()

    def forward(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        x_hat = self.decode(z)
        return x_hat, mean, logvar


class CircleVAEAdaptor(EditCapableGenerator):
    def __init__(self, config, dataset, model_dir=None):
        super(CircleVAEAdaptor, self).__init__()
        self.config = config
        self.dataset = dataset

        if not model_dir is None:
            self.model_dir = model_dir
        else:
            self.model_dir = config['base_path']

        #self.input_dim = config.input_dim
        #self.encoder_dims = config.encoder_dims
        #self.decoder_dims = config.decoder_dims
        #self.latent_dim = config.latent_dim
        
        self.input_dim = config['input_dim']
        self.encoder_dims = config['encoder_dims']
        self.decoder_dims = config['decoder_dims']
        self.latent_dim = config['latent_dim']

        self.train_and_load_vae(model_name=config['model_name'])

        self.input_idx = [
            idx
            for idx, element in enumerate(self.dataset.attributes)
            if element not in ["Confounder", "Target"]
        ]
        self.target_idx = [
            idx
            for idx, element in enumerate(self.dataset.attributes)
            if element == "Target"
        ]

    def train_and_load_vae(self, model_name='vae.pt', mode=None):
        self.model_path = os.path.join(self.model_dir, model_name)
        model = VAE(input_dim=self.input_dim, encoder_dims=self.encoder_dims, decoder_dims=self.decoder_dims,
                    latent_dim=self.latent_dim)

        if model_name in os.listdir(self.model_dir) and not mode == "train":
            model.load_state_dict(torch.load(self.model_path))
            logging.info(f'Model found with path {self.model_path}')
        elif model_name not in os.listdir(self.model_dir) and mode != 'train':
            logging.info('Model not found. Please run train_and_load_vae method and set its argument mode="train" ')
        else:
            logging.info(
                f'Training model with path {self.model_path}'
            )

        def VAELoss(x, x_hat, mean, logvar, beta=self.config['beta']):
            kl_loss = torch.mean(-0.5 * torch.sum(1 + logvar - mean ** 2 - logvar.exp(), dim=1), dim=0)
            reconstruction_loss = F.mse_loss(x, x_hat)
            return reconstruction_loss + beta * kl_loss

        def train(model, data_loader, epochs):
            model.train()
            optimizer = optim.Adam(model.parameters(), lr=1e-3)

            for epoch in tqdm(range(epochs)):
                total_loss = 0.0
                for x, y in data_loader:
                    x = x
                    optimizer.zero_grad()
                    x_hat, mean, logvar = model(x[:, self.input_idx])
                    loss = VAELoss(x[:, self.input_idx], x_hat=x_hat, mean=mean, logvar=logvar)
                    total_loss += loss.item()
                    loss.backward()
                    optimizer.step()

                print(f'Epoch: {epoch}, Loss: {loss}')

            return x, x_hat

        if mode == 'train':
            model.train()
            dataloader = DataLoader(self.dataset, batch_size=self.config['batch_size'], shuffle=True)
            train(model=model, data_loader=dataloader, epochs=self.config['num_epochs'])
            torch.save(model.state_dict(), self.model_path)

        self.model = model

    def sample_x(self, batch_size=1):
        return self.model.sample(num_samples=batch_size).detach()

    def DIVE(self, clean_batch, target_classes, model, classifier):

        classifier.eval()

        lasso_weight = self.config['lasso_weight']
        reconstruction_weight = self.config['reconstruction_weight']

        batch_size, _ = clean_batch.size()
        latent_dim = self.latent_dim

        mean, logvar = self.model.encode(clean_batch)
        z = model.reparameterize(mean, logvar).detach()  # no grads required for latents

        epsilon = torch.randn_like(z, requires_grad=True)
        epsilon.data *= 0.01
        optimizer = torch.optim.Adam([epsilon], lr=self.config['lr_counterfactual'], weight_decay=0)
        
        z_perturbed = z + epsilon 
        list_z = []
        list_counterfactuals = []
        list_z.append(z[0])
        list_counterfactuals.append(clean_batch[0])
        for it in range(self.config['num_iterations']):
            optimizer.zero_grad()

            decoded = model.decode(z_perturbed)

            classifier_criterion = lambda x: F.cross_entropy(classifier(x), target_classes)
            loss_attack = classifier_criterion(decoded)
            recon_regularizer = reconstruction_weight * torch.abs((clean_batch - decoded).mean(dim=-1)).sum()
            lasso_regularizer = lasso_weight * (torch.abs(z_perturbed - z)).sum()
            regularizer = recon_regularizer + lasso_regularizer

            loss = loss_attack + regularizer

            loss.backward()

            optimizer.step()
            
            z_perturbed = z + epsilon
            x = model.decode(z_perturbed).detach()
            list_z.append(z_perturbed.detach()[0])
            list_counterfactuals.append(x[0])

        return clean_batch, x, torch.stack(list_z), torch.stack(list_counterfactuals)
    
    def discard_counterfactuals(self, counterfactuals, classifier, target_classes, target_confidence, minimal_counterfactuals):
        for i in range(len(counterfactuals)):
            
            new_counterfactuals_confidence = classifier(counterfactuals[i]).softmax(dim=-1)[:, target_classes[i]]
            
            new_confidence_satisfied = new_counterfactuals_confidence > target_confidence
            
            new_confidence_satisfied_indices = torch.nonzero(new_counterfactuals_confidence > target_confidence)
            
            current_confidence_satisfied = classifier(minimal_counterfactuals[i:i+1]).softmax(dim=-1)[0][target_classes[i]].item() > target_confidence
             
            

    def edit(
            self,
            x_in: torch.Tensor,
            target_confidence_goal: float,
            target_classes: torch.Tensor,
            classifier: torch.nn.Module,
            **kwargs,
    ) -> Tuple[
        list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]
    ]:
        """
        Edit a batch of samples to achieve a target confidence goal.
        Args:
            x_in: Batch of samples to edit.
            target_confidence_goal: Target confidence goal.
            target_classes: Target classes for each sample in the batch.
            classifier: Classifier to use for confidence estimation.
            **kwargs: Additional keyword arguments.
        Returns:
            Tuple of (edited samples, confidence estimates, number of iterations, number of queries).
        """

        list_counterfactuals = torch.zeros_like(x_in)
        y_target_end_confidence = torch.zeros([x_in.shape[0]])
        counterfactuals = x_in
        for i in range(len(x_in)):
            #while True:
            _, counterfactual, _ ,_ = self.DIVE(x_in[i:i + 1], target_classes[i:i + 1], self.model, classifier)
            current_confidence = classifier(counterfactual).softmax(dim=-1)[0][target_classes[i].item()].item()
                #if current_confidence > target_confidence_goal:
                #    break
            y_target_end_confidence[i] = current_confidence
            list_counterfactuals[i] = counterfactual

        diff_latent = x_in - list_counterfactuals

        x_list = [row_tensor for row_tensor in x_in]

        return list(list_counterfactuals), diff_latent, y_target_end_confidence, x_list

def circle_distance(samples):
    radius = 1
    return (((samples.pow(2)).sum(dim=-1) - radius).pow(2)).mean()

def angle_cdf(samples):
    scores = abs(samples[:, 1] / samples[:, 0])

    first_quad_mask = (samples[:, 0] > 0) & (samples[:, 1] > 0)
    second_quad_mask = (samples[:, 0] < 0) & (samples[:, 1] > 0)
    third_quad_mask = (samples[:, 0] < 0) & (samples[:, 1] < 0)
    fourth_quad_mask = (samples[:, 0] > 0) & (samples[:, 1] < 0)
    theta_1 = torch.atan(scores) * first_quad_mask
    theta_1 = theta_1[theta_1 != 0]
    theta_2 = (torch.pi - torch.atan(scores)) * second_quad_mask
    theta_2 = theta_2[theta_2 != 0]
    theta_3 = (torch.pi + torch.atan(scores)) * third_quad_mask
    theta_3 = theta_3[theta_3 != 0]
    theta_4 = (2 * torch.pi - torch.atan(scores)) * fourth_quad_mask
    theta_4 = theta_4[theta_4 != 0]
    thetas, indices = torch.cat([theta_1, theta_2, theta_3, theta_4]).sort(dim=-1)

    return thetas

def circle_ks(samples):
    dist = torch.distributions.uniform.Uniform(0, 2*torch.pi)
    sample_thetas = angle_cdf(samples)
    
    ecdf = torch.arange(len(samples)) / len(samples)
    true_cdf = dist.cdf(sample_thetas)
    return torch.max(torch.abs(dist.cdf(sample_thetas) - ecdf))

adaptor_config = load_yaml_config('<PEAL_BASE>/configs/adaptors/circle_vae.yaml')
adaptor = CircleVAEAdaptor(config=adaptor_config.generator, dataset=train_set.dataset, model_dir=None)
adaptor.train_and_load_vae(model_name='vae_beta_0.05.pt')#, mode='train')

sizes = np.linspace(5, 30000, 50, dtype=np.int)
distances_vae = []
ks_vae = []
for size in sizes:
    dist = 0.0
    ks = 0.0
    for it in range(5):
        samples = adaptor.sample_x(batch_size=size)
        dist += circle_distance(samples)
        ks += circle_ks(samples)
    distances_vae.append(dist/5)
    ks_vae.append(ks/5)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(20, 7)) 

axs[0].plot(np.array(sizes), distances, label='DDPM')
axs[0].plot(np.array(sizes), distances_vae, label='VAE')
#axs[0].plot(np.array(sizes), distances_bvae, label='VAE_beta')
#axs[0].plot(np.array(sizes), distances_bvae5, label='VAE_beta5')
axs[0].set_title(f'Closeness to the Manifold', fontsize='15')

axs[1].plot(np.array(sizes), np.array(ks_stats), label='DDPM')
axs[1].plot(np.array(sizes), np.array(ks_vae), label='VAE')
axs[1].plot(np.array(sizes), -1*np.array(ks_bvae), label='VAE_beta')
axs[1].plot(np.array(sizes), -1*np.array(ks_bvae5), label='VAE_beta5')
axs[1].set_title(f'Diversity', fontsize='15')
axs[0].legend(fontsize='12')
axs[1].legend(fontsize='12')
plt.show()

In [13]:
from peal.adaptors.counterfactual_knowledge_distillation import (
    CounterfactualKnowledgeDistillation,
)
from peal.global_utils import load_yaml_config, add_class_arguments, integrate_arguments

In [14]:
data = torch.zeros([len(dataset.data),len(dataset.attributes)], dtype=torch.float16)
for idx, key in enumerate(dataset.data):
    data[idx] = dataset.data[key]


In [16]:
cfkd = CounterfactualKnowledgeDistillation(adaptor_config=adaptor_config, overwrite=True)

INFO:root:Model found with path peal_runs/artificial_circle_diffusion_generator/diffusion.pt


In [None]:
x, y = next(iter(cfkd.test_dataloader))

In [None]:
list_counterfactuals, latent_diff, confidence, list_x = cfkd.generator.edit(x_in=x[:10], target_confidence_goal=0.7, source_classes=y[:10], target_classes=1-y[:10].type(torch.long), classifier=cfkd.student)

In [None]:
plt.scatter(data[:,0], data[:,1], color='lightgray')
for i, point in enumerate(list_counterfactuals):
    plt.scatter(list_x[i][0], list_x[i][1], color='green')
    plt.scatter(list_counterfactuals[i][0], list_counterfactuals[i][1], color='red')
    plt.arrow(
        list_x[i][0], list_x[i][1], # plot the original point plus arrow until (j+granularity)th point
        list_counterfactuals[i][0] - list_x[i][0], 
        list_counterfactuals[i][1] - list_x[i][1],
        head_width=0.05, head_length=0.05, fc='blue', ec='blue',

    )

In [None]:
cfkd.run()

Adaptor Config: <peal.configs.adaptors.adaptor_template.AdaptorConfig object at 0x7f671b85bbb0>
Create base_dir in: peal_runs/artificial_circle_poisened_classifier/cfkd_ddpm_oracle_steps_100
Generator performance: {}




  0%|                                                                                            | 0/26 [00:00<?, ?it/s][A[A

test_correct: 1.0, it: 0:   4%|██▏                                                       | 1/26 [00:00<00:01, 13.64it/s][A[A


  0%|                                                                                            | 0/31 [00:00<?, ?it/s][A[A


  0%|                                                                                            | 0/60 [00:00<?, ?it/s][A[A[A


  5%|████▏                                                                               | 3/60 [00:00<00:02, 22.90it/s][A[A[A


 12%|█████████▊                                                                          | 7/60 [00:00<00:01, 29.00it/s][A[A[A


 18%|███████████████▏                                                                   | 11/60 [00:00<00:01, 31.56it/s][A[A[A


 25%|████████████████████▊                                                          

 20%|████████████████▌                                                                  | 12/60 [00:00<00:01, 36.67it/s][A[A[A


 28%|███████████████████████▌                                                           | 17/60 [00:00<00:01, 39.23it/s][A[A[A


 37%|██████████████████████████████▍                                                    | 22/60 [00:00<00:00, 42.12it/s][A[A[A


 48%|████████████████████████████████████████                                           | 29/60 [00:00<00:00, 50.18it/s][A[A[A


 62%|███████████████████████████████████████████████████▏                               | 37/60 [00:00<00:00, 58.57it/s][A[A[A


100%|███████████████████████████████████████████████████████████████████████████████████| 60/60 [00:00<00:00, 67.24it/s][A[A[A



  0%|                                                                                            | 0/60 [00:00<?, ?it/s][A[A[A


  7%|█████▌                                                                

> [0;32m/mnt/c/users/fahad/thesis/code/peal_private/peal/data/dataset_utils.py[0m(95)[0;36mparse_csv[0;34m()[0m
[0;32m     93 [0;31m    [0;32mif[0m [0mnp[0m[0;34m.[0m[0mrandom[0m[0;34m.[0m[0mrand[0m[0;34m([0m[0;34m)[0m [0;34m<[0m [0;36m0.5[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     94 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 95 [0;31m    [0;32mif[0m [0mlen[0m[0;34m([0m[0mconfig[0m[0;34m.[0m[0mconfounding_factors[0m[0;34m)[0m [0;34m==[0m [0;36m2[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     96 [0;31m[0;34m[0m[0m
[0m[0;32m     97 [0;31m        [0;32mdef[0m [0mextract_instances_tensor_confounder[0m[0;34m([0m[0midx[0m[0;34m,[0m [0mline[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


In [None]:
data = {
    'x_list': [torch.tensor([0.9983, 0.0577]), torch.tensor([0.6026, 0.7980]), torch.tensor([0.7544, 0.6564]), torch.tensor([-0.3068, -0.9518]), torch.tensor([0.4866, 0.8736]), torch.tensor([-0.4967, -0.8679]), torch.tensor([0.0115, 0.9999]), torch.tensor([0.9808, 0.1951]), torch.tensor([0.1724, 0.9850]), torch.tensor([-0.9552, -0.2958]), torch.tensor([0.9983, 0.0577]), torch.tensor([0.8502, 0.5264]), torch.tensor([-0.5264, -0.8502]), torch.tensor([-0.8792, -0.4765]), torch.tensor([-0.4039, -0.9148]), torch.tensor([-0.4663, -0.8846]), torch.tensor([0.6651, 0.7467]), torch.tensor([-0.4866, -0.8736]), torch.tensor([0.1837, 0.9830]), torch.tensor([0.9282, 0.3720]), torch.tensor([0.7910, 0.6118]), torch.tensor([0.9904, 0.1382]), torch.tensor([0.2064, 0.9785]), torch.tensor([0.4249, 0.9052]), torch.tensor([0.6026, 0.7980]), torch.tensor([0.9003, 0.4354]), torch.tensor([0.9366, 0.3504]), torch.tensor([0.6651, 0.7467]), torch.tensor([0.9052, 0.4249]), torch.tensor([0.1038, 0.9946]), torch.tensor([0.9967, 0.0808]), torch.tensor([-0.9933, -0.1152]), torch.tensor([-0.5166, -0.8562]), torch.tensor([0.8184, 0.5746]), torch.tensor([-0.3612, -0.9325]), torch.tensor([0.9760, 0.2177]), torch.tensor([0.7839, 0.6209]), torch.tensor([0.2402, 0.9707]), torch.tensor([-1.8370e-16, -1.0000e+00]), torch.tensor([0.8900, 0.4560]), torch.tensor([0.4866, 0.8736]), torch.tensor([-0.8679, -0.4967]), torch.tensor([0.7693, 0.6388]), torch.tensor([-0.0231, -0.9997]), torch.tensor([-0.2402, -0.9707]), torch.tensor([-0.9482, -0.3178]), torch.tensor([-0.9870, -0.1610]), torch.tensor([-0.1610, -0.9870]), torch.tensor([0.9785, 0.2064]), torch.tensor([0.9994, 0.0346]), torch.tensor([-0.9808, -0.1951]), torch.tensor([0.9850, 0.1724]), torch.tensor([-0.1382, -0.9904]), torch.tensor([0.3504, 0.9366]), torch.tensor([-0.4457, -0.8952]), torch.tensor([0.0231, 0.9997])],
    
    'x_counterfactual_list': [torch.tensor([-0.3278, -0.9271]), torch.tensor([ 0.6094, -0.7962]), torch.tensor([ 0.4023, -0.9154]), torch.tensor([ 0.8901, -0.4527]), torch.tensor([ 0.5666, -0.8311]), torch.tensor([ 0.8767, -0.4678]), torch.tensor([-0.7687,  0.6222]), torch.tensor([ 0.6639, -0.7749]), torch.tensor([-0.8234,  0.5644]), torch.tensor([ 0.9490, -0.3264]), torch.tensor([ 0.6660, -0.7403]), torch.tensor([ 0.6616, -0.7394]), torch.tensor([ 0.8022, -0.6140]), torch.tensor([ 0.7779, -0.6191]), torch.tensor([ 0.8159, -0.5691]), torch.tensor([ 0.7458, -0.6825]), torch.tensor([ 0.5259, -0.8478]), torch.tensor([ 0.7926, -0.6306]), torch.tensor([-0.8557,  0.5258]), torch.tensor([ 0.6206, -0.7707]), torch.tensor([-0.7720,  0.6367]), torch.tensor([ 0.3622, -0.9374]), torch.tensor([-0.8194,  0.5773]), torch.tensor([ 0.5147, -0.8544]), torch.tensor([-0.9915,  0.0854]), torch.tensor([ 0.6768, -0.7524]), torch.tensor([ 0.5410, -0.8401]), torch.tensor([-0.7550,  0.6432]), torch.tensor([ 0.5703, -0.8203]), torch.tensor([-0.9732, -0.2271]), torch.tensor([ 0.6819, -0.7523]), torch.tensor([-0.5816,  0.8087]), torch.tensor([ 0.8307, -0.5553]), torch.tensor([ 0.4141, -0.9123]), torch.tensor([ 0.7574, -0.6468]), torch.tensor([ 0.6505, -0.7626]), torch.tensor([ 0.6310, -0.7757]), torch.tensor([-0.9183,  0.3827]), torch.tensor([ 0.7822, -0.6366]), torch.tensor([ 0.3387, -0.9406]), torch.tensor([-0.3059,  0.1437]), torch.tensor([-0.6234,  0.7672]), torch.tensor([-0.8768,  0.4734]), torch.tensor([ 0.8038, -0.6209]), torch.tensor([ 0.7816, -0.6295]), torch.tensor([ 0.7685, -0.6677]), torch.tensor([-0.6739,  0.7387]), torch.tensor([ 0.6678, -0.7440]), torch.tensor([-0.8223,  0.5694]), torch.tensor([ 0.5737, -0.8168]), torch.tensor([ 0.4579, -0.8889]), torch.tensor([ 0.5802, -0.8145]), torch.tensor([ 0.8193, -0.5774]), torch.tensor([ 0.7992, -0.6015]), torch.tensor([ 0.9552, -0.2958])],
}

In [None]:
adaptor_config.data

In [None]:
data = torch.zeros([len(dataset.data),len(dataset.attributes)], dtype=torch.float16)
for idx, key in enumerate(dataset.data):
    data[idx] = dataset.data[key]


In [None]:
batch_out = {
    'x_list': [torch.tensor([-0.9983, -0.0577]), torch.tensor([-0.9239, -0.3827]), torch.tensor([0.2625, 0.9649]), torch.tensor([-0.2290, -0.9734]), torch.tensor([-0.4249, -0.9052]), torch.tensor([-0.1152, -0.9933]), torch.tensor([0.4967, 0.8679]), torch.tensor([-0.8621, -0.5067]), torch.tensor([0.9997, 0.0231]), torch.tensor([-0.2064, -0.9785])],
    'y_list': torch.tensor([0., 0., 1., 0., 0., 0., 1., 0., 1., 0.]),
    'y_source_list': torch.tensor([0, 0, 1, 0, 0, 0, 1, 0, 1, 0]),
    'y_target_list': torch.tensor([1, 1, 0, 1, 1, 1, 0, 1, 0, 1]),
    'y_target_start_confidence_list': torch.tensor([6.6621e-05, 4.2620e-05, 4.5146e-05, 6.3567e-05, 4.7831e-05, 7.9660e-05, 3.2440e-05, 3.9745e-05, 8.9317e-05, 6.6160e-05]),
    'x_counterfactual_list': [torch.tensor([-0.4479, 0.8914]), torch.tensor([-0.8168, -0.5625]), torch.tensor([-0.7237, 0.6845]), torch.tensor([0.7815, -0.6335]), torch.tensor([0.8844, -0.4574]), torch.tensor([0.9444, -0.3056]), torch.tensor([-0.7455, 0.6578]), torch.tensor([-0.6356, 0.7643]), torch.tensor([0.5736, -0.8212]), torch.tensor([0.8232, -0.5641])],
    'z_difference_list': torch.tensor([[-0.5504, -0.9491], [-0.1071, 0.1798], [0.9863, 0.2805], [-1.0104, -0.3399], [-1.3093, -0.4478], [-1.0596, -0.6878], [1.2421, 0.2102], [-0.2266, -1.2709], [0.4261, 0.8443], [-1.0296, -0.4143]]),
    'y_target_end_confidence_list': [torch.tensor(0.9810), torch.tensor(4.1007e-05), torch.tensor(0.6721), torch.tensor(0.8129), torch.tensor(0.9863), torch.tensor(0.9983), torch.tensor(0.7672), torch.tensor(0.7127), torch.tensor(0.9163), torch.tensor(0.9303)],
    'x_attribution_list': [torch.tensor([-0.9983, -0.0577]), torch.tensor([-0.9239, -0.3827]), torch.tensor([0.2625, 0.9649]), torch.tensor([-0.2290, -0.9734]), torch.tensor([-0.4249, -0.9052]), torch.tensor([-0.1152, -0.9933]), torch.tensor([0.4967, 0.8679]), torch.tensor([-0.8621, -0.5067]), torch.tensor([0.9997, 0.0231]), torch.tensor([-0.2064, -0.9785])],
    'collage_path_list': ['peal_runs/artificial_circle_poisened_classifier/cfkd_ddpm_oracle_steps_100/0/validation_collages/0000000', 'peal_runs/artificial_circle_poisened_classifier/cfkd_ddpm_oracle_steps_100/0/validation_collages/0000001', 'peal_runs/artificial_circle_poisened_classifier/cfkd_ddpm_oracle_steps_100/0/validation_collages/0000002', 'peal_runs/artificial_circle_poisened_classifier/cfkd_ddpm_oracle_steps_100/0/validation_collages/0000003', 'peal_runs/artificial_circle_poisened_classifier/cfkd_ddpm_oracle_steps_100/0/validation_collages/0000004', 'peal_runs/artificial_circle_poisened_classifier/cfkd_ddpm_oracle_steps_100/0/validation_collages/0000005', 'peal_runs/artificial_circle_poisened_classifier/cfkd_ddpm_oracle_steps_100/0/validation_collages/0000006', 'peal_runs/artificial_circle_poisened_classifier/cfkd_ddpm_oracle_steps_100/0/validation_collages/0000007', 'peal_runs/artificial_circle_poisened_classifier/cfkd_ddpm_oracle_steps_100/0/validation_collages/0000008', 'peal_runs/artificial_circle_poisened_classifier/cfkd_ddpm_oracle_steps_100/0/validation_collages/0000009']
}


In [None]:
x_list = torch.stack(data['x_list'])
x_counterfactual_list = torch.stack(data['x_counterfactual_list'])

In [None]:
x = torch.tensor([[0.7451, 0.0000],
        [0.0000, 0.5355],
        [0.0000, 0.8923],
        [0.0000, 0.3088],
        [0.0000, 0.7999],
        [0.0000, 0.2300],
        [0.5908, 0.0000],
        [0.4921, 0.0000],
        [0.5737, 0.0000],
        [0.7567, 0.0000]])


In [None]:
xs = []
ys = []
for i in range(10):
    x, y = cfkd.dataloader_mixer.sample()
    #if x.shape[0] == 100:
    xs.append(x)
    ys.append(y)
    

In [None]:
xs = torch.stack([tensor for i in range(len(xs)) for tensor in xs[i]])
ys = torch.stack([tensor for i in range(len(ys)) for tensor in ys[i]])

In [None]:
%matplotlib inline
plt.scatter(data.data[:,0], data.data[:,1], c='lightgray')
#plt.scatter(xs[:,0], xs[:,1], c=ys)
#plt.scatter(x_counterfactual_list[:,0], x_counterfactual_list[:,1], color='green')
#plt.scatter(x_list[:, 0], x_list[:, 1], color='red')
plt.scatter(x[:,0], x[:,1])
plt.show()

# Counterfactual Production

In [None]:
test_dl = DataLoader(test_set, batch_size=3, shuffle=True)
sample, _ = next(iter(test_dl))
plt.figure(figsize=(5,3))
plt.scatter(sample[:,0], sample[:, 1])
plt.show()

In [None]:
adaptor_config.generator['grad_scales'] = [1.0]
adaptor_config.generator['noise_steps_for_counterfactuals'] = [40]
#adaptor_config.generator['num_iterations'] = 3
target_classes = student(sample[:,:2]).argmin(dim=-1)
list_counterfactuals, diff_latent, y_target_end_confidence, x_list = adaptor.edit(x_in=sample[:,:2], target_confidence_goal=0.9, target_classes=target_classes, classifier=student)

input_idx = [0,1]
xx1, xx2 = np.meshgrid(*[np.linspace(float(adaptor.data[:, [input_idx]].min()-0.5),float(adaptor.data[:, [input_idx]].max()+0.5), 500) for idx in input_idx])
grid = torch.from_numpy(np.array([xx1.flatten(), xx2.flatten()]).T).to(torch.float32)
z = student(grid).to(torch.float32).detach().numpy().argmax(axis=1).reshape(xx1.shape)
plt.contour(xx1, xx2, z, levels=[0],linestyles='dashed', label='decision boundary')
adaptor.plot_counterfactuals()
#plt.legend()
#plt.show()

In [None]:
input_data = torch.nn.Parameter(grid.clone())  # Assuming 2 input features

# Compute logits
logits = student(input_data)

# Compute gradients of the logits with respect to the inputs
input_data.requires_grad = True
logits[:,0].sum().backward(retain_graph=True)
#logits[:,0].backward(torch.ones_like(logits), retain_graph=True)
input_gradients = input_data.grad
#input_data.grad.zero_()
plt.quiver(input_data[:,0].detach(), input_data[:,1].detach(), input_gradients[:, 0], input_gradients[:, 1])


In [None]:
fig, axs = plt.subplots(1, 2, figsize=(20, 7)) 

for i in range(2):
    input_data = torch.nn.Parameter(grid.clone())  # Assuming 2 input features
    logits = student(input_data)
    input_data.requires_grad = True
    logits[:,i].sum().backward()
    input_gradients = input_data.grad
    axs[i].quiver(input_data[:,0].detach(), input_data[:,1].detach(), input_gradients[:, 0], input_gradients[:, 1])
    input_data.grad.zero_()
    axs[i].set_title(f'Class:{i}')
plt.show()

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(20, 7)) 

for i in range(2):
    input_data = torch.nn.Parameter(grid.clone())  # Assuming 2 input features
    logits = fine_tuned(input_data)
    input_data.requires_grad = True
    logits[:,i].sum().backward()
    input_gradients = input_data.grad
    axs[i].quiver(input_data[:,0].detach(), input_data[:,1].detach(), input_gradients[:, 0], input_gradients[:, 1])
    input_data.grad.zero_()
    axs[i].set_title(f'Class:{i}')
plt.show()

In [None]:
#target_classes = student(grid).argmax(dim=-1)
#classifier_criterion = lambda x: F.cross_entropy(student(x), target_classes)
sample_copy = torch.nn.Parameter(grid)
#loss = classifier_criterion(sample_copy)
#loss.backward()
y = student(sample_copy)
grad = torch.autograd.grad(outputs=y, inputs=sample_copy, grad_outputs=torch.ones_like(y), create_graph=True)[0]
plt.quiver(sample_copy[:,0].detach(), sample_copy[:,1].detach(), grad[:,0].detach(), grad[:,1].detach())
xx1, xx2 = np.meshgrid(*[np.linspace(float(adaptor.data[:, [input_idx]].min()-0.5),float(adaptor.data[:, [input_idx]].max()+0.5), 500) for idx in input_idx])
grid2 = torch.from_numpy(np.array([xx1.flatten(), xx2.flatten()]).T).to(torch.float32)
z = student(grid2).to(torch.float32).detach().numpy().argmax(axis=1).reshape(xx1.shape)
plt.contour(xx1, xx2, z, levels=[0],linestyles='dashed', label='decision boundary')



In [None]:
plt.figure(figsize=(5,5))

plt.scatter(data.data[:,0], data.data[:,1], c='lightgray')#c=np.where(data.label == 0, 'lightcyan', 'lightgray'))

step = 5
point = adaptor.counterfactuals_series[1]
for j in range(0, point.shape[0]-step, step): # jth counterfactual
    plt.arrow(
        point[j, 0], point[j, 1], # plot the original point plus arrow until (j+granularity)th point
        point[j+step, 0] - point[j, 0], 
        point[j+step, 1] - point[j, 1],
        head_width=0.09, head_length=0.09, fc='blue', ec='blue'
    )
    plt.scatter(point[j+step, 0], point[j+step, 1], color='red', s=15)
plt.scatter(point[0, 0], point[0, 1], color='black', label='start')
plt.arrow(point[j, 0], point[j, 1],
              point[-1, 0] - point[j, 0],
              point[-1, 1] - point[j, 1],
              head_width=0.05, head_length=0.05, fc='blue', ec='blue')
#input_idx = [0,1]
#xx1, xx2 = np.meshgrid(*[np.linspace(float(adaptor.data[:, [input_idx]].min()-0.5),float(adaptor.data[:, [input_idx]].max()+0.5), 1000) for idx in input_idx])
#grid = torch.from_numpy(np.array([xx1.flatten(), xx2.flatten()]).T).to(torch.float32)
#z = student(grid).to(torch.float32).detach().numpy().argmax(axis=1).reshape(xx1.shape)
plt.contour(xx1, xx2, z, levels=[0],linestyles='dashed', label='decision boundary')
plt.scatter(point[-1, 0], point[-1, 1], color='lime')
plt.scatter(point[-1, 0], point[-1, 1], color='lime')
plt.scatter(point[-1, 0], point[-1, 1], color='cyan', label='end')

plt.legend()
plt.show()


In [None]:
cmap = plt.get_cmap('viridis')
plt.figure(figsize=(3,3))

plt.scatter(data[:,0], data[:,1], color='lightgray', s=10)
input_idx = [0,1]
xx1, xx2 = np.meshgrid(*[np.linspace(float(data[:, [input_idx]].min()-0.5),float(data[:, [input_idx]].max()+0.5), 1000) for idx in input_idx])
grid = torch.from_numpy(np.array([xx1.flatten(), xx2.flatten()]).T).to(torch.float32)
plt.contour(xx1, xx2, z, levels=[0],linestyles='dashed', label='decision boundary')

for i in range(0, len(list_counterfactuals)):
    color = cmap(i / len(list_counterfactuals))  # Gradual color change
    plt.scatter(list_counterfactuals[i][0], list_counterfactuals[i][1], c=color, s=10)

# Counterfactual Analysis

In [None]:
#target_classes = student(sample[:,:2]).softmax(dim=-1).argmin(dim=-1)
def plot_counterfactuals(counterfactuals, unguided_grads, guided_grads, pointwise_evolution=None, granularity=2):
    #stack = torch.stack(counterfactuals).permute(1, 0, 2) 
    
    bs = len(counterfactuals)
    #nrows = np.ceil(np.sqrt(bs))
    #fig, axs = plt.subplots(nrows=int(nrows), ncols=np.ceil(bs/nrows), figsize=(20, 20))
    
    plt.figure(figsize=(5,5))
    
    plt.scatter(data.data[:,0], data.data[:,1], c='lightcyan')#c=np.where(data.label == 0, 'lightcyan', 'lightgray'))
    
    for i, point in enumerate(counterfactuals):
        skip = point.shape[0]//granularity
        #skip = 1
        for j in range(0, point.shape[0] - skip, skip): # jth counterfactual
            step = j+skip
            # plot guided gradient at the last point to check direction (even though computer only for the first point)
            plt.arrow(point[j, 0], point[j, 1], 5.0*guided_grads[i][j][0], guided_grads[i][j][1], head_width=0.08, head_length=0.05, fc='deeppink', ec='deeppink')

            # plot the diffusion gradient
            plt.arrow(point[j, 0], point[j, 1], 50.0*unguided_grads[i][j][0], unguided_grads[i][j][1], head_width=0.08, head_length=0.05, fc='cadetblue', ec='cadetblue')
            
            if step>=point.shape[0]:
                break
                step = point.shape[0]
                
            plt.arrow(
                point[j, 0], point[j, 1], # plot the original point plus arrow until (j+granularity)th point
                point[step, 0] - point[j, 0], 
                point[step, 1] - point[j, 1],
                head_width=0.05, head_length=0.05, fc='blue', ec='blue'
            ) 
            plt.scatter(point[step, 0], point[step, 1], color='black')
        
        #if j+skip <point.shape[0]:
        
        if j+skip >= point.shape[0]:
            plt.arrow(point[j, 0], point[j, 1],
                          point[-1, 0] - point[j, 0],
                          point[-1, 1] - point[j, 1],
                          head_width=0.05, head_length=0.05, fc='blue', ec='blue')
            #plt.arrow(point[j, 0], point[j, 1], 10.0*guided_grads[i][j][0], guided_grads[i][j][1], head_width=0.08, head_length=0.05, fc='deeppink', ec='deeppink')

            # plot the diffusion gradient
            #plt.arrow(point[j, 0], point[j, 1], 10.0*unguided_grads[i][j][0], unguided_grads[i][j][1], head_width=0.08, head_length=0.05, fc='cadetblue', ec='cadetblue')

         
        else:
            plt.arrow(point[j+skip, 0], point[j+skip, 1],
                          point[-1, 0] - point[j+skip, 0],
                          point[-1, 1] - point[j+skip, 1],
                          head_width=0.05, head_length=0.05, fc='blue', ec='blue')
            plt.arrow(point[j+skip, 0], point[j+skip, 1], guided_grads[i][j+skip][0], guided_grads[i][j+skip][1], head_width=0.08, head_length=0.05, fc='deeppink', ec='deeppink')

            # plot the diffusion gradient
            plt.arrow(point[j+skip, 0], point[j+skip,1], unguided_grads[i][j+skip][0], unguided_grads[i][j+skip][1], head_width=0.08, head_length=0.05, fc='cadetblue', ec='cadetblue')
          
            
        # plot the last counterfactual
        #plt.arrow(point[j, 0], point[j, 1],
        #                  point[-1, 0] - point[j, 0],
        #                  point[-1, 1] - point[j, 1],
        #                  head_width=0.05, head_length=0.05, fc='blue', ec='blue')

        
        plt.scatter(point[0, 0], point[0, 1], color='lime')#, label='start')
        plt.scatter(point[-1, 0], point[-1, 1], color='red')#, label='end')

    plt.xlabel('X')
    plt.ylabel('Y')
    plt.legend()
    plt.grid()
    plt.show()

#counterfactuals, guided_grads, unguided_grads, total_series = adaptor.sample_counterfactual_ddpm(clean_batch=sample[6:7,:2], model=adaptor.model, classifier=student, num_noise_steps=60, target_classes=target_classes[6:7], classifier_grad_weight=0.5)
plot_counterfactuals(adaptor.counterfactuals_series, adaptor.unguided_grads, adaptor.guided_grads)

In [None]:
plt.figure(figsize=(5,5))

plt.scatter(data.data[:,0], data.data[:,1], c='lightcyan')#c=np.where(data.label == 0, 'lightcyan', 'lightgray'))

step = 1
point = adaptor.counterfactuals_series[0]
for j in range(0, point.shape[0]-step, step): # jth counterfactual
    plt.arrow(
        point[j, 0], point[j, 1], # plot the original point plus arrow until (j+granularity)th point
        point[j+step, 0] - point[j, 0], 
        point[j+step, 1] - point[j, 1],
        head_width=0.05, head_length=0.05, fc='blue', ec='blue'
    ) 
    plt.scatter(point[j+step, 0], point[j+step, 1], color='black')
    plt.scatter(point[0, 0], point[0, 1], color='lime')#, label='start')
    plt.scatter(point[-1, 0], point[-1, 1], color='red')#, label='end')



In [None]:
counterfactuals[0].shape[0]

In [None]:
# plotting different series for the same point
point_no = 1

for num_series in [20, 25]:#, 32, 33, 40, 50, 78]: #range(0, len(total_series), len(total_series) // 4):  # only selecting some counterfactuals
    series = total_series[num_series]  # selecting a counterfactual evolution series for one point
    stack = torch.stack(series)
    first_points = stack[:, point_no, :] # second index for different points # we are only taking the first point as above

    skip = len(first_points) // 10 # set steps to skip in each series
    if skip == 0:
        skip = 1
    plt.figure(figsize=(4, 4))
    plt.scatter(data.data[:,0], data.data[:,1])#, c=np.where(data.label == 0, 'lightcyan', 'lightgray'))
    plt.scatter(first_points[0, 0], first_points[0, 1], color='lime', label='start')
    #plt.arrow(first_points[0, 0], first_points[0, 1],
    #              first_points[25, 0] - first_points[0, 0],
    #              first_points[25, 1] - first_points[0, 1],
    #              head_width=0.05, head_length=0.05, fc='blue', ec='blue')
       
    for i in range(0, len(stack)-skip, skip):
        #plt.scatter(first_points[i, 0], first_points[i, 1], color='black')
        plt.arrow(first_points[i, 0], first_points[i, 1],
                  first_points[i+skip, 0] - first_points[i, 0],
                  first_points[i+skip, 1] - first_points[i, 1],
                  head_width=0.00005, head_length=0.000005, fc='blue', ec='blue')
        plt.scatter(first_points[i+skip, 0], first_points[i+skip, 1], color='black')
        


    plt.arrow(first_points[i+skip, 0], first_points[i+skip, 1],
                  first_points[-1, 0] - first_points[i+skip, 0],
                  first_points[-1, 1] - first_points[i+skip, 1],
                  head_width=0.000005, head_length=0.000005, fc='blue', ec='blue')
    plt.arrow(first_points[0, 0], first_points[0, 1], 10*guided_grads[point_no][num_series][0], 10*guided_grads[point_no][num_series][1], head_width=0.05, head_length=0.05, fc='deeppink', ec='deeppink')
    plt.arrow(first_points[0, 0], first_points[0, 1], 10*unguided_grads[point_no][num_series][0], 10*unguided_grads[point_no][num_series][1], head_width=0.05, head_length=0.05, fc='cadetblue', ec='cadetblue')

    plt.scatter(first_points[-1, 0], first_points[-1, 1], color='red', label='end')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.legend()
    plt.grid()
    plt.title(f'Timestep: {79 - num_series}')
plt.tight_layout()
plt.show()

In [None]:
fine_tuned = torch.load('peal_runs/artificial_circle_poisened_classifier/cfkd_ddpm_oracle_steps_100_grad_1/32/finetuned_model/model.cpl')

In [None]:
contour = fine_tuned(grid).detach()

In [None]:
contour_actual = contour.gather(1, contour.argmax(axis=-1).unsqueeze(1))
contour_actual_min = contour.gather(1, contour.argmin(axis=-1).unsqueeze(1))

In [None]:
#x1_span = np.linspace(-1.1, 1.1, 1000)
#x2_span = np.linspace(-1.1, 1.1, 1000)
#xx1, xx2 = np.meshgrid(x1_span, x2_span)
#grid = torch.from_numpy(np.array([xx1.flatten(), xx2.flatten()]).T).to(torch.float32)
#fine_tuned.eval()
#z = fine_tuned(grid).to(torch.float32).detach().numpy().argmax(axis=1).reshape(xx1.shape)
z1 = contour_actual.reshape(xx1.shape)
z2 = z = contour_actual_min.reshape(xx1.shape)

fig, ax = plt.subplots()

#a = student(grid).to(torch.float32).detach().numpy().argmax(axis=1).reshape(xx1.shape)
#b = teacher(grid).to(torch.float32).detach().numpy().argmax(axis=1).reshape(xx1.shape)


#data = torch.zeros([len(dataset.data),len(dataset.attributes)], dtype=torch.float16)
#for idx, key in enumerate(dataset.data):
#    data[idx] = dataset.data[key]
ax.scatter(data[:,0], data[:,1], c=np.where(data[:,-1] == 0, 'lightcyan', 'lightgray')[0])
ax.contour(xx1, xx2, z1, levels=[0.5], lcmap='coolwarm', label=f'confidence: 0.5')
ax.contour(xx1, xx2, z1, levels=[0.51], lcmap='coolwarm', label=f'confidence: 0.51')
ax.contour(xx1, xx2, z1, levels=[0.515], lcmap='coolwarm', label=f'confidence: 0.515')
ax.contour(xx1, xx2, z1, levels=[0.52], lcmap='coolwarm', label=f'confidence: 0.52')
ax.contour(xx1, xx2, z1, levels=[0.534], lcmap='coolwarm', label=f'confidence: 0.534')
ax.contour(xx1, xx2, z1, levels=[0.55], lcmap='coolwarm', label=f'confidence: 0.55')
ax.contour(xx1, xx2, z2, levels=[0.49], lcmap='coolwarm', label=f'confidence: 0.49')
ax.contour(xx1, xx2, z2, levels=[0.3], lcmap='coolwarm', label=f'confidence: 0.49')
#ax.contour(xx1, xx2, z, levels=[0], lcmap='coolwarm')
#ax.contour(xx1, xx2, z, levels=[-0.05], lcmap='coolwarm')
#ax.contour(xx1, xx2, z, levels=[-0.1], lcmap='coolwarm')
#ax.contour(xx1, xx2, z, levels=[-1], lcmap='coolwarm')
ax.contour(xx1, xx2, a, levels=[0], linestyles='dashed', colors='red', label='student')
ax.contour(xx1, xx2, b, levels=[0], linestyles='dashed', colors='green', label='teacher')
plt.show()

In [None]:
contour = fine_tuned(grid)

In [None]:
contour.

In [None]:
fine_tuned = torch.load('peal_runs/artificial_symbolic_100_classifier/cfkd_ddpm_oracle_50/model.cpl')
sum(fine_tuned(torch.tensor(data.data)).softmax(dim=-1).argmax(dim=-1) == torch.tensor(data.label)) / len(data.label)

In [None]:
fine_tuned = torch.load('peal_runs/artificial_circle_poisened_classifier/cfkd_ddpm_oracle_steps_20_grad_1/model.cpl')
x1_span = np.linspace(-1.1, 1.1, 1000)
x2_span = np.linspace(-1.1, 1.1, 1000)
xx1, xx2 = np.meshgrid(x1_span, x2_span)
grid = torch.from_numpy(np.array([xx1.flatten(), xx2.flatten()]).T).to(torch.float32)
fine_tuned.eval()
z = fine_tuned(grid).to(torch.float32).detach().numpy().argmax(axis=1).reshape(xx1.shape)
fig, ax = plt.subplots()
a = student(grid).to(torch.float32).detach().numpy().argmax(axis=1).reshape(xx1.shape)
b = teacher(grid).to(torch.float32).detach().numpy().argmax(axis=1).reshape(xx1.shape)


#data = torch.zeros([len(dataset.data),len(dataset.attributes)], dtype=torch.float16)
#for idx, key in enumerate(dataset.data):
#    data[idx] = dataset.data[key]
ax.scatter(data[:,0], data[:,1], c=np.where(data[:,-1] == 0, 'lightcyan', 'lightgray')[0])
ax.contour(xx1, xx2, z, levels=[0], linestyles='dashed', label='fine-tuned')

ax.contour(xx1, xx2, a, levels=[0], linestyles='dashed', colors='red', label='student')
ax.contour(xx1, xx2, b, levels=[0], linestyles='dashed', colors='green', label='teacher')
plt.title('20 Iterations')
ax.grid()

In [None]:
fine_tuned = torch.load('peal_runs/artificial_symbolic_100_classifier/cfkd_ddpm_oracle_100/model.cpl')
sum(fine_tuned(torch.tensor(data.data)).softmax(dim=-1).argmax(dim=-1) == torch.tensor(data.label)) / len(data.label)

In [None]:
fine_tuned = torch.load('peal_runs/artificial_circle_poisened_classifier/cfkd_ddpm_oracle_steps_30_grad_1/model.cpl')
#x1_span = np.linspace(-1.1, 1.1, 1000)
#x2_span = np.linspace(-1.1, 1.1, 1000)
#xx1, xx2 = np.meshgrid(x1_span, x2_span)
#grid = torch.from_numpy(np.array([xx1.flatten(), xx2.flatten()]).T).to(torch.float32)
fine_tuned.eval()
z = fine_tuned(grid).to(torch.float32).detach().numpy().argmax(axis=1).reshape(xx1.shape)
fig, ax = plt.subplots()
a = student(grid).to(torch.float32).detach().numpy().argmax(axis=1).reshape(xx1.shape)
b = teacher(grid).to(torch.float32).detach().numpy().argmax(axis=1).reshape(xx1.shape)


#data = torch.zeros([len(dataset.data),len(dataset.attributes)], dtype=torch.float16)
#for idx, key in enumerate(dataset.data):
#    data[idx] = dataset.data[key]
ax.scatter(data[:,0], data[:,1], c=np.where(data[:,-1] == 0, 'lightcyan', 'lightgray')[0])
ax.contour(xx1, xx2, z, levels=[0], linestyles='dashed', label='fine-tuned')

ax.contour(xx1, xx2, a, levels=[0], linestyles='dashed', colors='red', label='student')
ax.contour(xx1, xx2, b, levels=[0], linestyles='dashed', colors='green', label='teacher')
plt.title('10 Iterations')
ax.grid()

In [None]:
fine_tuned = torch.load('peal_runs/artificial_circle_poisened_classifier/cfkd_ddpm_oracle_steps_50_grad_1/model.cpl')
#x1_span = np.linspace(-1.1, 1.1, 1000)
#x2_span = np.linspace(-1.1, 1.1, 1000)
#xx1, xx2 = np.meshgrid(x1_span, x2_span)
#grid = torch.from_numpy(np.array([xx1.flatten(), xx2.flatten()]).T).to(torch.float32)
fine_tuned.eval()
z = fine_tuned(grid).to(torch.float32).detach().numpy().argmax(axis=1).reshape(xx1.shape)
fig, ax = plt.subplots()
a = student(grid).to(torch.float32).detach().numpy().argmax(axis=1).reshape(xx1.shape)
b = teacher(grid).to(torch.float32).detach().numpy().argmax(axis=1).reshape(xx1.shape)


#data = torch.zeros([len(dataset.data),len(dataset.attributes)], dtype=torch.float16)
#for idx, key in enumerate(dataset.data):
#    data[idx] = dataset.data[key]
ax.scatter(data[:,0], data[:,1], c=np.where(data[:,-1] == 0, 'lightcyan', 'lightgray')[0])
ax.contour(xx1, xx2, z, levels=[0], linestyles='dashed', label='fine-tuned')

ax.contour(xx1, xx2, a, levels=[0], linestyles='dashed', colors='red', label='student')
ax.contour(xx1, xx2, b, levels=[0], linestyles='dashed', colors='green', label='teacher')
plt.title('10 Iterations')
ax.grid()

In [None]:
fine_tuned = torch.load('peal_runs/artificial_circle_poisened_classifier/cfkd_ddpm_oracle_steps_50_grad_1/model.cpl')
x1_span = np.linspace(-1.1, 1.1, 1000)
x2_span = np.linspace(-1.1, 1.1, 1000)
xx1, xx2 = np.meshgrid(x1_span, x2_span)
grid = torch.from_numpy(np.array([xx1.flatten(), xx2.flatten()]).T).to(torch.float32)
fine_tuned.eval()
z = fine_tuned(grid).to(torch.float32).detach().numpy().argmax(axis=1).reshape(xx1.shape)
fig, ax = plt.subplots()
a = student(grid).to(torch.float32).detach().numpy().argmax(axis=1).reshape(xx1.shape)
b = teacher(grid).to(torch.float32).detach().numpy().argmax(axis=1).reshape(xx1.shape)


data = torch.zeros([len(dataset.data),len(dataset.attributes)], dtype=torch.float16)
for idx, key in enumerate(dataset.data):
    data[idx] = dataset.data[key]
ax.scatter(data[:,0], data[:,1], c=np.where(data[:,-1] == 0, 'lightcyan', 'lightgray')[0])
ax.contour(xx1, xx2, z, levels=[0], linestyles='dashed', label='fine-tuned')

ax.contour(xx1, xx2, a, levels=[0], linestyles='dashed', colors='red', label='student')
ax.contour(xx1, xx2, b, levels=[0], linestyles='dashed', colors='green', label='teacher')
ax.grid()

In [None]:
fine_tuned = torch.load('peal_runs/artificial_symbolic_100_classifier/cfkd_ddpm_new_oracle_100/model.cpl')
x1_span = np.linspace(-1.1, 1.1, 1000)
x2_span = np.linspace(-1.1, 1.1, 1000)
xx1, xx2 = np.meshgrid(x1_span, x2_span)
grid = torch.from_numpy(np.array([xx1.flatten(), xx2.flatten()]).T).to(torch.float32)
fine_tuned.eval()
z = fine_tuned(grid).to(torch.float32).detach().numpy().argmax(axis=1).reshape(xx1.shape)
fig, ax = plt.subplots()
a = student(grid).to(torch.float32).detach().numpy().argmax(axis=1).reshape(xx1.shape)
b = teacher(grid).to(torch.float32).detach().numpy().argmax(axis=1).reshape(xx1.shape)


data = torch.zeros([len(dataset.data),len(dataset.attributes)], dtype=torch.float16)
for idx, key in enumerate(dataset.data):
    data[idx] = dataset.data[key]
ax.scatter(data[:,0], data[:,1], c=np.where(data[:,-1] == 0, 'lightcyan', 'lightgray')[0])
ax.contour(xx1, xx2, z, levels=[0], linestyles='dashed', label='fine-tuned')

ax.contour(xx1, xx2, a, levels=[0], linestyles='dashed', colors='red', label='student')
ax.contour(xx1, xx2, b, levels=[0], linestyles='dashed', colors='green', label='teacher')
ax.grid()

In [None]:
import math
from copy import deepcopy
import torch
import torch.nn as nn


import torch.nn.functional as F
from tqdm import tqdm

class DDPM(nn.Module):
    def __init__(self, input_dim: int, embed_dim: int, num_timesteps: int, var_schedule='linear'):
        super(DDPM, self).__init__()

        self.input_dim = input_dim
        self.num_timesteps = num_timesteps

        def linear_schedule(num_timesteps: int):
            scale = 1000 / num_timesteps
            min_var = scale * 1e-5
            max_var = scale * 1e-2
            return torch.linspace(min_var, max_var, num_timesteps, dtype=torch.float32)

        def cosine_schedule(num_timesteps, s=0.008):
            steps = num_timesteps + 1
            x = torch.linspace(0, num_timesteps, steps, dtype=torch.float64)
            alphas_cumprod = torch.cos(((x / num_timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
            alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
            betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
            return torch.clip(betas, 0, 0.999)

        if var_schedule == 'linear':
            betas = linear_schedule(num_timesteps)

        if var_schedule == 'cosine':
            betas = cosine_schedule(num_timesteps)

        self.register_buffer("beta", betas)
        self.register_buffer("alpha", 1 - self.beta)
        self.register_buffer("alpha_bar", self.alpha.cumprod(0))

    def forward_diffusion(self, clean_x: torch.Tensor, noise: torch.tensor, timestep: torch.Tensor):
        timestep = torch.tensor([timestep])
        if timestep.shape[0] == clean_x.shape[0]:
            alpha_bar_t = self.alpha_bar[timestep][:, None]
        else:
            alpha_bar_t = self.alpha_bar[timestep].repeat(clean_x.shape[0])[:, None]
        mu = torch.sqrt(alpha_bar_t)
        std = torch.sqrt(1 - alpha_bar_t)
        noisy_x = mu * clean_x + std * noise
        return noisy_x
    
    
    def forward_diffusion_ddim():
        pass

    def loss(self, model: nn.Module, clean_x: torch.Tensor, loss='L1_simple', var_model: nn.Module=None) -> torch.Tensor:
        t = torch.randint(self.num_timesteps, (clean_x.shape[0],))
        #t = int(torch.rand(1)*self.num_timesteps)
        eps_t = torch.randn_like(clean_x)
        alpha_bar_t = self.alpha_bar[t][:, None]
        
        #x_t = self.forward_diffusion(clean_x=clean_x, noise=eps_t, timestep=t)
        
        x_t = torch.sqrt(alpha_bar_t) * clean_x + torch.sqrt(1 - alpha_bar_t) * eps_t
        
        eps_hat = model(x=x_t, t=t)
        if loss == "L1_simple":
              loss_diff = nn.MSELoss(reduction='sum')(eps_hat, eps_t)

        return loss_diff
    

    def reverse_diffusion_ddpm(self, noisy_x: torch.Tensor, model: nn.Module, timestep: torch.Tensor):
        alpha_t = self.alpha[timestep].repeat(noisy_x.shape[0])[:, None]
        alpha_bar_t = self.alpha_bar[timestep].repeat(noisy_x.shape[0])[:, None]
        beta_t = 1 - alpha_t
        eps_hat = model(x=noisy_x, t=timestep)
        posterior_mean = (1 / torch.sqrt(alpha_t)) * (noisy_x - (beta_t / torch.sqrt(1 - alpha_bar_t) * eps_hat))
        z = torch.randn_like(noisy_x)
        
        if timestep > 0:
            denoised_x = posterior_mean + torch.sqrt(beta_t)*z #* z * (timestep > 0))  # variance = beta_t
        else:
            denoised_x = posterior_mean
                                           
        return denoised_x
    
    def reverse_diffusion_ddim():
        pass


    def sample_ddpm(self, model: nn.Module, n_samples: int = 256, label=None):
        x_pred = []
        x = torch.randn(n_samples, self.input_dim)
        x_pred.append(x)

        with torch.no_grad():
            for t in reversed(range(0, self.num_timesteps)):
                #alpha_t = self.alpha[t].repeat(n_samples)[:, None]
                #alpha_bar_t = self.alpha_bar[t].repeat(n_samples)[:, None]
                #eps_hat = model(x, t)

                #sigma_t = torch.sqrt(1 - self.alpha[t])
                #x_t_mean = (x - eps_hat * (1 - alpha_t) / torch.sqrt(1 - alpha_bar_t)) / torch.sqrt(alpha_t)

                #if t > 0:
                #    x = x_t_mean + (sigma_t * torch.randn_like(x))
                #else:
                #    x = x_t_mean
                x = self.reverse_diffusion_ddpm(noisy_x=x, model=model, timestep=t)
                
                x_pred.append(x)
                
        return x_pred


    def sample_counterfactual_ddpm(self, clean_batch: torch.Tensor, model: nn.Module, classifier: nn.Module, num_noise_steps: int, counterfactual_class: torch.Tensor, classifier_grad_weight: float, perceptual_weight: float):
        
        classifier.eval()
        
        # DEFINE BATCH SIZE AND COUNTERFACTUAL CLASS
        bs = clean_batch.shape[0]
        #label = torch.tensor(counterfactual_class).repeat(clean_batch.shape[0])
        
        # COMPUTE CLEAN GRADIENTS FOR THE FIRST STEP
        #print(classifier(clean_batch).softmax(dim=-1).argmax(dim=-1))
        classifier_criterion = lambda x: F.cross_entropy(classifier(x), counterfactual_class)
        clean_batch_copy = torch.nn.Parameter(clean_batch)
        loss = classifier_criterion(clean_batch_copy)
        loss.backward()
        
        clean_grad = classifier_grad_weight * clean_batch_copy.grad.detach()
        
        # REDEFINING VARIABLES AND PERFORMING FORWARD DIFFUSION
        #next_z = clean_batch
        eps_t = torch.randn_like(clean_batch)
        
        #next_z = self.forward_diffusion(clean_x=clean_batch, noise=eps_t, timestep=num_noise_steps)
 
        alpha_bar_t = self.alpha_bar[num_noise_steps].repeat(bs)[:, None]
        alpha_t = self.alpha[num_noise_steps].repeat(bs)[:, None]
        next_z = torch.sqrt(alpha_bar_t) * clean_batch + torch.sqrt(1 - alpha_bar_t) * eps_t
        
        #plt.scatter(data.data[:,0], data.data[:,1], c=np.where(data.label == 0, 'lightcyan', 'lightgray'))
        #plt.scatter(next_z[:,0], next_z[:,1])
        
        counterfactuals = [] # total counterfactuals
        counterfactuals.append(clean_batch)
        guided_grads = []  # guided grads at the first step
        unconditional_grads = [] # diffusion grads at the first step 
        total_series = [] # contains evolution from noisy to cleaned instance for each data point
        losses = []
        losses.append(loss)
        for i in tqdm(range(0, num_noise_steps)[::-1]):
            
            # Denoise z_t to create z_t-1 (next z)
            
            alpha_i = self.alpha[i].repeat(bs)[:, None]
            alpha_bar_i = self.alpha_bar[i].repeat(bs)[:, None]
            sigma_i = torch.sqrt(1 - self.alpha[i])
            eps_hat = model(next_z, i)

            # Unconditional mean
            unconditional_grad = -eps_hat / torch.sqrt(1 - alpha_bar_i)
            z_t_mean = (next_z + unconditional_grad * (1 - alpha_i)) / torch.sqrt(alpha_i)

            # Guided mean
            z_t_mean -= sigma_i * (clean_grad / torch.sqrt(alpha_bar_i))

            if i > 0:
                next_z = z_t_mean + (sigma_i * torch.randn_like(clean_batch)) 
            else:
                next_z = z_t_mean

            next_x = next_z.clone()
            # Denoise to create a cleaned x (next x)
            series = []
            series.append(next_x.detach())
            for t in range(0, i)[::-1]:
                if i == 0:
                    break
                #next_x = self.reverse_diffusion_ddpm(noisy_x=next_x, model=model, timestep=t)
                alpha_t = self.alpha[t].repeat(bs)[:, None]
                alpha_bar_t = self.alpha_bar[t].repeat(bs)[:, None]
                sigma_t = torch.sqrt(1 - self.alpha[t])
                eps_hat = model(next_x, t)
                next_x = (next_x - eps_hat * (1 - alpha_t) / torch.sqrt(1 - alpha_bar_t)) / torch.sqrt(alpha_t)
                if t > 0:
                    next_x = next_x + (sigma_t * torch.randn_like(next_x))
                else:
                    next_x = next_x
                
                series.append(next_x.detach())
            total_series.append(series)
            guided_grads.append(-sigma_i * clean_grad.detach() / torch.sqrt(alpha_bar_i))
            unconditional_grads.append(unconditional_grad.detach() * (1 - alpha_i) / torch.sqrt(alpha_i) )
            
            if i != 0:
                counterfactuals.append(next_x.detach())


            # Gradient wrt denoised image (next_x)
            next_x_copy = torch.nn.Parameter(next_x.clone())
            loss = classifier_criterion(next_x_copy)
            loss.backward()
            losses.append(loss)
            clean_classifier_grad = next_x_copy.grad.detach()
            radius = torch.sqrt((next_x**2).sum(axis=1))
            perceptual_grad = (2 * (radius - 1))[:, None]
            clean_grad = classifier_grad_weight * clean_classifier_grad + perceptual_weight * perceptual_grad
            
            self.counterfactuals = counterfactuals
            self.guided_grads = guided_grads
            self.diffusion_grads = unconditional_grads
            self.pointwise_evolution = total_series 
            
        counterfactuals = torch.stack(counterfactuals).permute(1, 0, 2) 
        guided_grads = torch.stack(guided_grads).permute(1, 0, 2) 
        unguided_grads = torch.stack(unconditional_grads).permute(1, 0, 2)
        #total_series = torch.stack(total_series).permute(1, 0, 2)
            
        return counterfactuals, guided_grads, unguided_grads, total_series
    

    
    def plot_counterfactuals():
        pass



In [None]:
input_dim = 2
embed_dim = 256
T = 500

model = BasicDiscreteTimeModel(input_dim, embed_dim, num_timesteps=T)
ddpm = DDPM(input_dim=input_dim, embed_dim=embed_dim, num_timesteps=T)
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

model.load_state_dict(torch.load('./peal_runs/artificial_symbolic_100_generator/ddpm.pth'))



In [None]:
samples = ddpm.sample_ddpm(model, n_samples=500)
plt.scatter(samples[-1][:,0], samples[-1][:,1])
plt.show()

In [None]:
target_classes = student(sample[:,:2]).softmax(dim=-1).argmin(dim=-1)

In [None]:
counterfactuals, guided_grads, unguided_grads, total_series = ddpm.sample_counterfactual_ddpm(clean_batch=sample[:,:2], model=model, classifier=student, num_noise_steps=60, counterfactual_class=target_classes, classifier_grad_weight=1.0, perceptual_weight=0.0)

In [None]:
from torch.utils.data import Dataset
class CircleDataset(Dataset):
    def __init__(self, data_path: str, features: list[str], labels: list[str]):
        super().__init__()
        data = pd.read_csv(data_path)
        self.data = data[features].to_numpy('float32')
        self.label = data[labels].to_numpy('long')
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx,:], self.label[idx]
    
    def serialize_dataset(self, output_dir: str, x_list: list, y_list: list):
        pass
    
    __name__ = 'circle'
    
data_config = {
    'data_path': 'datasets/circle/size_500_radius_1_seed_0.csv',
    'features': ['x1', 'x2'],
    'target': 'Target'
}

data = CircleDataset(data_path=data_config['data_path'], features=data_config['features'], labels=data_config['target'])

    
def plot_counterfactuals(counterfactuals, diffusion_grads, classifier_grads, pointwise_evolution=None, granularity=2):
    #stack = torch.stack(counterfactuals).permute(1, 0, 2) 
    
    bs = len(counterfactuals)
    #nrows = np.ceil(np.sqrt(bs))
    #fig, axs = plt.subplots(nrows=int(nrows), ncols=np.ceil(bs/nrows), figsize=(20, 20))
    
    plt.figure(figsize=(5,5))
    
    plt.scatter(data.data[:,0], data.data[:,1], c='lightcyan')#c=np.where(data.label == 0, 'lightcyan', 'lightgray'))
    
    for i, point in enumerate(counterfactuals):
        skip = point.shape[0]//granularity
        for j in range(0, point.shape[0] - skip, skip): # jth counterfactual
            step = j+skip
            # plot guided gradient at the last point to check direction (even though computer only for the first point)
            plt.arrow(point[j, 0], point[j, 1], 5.0*guided_grads[i][j][0], 5.0*guided_grads[i][j][1], head_width=0.08, head_length=0.05, fc='deeppink', ec='deeppink')

            # plot the diffusion gradient
            plt.arrow(point[j, 0], point[j, 1], 20.0*unguided_grads[i][j][0], 20.0*unguided_grads[i][j][1], head_width=0.08, head_length=0.05, fc='cadetblue', ec='cadetblue')
            
            if step>=point.shape[0]:
                break
                step = point.shape[0]
                
            plt.arrow(
                point[j, 0], point[j, 1], # plot the original point plus arrow until (j+granularity)th point
                point[step, 0] - point[j, 0], 
                point[step, 1] - point[j, 1],
                head_width=0.05, head_length=0.05, fc='blue', ec='blue'
            ) 
            plt.scatter(point[step, 0], point[step, 1], color='black')
        
        #if j+skip <point.shape[0]:
        
        if j+skip >= point.shape[0]:
            plt.arrow(point[j, 0], point[j, 1],
                          point[-1, 0] - point[j, 0],
                          point[-1, 1] - point[j, 1],
                          head_width=0.05, head_length=0.05, fc='blue', ec='blue')
            plt.arrow(point[j, 0], point[j, 1], 10.0*guided_grads[i][j][0], 10.0*guided_grads[i][j][1], head_width=0.08, head_length=0.05, fc='deeppink', ec='deeppink')

            # plot the diffusion gradient
            plt.arrow(point[j, 0], point[j, 1], 50.0*unguided_grads[i][j][0], 50.0*unguided_grads[i][j][1], head_width=0.08, head_length=0.05, fc='cadetblue', ec='cadetblue')

         
        else:
            plt.arrow(point[j+skip, 0], point[j+skip, 1],
                          point[-1, 0] - point[j+skip, 0],
                          point[-1, 1] - point[j+skip, 1],
                          head_width=0.05, head_length=0.05, fc='blue', ec='blue')
            plt.arrow(point[j+skip, 0], point[j+skip, 1], 10.0*guided_grads[i][j+skip][0], 10.0*guided_grads[i][j+skip][1], head_width=0.08, head_length=0.05, fc='deeppink', ec='deeppink')

            # plot the diffusion gradient
            plt.arrow(point[j+skip, 0], point[j+skip,1], 50.0*unguided_grads[i][j+skip][0], 50.0*unguided_grads[i][j+skip][1], head_width=0.08, head_length=0.05, fc='cadetblue', ec='cadetblue')
          
            
        # plot the last counterfactual
        #plt.arrow(point[j, 0], point[j, 1],
        #                  point[-1, 0] - point[j, 0],
        #                  point[-1, 1] - point[j, 1],
        #                  head_width=0.05, head_length=0.05, fc='blue', ec='blue')

        
        plt.scatter(point[0, 0], point[0, 1], color='lime', label='start')
        plt.scatter(point[-1, 0], point[-1, 1], color='red', label='end')

    plt.xlabel('X')
    plt.ylabel('Y')
    plt.legend()
    plt.grid()
    plt.show()

counterfactuals, guided_grads, unguided_grads, total_series = adaptor.sample_counterfactual_ddpm(clean_batch=sample[:,:2], model=adaptor.model, classifier=student, num_noise_steps=80, target_classes=target_classes, classifier_grad_weight=15.0)
plot_counterfactuals(counterfactuals[1:2,:], unguided_grads, guided_grads)

In [None]:
len(total_series[0])

In [None]:
# plotting different series for the same point
point_no = 1

for num_series in [20, 25]:#, 32, 33, 40, 50, 78]: #range(0, len(total_series), len(total_series) // 4):  # only selecting some counterfactuals
    series = total_series[num_series]  # selecting a counterfactual evolution series for one point
    stack = torch.stack(series)
    first_points = stack[:, point_no, :] # second index for different points # we are only taking the first point as above

    skip = len(first_points) // 10 # set steps to skip in each series
    if skip == 0:
        skip = 1
    plt.figure(figsize=(4, 4))
    plt.scatter(data.data[:,0], data.data[:,1])#, c=np.where(data.label == 0, 'lightcyan', 'lightgray'))
    plt.scatter(first_points[0, 0], first_points[0, 1], color='lime', label='start')
    #plt.arrow(first_points[0, 0], first_points[0, 1],
    #              first_points[25, 0] - first_points[0, 0],
    #              first_points[25, 1] - first_points[0, 1],
    #              head_width=0.05, head_length=0.05, fc='blue', ec='blue')
       
    for i in range(0, len(stack)-skip, skip):
        #plt.scatter(first_points[i, 0], first_points[i, 1], color='black')
        plt.arrow(first_points[i, 0], first_points[i, 1],
                  first_points[i+skip, 0] - first_points[i, 0],
                  first_points[i+skip, 1] - first_points[i, 1],
                  head_width=0.00005, head_length=0.000005, fc='blue', ec='blue')
        plt.scatter(first_points[i+skip, 0], first_points[i+skip, 1], color='black')
        


    plt.arrow(first_points[i+skip, 0], first_points[i+skip, 1],
                  first_points[-1, 0] - first_points[i+skip, 0],
                  first_points[-1, 1] - first_points[i+skip, 1],
                  head_width=0.000005, head_length=0.000005, fc='blue', ec='blue')
    plt.arrow(first_points[0, 0], first_points[0, 1], 10*guided_grads[point_no][num_series][0], 10*guided_grads[point_no][num_series][1], head_width=0.05, head_length=0.05, fc='deeppink', ec='deeppink')
    plt.arrow(first_points[0, 0], first_points[0, 1], 10*unguided_grads[point_no][num_series][0], 10*unguided_grads[point_no][num_series][1], head_width=0.05, head_length=0.05, fc='cadetblue', ec='cadetblue')

    plt.scatter(first_points[-1, 0], first_points[-1, 1], color='red', label='end')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.legend()
    plt.grid()
    plt.title(f'Timestep: {79 - num_series}')
plt.tight_layout()
plt.show()

In [None]:
x = np.array([10, 16, 20])
y = np.array([0.7680, 0.8680, 0.9020])

In [None]:
plt.plot(x, y, marker='o')
for i, (xi, yi) in enumerate(zip(x, y)):
    plt.text(xi, yi, f'({yi})', fontsize=12, ha='right', va='bottom')

plt.xlabel('No. of Iterations')
plt.ylabel('Accuracy')
plt.show()

In [None]:
?plt.plot

In [None]:
# version 1
def discard_counterfactuals(self, counterfactuals, classifier, target_classes, target_confidence, minimal_counterfactuals, tolerance=0.1):
        
        # compute distance of current minimal_counterefactuals from radius 1.0
        current_counterfactual_distance_from_manifold = torch.abs((torch.pow(minimal_counterfactuals, 2).sum(dim=-1) - 1.0))
        
        for i in range(len(counterfactuals)):  

            # compute classifier  for all the counterfactuals for each point
            new_counterfactuals_confidence = classifier(counterfactuals[i]).softmax(dim=-1)[:, target_classes[i]]
            
            
            # check if new counterfactuals satisfy the confidence constraint
            new_confidence_satisfied = new_counterfactuals_confidence > target_confidence
            
            new_confidence_satisfied_indices = torch.nonzero(new_counterfactuals_confidence > target_confidence)
            
            
            # 
            new_tolerance_satisfied = torch.abs((torch.pow(counterfactuals[i], 2).sum(dim=-1) - 1.0)) < tolerance
            new_tolerance_satisfied_indices = torch.nonzero(torch.abs((torch.pow(counterfactuals[i], 2).sum(dim=-1) - 1.0)) < tolerance)
            
            # check where target confidence is reached AND if new counterfactuals are within the tolerance range
            #indices = torch.nonzero((confidence > target_confidence) 
            #                    (torch.abs((torch.pow(counterfactuals[i], 2).sum(dim=-1) - 1.0)) < tolerance))
            
            
            new_confidence_and_tolerance_satisfied_indices = torch.nonzero(new_confidence_satisfied & new_tolerance_satisfied)
            
            current_tolerance_satisfied = current_counterfactual_distance_from_manifold[i] < tolerance
            #classifier(minimal_counterfactuals[i:i+1]).softmax(dim=-1)[0][target_classes[i]]
            current_confidence_satisfied = classifier(minimal_counterfactuals[i:i+1]).softmax(dim=-1)[0][target_classes[i]].item() > target_confidence
            
            # if current counterfactual satisfies confidence and tolerance, maintain status quo 
            
            
            
            if (current_tolerance_satisfied) and (current_confidence_satisfied):
            
                #print(f'confidence and tolerance satisfied for {i}')
                continue
          
            # if new target confidence and current target confidence not satisfied
            # but new tolerance is satisfied, then move the point closer
            elif (new_tolerance_satisfied_indices.nelement() !=0) & (new_confidence_satisfied_indices.nelement() == 0) and (not current_confidence_satisfied):
                minimal_counterfactuals[i] = counterfactuals[i][torch.nonzero(new_tolerance_satisfied)[-1].item()]
            
              
            # if current counterfactual is on the manifold but is not actually a counterfactual, 
            # replace it with new counterfactual if there exists any 
            elif (not current_confidence_satisfied) & (new_confidence_and_tolerance_satisfied_indices.nelement() != 0):
                # change this to only include the first where confidence and tolerance is satisfied
                minimal_counterfactuals[i] = counterfactuals[i][new_confidence_and_tolerance_satisfied_indices[0].item()]
                
            else:
                continue
            

        return minimal_counterfactuals


In [None]:
# version 2

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm, trange
from typing import Tuple
import logging
from torch.utils.data import DataLoader
import math
%matplotlib inline

logging.getLogger().setLevel(logging.INFO)

class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=500):

        super(PositionalEncoding, self).__init__()
        max_len += 1
        self.P = torch.zeros(max_len, embed_dim)
        freqs = torch.arange(max_len)[:, None] / (torch.pow(10000, torch.arange(0, embed_dim, 2, dtype=torch.float32)/embed_dim))

        self.P[:,0::2] = torch.sin(freqs)
        self.P[:,1::2] = torch.cos(freqs)
        
        self.P = self.P[1:]
        
    def forward(self, t):
        return self.P[t]
    
class ScoreNetwork(nn.Module):
    def __init__(self, input_dim, embed_dim):
        super(ScoreNetwork, self).__init__()
        self.embed_dim = embed_dim
        self.layer1 = nn.LazyLinear(embed_dim)
        self.layer2 = nn.LazyLinear(embed_dim)
        self.layer3 = nn.LazyLinear(embed_dim)
        self.norm = nn.LayerNorm(embed_dim)
        self.layer4 = nn.LazyLinear(input_dim)
    
    def forward(self, x, time_embed):
        x = self.layer1(x) + time_embed
        x = F.silu(self.layer2(x))
        x = F.silu(self.layer3(x))  
        return self.layer4((self.norm(x)))
                           
class BasicDiscreteTimeModel(nn.Module):
    def __init__(self, input_dim: int, embed_dim: int, num_timesteps: int):
        super(BasicDiscreteTimeModel, self).__init__()

        self.positional_embeddings = PositionalEncoding(embed_dim=embed_dim, max_len=num_timesteps)
        self.score_network = ScoreNetwork(input_dim=input_dim, embed_dim=embed_dim)

    def forward(self, x, t):

        time_embed = self.positional_embeddings(t)
        return self.score_network(x, time_embed)

    

class CircleDiffusionAdaptor(nn.Module):
    def __init__(self, config, dataset, model_dir=None):
        super(CircleDiffusionAdaptor, self).__init__()
        # self.config = load_yaml_config(config)
        self.config = config
        
        if not model_dir is None:
            self.model_dir = model_dir
        else:
            self.model_dir = config['base_path']
        
        #if not os.path.exists(model_dir):
        #    os.mkdir(model_dir)
        #self.model_dir = model_dir
        self.input_dim = config['input_dim']
        try: 
            self.num_timesteps = config['num_timesteps']
        except KeyError: 
            pass
        
        self.dataset = dataset
        self.input_idx = [idx for idx, element in enumerate(self.dataset.attributes) if element not in ['Confounder', 'Target']]
        self.target_idx = [idx for idx, element in enumerate(self.dataset.attributes) if element == 'Target']
        #data = torch.zeros([len(dataset.data),len(dataset.attributes)], dtype=torch.float16)
        #for idx, key in enumerate(dataset.data):
        #    data[idx] = dataset.data[key]
        #self.model = self.train_and_load_diffusion(model_name='diffusion.pth')
        
        def schedules(num_timesteps: int, type: str='linear'):
 
            if type=='linear':
                scale = 1000 / num_timesteps
                min_var = scale * 1e-4
                max_var = scale * 1e-2
                return torch.linspace(min_var, max_var, num_timesteps, dtype=torch.float32)
            elif type=='cosine':
                steps = num_timesteps + 1
                x = torch.linspace(0, num_timesteps, steps, dtype=torch.float64)
                alphas_cumprod = torch.cos(((x / num_timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
                alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
                betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
                return torch.clip(betas, 0, 0.999)
        
        betas = schedules(num_timesteps=config['num_timesteps'], type=config['var_schedule'])

        self.register_buffer("beta", betas)
        self.register_buffer("alpha", 1 - self.beta)
        self.register_buffer("alpha_bar", self.alpha.cumprod(0))
        
        
    def forward_diffusion(self, clean_x: torch.Tensor, noise: torch.tensor, timestep: torch.Tensor):
        
        if isinstance(timestep, int):
            timestep = torch.tensor([timestep])
            alpha_bar_t = self.alpha_bar[timestep].repeat(clean_x.shape[0])[:, None]
        else:
            alpha_bar_t = self.alpha_bar[timestep][:, None]
        mu = torch.sqrt(alpha_bar_t)
        std = torch.sqrt(1 - alpha_bar_t)
        noisy_x = mu * clean_x + std * noise
        return noisy_x
    

    def reverse_diffusion_ddpm(self, noisy_x: torch.Tensor, model: nn.Module, timestep: torch.Tensor):
        alpha_t = self.alpha[timestep].repeat(noisy_x.shape[0])[:, None]
        alpha_bar_t = self.alpha_bar[timestep].repeat(noisy_x.shape[0])[:, None]
        beta_t = 1 - alpha_t
        eps_hat = model(x=noisy_x, t=timestep)
        posterior_mean = (1 / torch.sqrt(alpha_t)) * (noisy_x - (beta_t / torch.sqrt(1 - alpha_bar_t) * eps_hat))
        z = torch.randn_like(noisy_x)
        
        if timestep > 0:
            alpha_bar_t_minus_1 = self.alpha_bar[timestep-1].repeat(noisy_x.shape[0])[:, None]
            sigma_t = beta_t * (1 - alpha_bar_t_minus_1) / (1 - alpha_bar_t)
            denoised_x = posterior_mean + torch.sqrt(sigma_t)*z #* z * (timestep > 0))  # variance = beta_t
        else:
            denoised_x = posterior_mean
                                           
        return denoised_x
    
    def train_and_load_diffusion(self, model_name='diffusion.pt', mode=None):
        
        self.model_path = os.path.join(self.model_dir, model_name)
        model = BasicDiscreteTimeModel(input_dim=self.config['input_dim'], embed_dim=self.config['embed_dim'], num_timesteps=self.config['num_timesteps'])
        if model_name in os.listdir(self.model_dir) and not mode == "train":
            model.load_state_dict(torch.load(self.model_path))
            logging.info(f'Model found with path {self.model_path}')
        elif model_name not in os.listdir(self.model_dir) and mode != 'train':
            logging.info('Model not found. Please run train_and_load_diffusion method and set its argument mode="train" ')
        else:
            logging.info(
                f'Training model with path {self.model_path}'
            )
        
        def diffusion_loss(model: nn.Module, clean_x: torch.Tensor) -> torch.Tensor:
            t = torch.randint(self.num_timesteps, (clean_x.shape[0],))
            eps_t = torch.randn_like(clean_x)
            alpha_bar_t = self.alpha_bar[t][:, None]
            x_t = self.forward_diffusion(clean_x=clean_x, noise=eps_t, timestep=t)
            eps_hat = model(x=x_t, t=t)
            loss_diff = nn.MSELoss(reduction='sum')(eps_hat, eps_t)
            
            return loss_diff
                    
        def run_epoch(model: nn.Module, dataloader: torch.utils.data.dataloader.DataLoader):
            model.train()
            epoch_loss = 0.0

            for x, _ in dataloader:
                optimizer.zero_grad()
                loss = diffusion_loss(model, x[:, self.input_idx])
                epoch_loss += loss
                loss.backward()
                optimizer.step()
                
            return epoch_loss / len(dataloader.dataset)
        
        if mode == 'train':
            model.train()
            num_epochs = self.config['num_epochs']
            dataloader = DataLoader(self.dataset, batch_size=self.config['batch_size'], shuffle=True)
            learning_rate = 1e-4
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
            
            losses = []
            for i in trange(num_epochs):
                epoch_loss = 0.0
                for x, _ in dataloader:
                    optimizer.zero_grad()
                    loss = diffusion_loss(model, x[:, self.input_idx])
                    epoch_loss += loss
                    loss.backward()
                    optimizer.step()

                train_loss = epoch_loss / len(dataloader.dataset)
                print(f'Epoch: {i}, train_loss: {train_loss}')
                losses.append(train_loss.detach().numpy())
            
            torch.save(model.state_dict(), self.model_path) 
            
        self.model = model
    
    @torch.no_grad()
    def sample_ddpm(self, model: nn.Module, n_samples: int = 256, label=None):
        """
        iteratively denoises pure noise to produce a list of denoised samples at each timestep
        """
        model.eval()
        
        x_pred = []
        x = torch.randn(n_samples, self.input_dim)
        x_pred.append(x)

        for t in reversed(range(0, self.num_timesteps)):
                
            x = self.reverse_diffusion_ddpm(noisy_x=x, model=model, timestep=t)
                
            x_pred.append(x)
        return x_pred
    
    def sample_x(self, batch_size=1):
        x = self.sample_ddpm(model=self.model, n_samples=batch_size)[-1] 
        return x

    def sample_counterfactual_ddpm(self, clean_batch: torch.Tensor, model: nn.Module, classifier: nn.Module, num_noise_steps: int, target_classes: int, classifier_grad_weight: float):
        
        
        classifier.eval()
        self.classifier = classifier
        
        # DEFINE BATCH SIZE AND COUNTERFACTUAL CLASS
        bs = clean_batch.shape[0]

        # COMPUTE CLEAN GRADIENTS FOR THE FIRST STEP
        
        classifier_criterion = lambda x: F.cross_entropy(classifier(x), target_classes)
        clean_batch_copy = torch.nn.Parameter(clean_batch)
        loss = classifier_criterion(clean_batch_copy)
        loss.backward()
        clean_grad = classifier_grad_weight * clean_batch_copy.grad.detach()
        
        # PERFORMING FORWARD DIFFUSION UNTIL NUM_NOISE_STEPS
        eps_t = torch.randn_like(clean_batch)
        next_z = self.forward_diffusion(clean_x=clean_batch, noise=eps_t, timestep=num_noise_steps)
        counterfactuals = [] # total counterfactuals
        counterfactuals.append(clean_batch)
        guided_grads = []  # guided grads at the first step
        unconditional_grads = [] # diffusion grads at the first step
        total_series = [] # contains evolution from noisy to cleaned instance for each data point
        for i in tqdm(range(0, num_noise_steps)[::-1]):
            # Denoise z_t to create z_t-1 (next z)
            alpha_i = self.alpha[i].repeat(bs)[:, None]
            alpha_bar_i = self.alpha_bar[i].repeat(bs)[:, None]
            sigma_i = torch.sqrt(1 - self.alpha[i])
            eps_hat = model(next_z, i)

            # Unconditional mean
            unconditional_grad = -eps_hat / torch.sqrt(1 - alpha_bar_i)
            z_t_mean = (next_z + unconditional_grad * (1 - alpha_i)) / torch.sqrt(alpha_i)

            # Guided mean
            z_t_mean -= sigma_i * (clean_grad / torch.sqrt(alpha_bar_i))

            if i > 0:
                next_z = z_t_mean + (sigma_i * torch.randn_like(clean_batch))
            else:
                next_z = z_t_mean

            next_x = next_z.clone()
            # Denoise to create a cleaned x (next x)
            series = []
            series.append(next_x.detach())
            for t in range(0, i)[::-1]:
                if i == 0:
                    break
                next_x = self.reverse_diffusion_ddpm(noisy_x=next_x, model=model, timestep=t)
                series.append(next_x.detach())
            total_series.append(series)
            guided_grads.append(-sigma_i * clean_grad.detach() / torch.sqrt(alpha_bar_i))
            unconditional_grads.append(unconditional_grad.detach() * (1 - alpha_i) / torch.sqrt(alpha_i) )
            
            
            if i != 0:
                counterfactuals.append(next_x.detach())

            # Gradient wrt denoised image (next_x)
            next_x_copy = torch.nn.Parameter(next_x.clone())
            loss = classifier_criterion(next_x_copy)
            loss.backward()
            clean_classifier_grad = next_x_copy.grad.detach()
            clean_grad = classifier_grad_weight * clean_classifier_grad
            
            #self.counterfactuals = counterfactuals
            #self.guided_grads = guided_grads
            #self.diffusion_grads = unconditional_grads
            #self.pointwise_evolution = total_series 
            
        counterfactuals = torch.stack(counterfactuals).permute(1, 0, 2) 
        guided_grads = torch.stack(guided_grads).permute(1, 0, 2) 
        unguided_grads = torch.stack(unconditional_grads).permute(1, 0, 2)
        
        self.counterfactuals_series = counterfactuals
        
        return counterfactuals, guided_grads, unguided_grads, total_series

    
    def discard_counterfactuals(self, counterfactuals, classifier, target_classes, target_confidence, minimal_counterfactuals, tolerance=0.1):
        
        # compute distance of current minimal_counterefactuals from radius 1.0
        #current_counterfactual_distance_from_manifold = torch.abs((torch.pow(minimal_counterfactuals, 2).sum(dim=-1) - 1.0))
        
        for i in range(len(counterfactuals)):  

            # compute classifier  for all the counterfactuals for each point
            new_counterfactuals_confidence = classifier(counterfactuals[i]).softmax(dim=-1)[:, target_classes[i]]
            
            
            # check if new counterfactuals satisfy the confidence constraint
            new_confidence_satisfied = new_counterfactuals_confidence > target_confidence
            
            new_confidence_satisfied_indices = torch.nonzero(new_counterfactuals_confidence > target_confidence)
            
            
            # 
            #new_tolerance_satisfied = torch.abs((torch.pow(counterfactuals[i], 2).sum(dim=-1) - 1.0)) < tolerance
            #new_tolerance_satisfied_indices = torch.nonzero(torch.abs((torch.pow(counterfactuals[i], 2).sum(dim=-1) - 1.0)) < tolerance)
            
            # check where target confidence is reached AND if new counterfactuals are within the tolerance range
            #indices = torch.nonzero((confidence > target_confidence) 
            #                    (torch.abs((torch.pow(counterfactuals[i], 2).sum(dim=-1) - 1.0)) < tolerance))
            
            
            #new_confidence_and_tolerance_satisfied_indices = torch.nonzero(new_confidence_satisfied & new_tolerance_satisfied)
            
            #current_tolerance_satisfied = current_counterfactual_distance_from_manifold[i] < tolerance
            #classifier(minimal_counterfactuals[i:i+1]).softmax(dim=-1)[0][target_classes[i]]
            current_confidence_satisfied = classifier(minimal_counterfactuals[i:i+1]).softmax(dim=-1)[0][target_classes[i]].item() > target_confidence
            
            # if current counterfactual satisfies confidence and tolerance, maintain status quo 
          
            if new_confidence_satisfied_indices.nelement() != 0:
                print('new confidence satisfied')
                minimal_counterfactuals[i] = counterfactuals[i][new_confidence_satisfied_indices[0].item()]
                
            else:
                print('neither current nor new confidence satisfied')
                minimal_counterfactuals[i] = counterfactuals[i][-1]
                
            # if new target confidence and current target confidence not satisfied
            # but new tolerance is satisfied, then move the point closer
            #elif (new_tolerance_satisfied_indices.nelement() !=0) & (new_confidence_satisfied_indices.nelement() == 0) and (not current_confidence_satisfied):
            #    minimal_counterfactuals[i] = counterfactuals[i][torch.nonzero(new_tolerance_satisfied)[-1].item()]
            
              
            # if current counterfactual is on the manifold but is not actually a counterfactual, 
            # replace it with new counterfactual if there exists any 
            #elif (not current_confidence_satisfied) & (new_confidence_and_tolerance_satisfied_indices.nelement() != 0):
            #    # change this to only include the first where confidence and tolerance is satisfied
            #    minimal_counterfactuals[i] = counterfactuals[i][new_confidence_and_tolerance_satisfied_indices[0].item()]
                
            #else:
            #    continue
            

        return minimal_counterfactuals
        
        
    def edit(
        self,
        x_in: torch.Tensor,
        target_confidence_goal: float,
        target_classes: torch.Tensor,
        classifier: nn.Module
    ) -> Tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]:
        
        self.original_sample = x_in
        #minimal_counterfactuals = torch.zeros(size=x_in.shape)

        scales = self.config['grad_scales']
        noise_steps = self.config['noise_steps_for_counterfactuals']
        
        minimal_counterfactuals = x_in.clone()
        
        for it in range(self.config['num_iterations']):
            for steps in noise_steps:

                for s in scales:

                    counterfactuals, guided_grads, unguided_grads, total_series = self.sample_counterfactual_ddpm(clean_batch=minimal_counterfactuals, model=self.model, classifier=classifier, num_noise_steps=steps, target_classes=target_classes, classifier_grad_weight=s)

                    minimal_counterfactuals = self.discard_counterfactuals(counterfactuals=counterfactuals, classifier=classifier, target_confidence=target_confidence_goal, target_classes=target_classes, minimal_counterfactuals=minimal_counterfactuals)
                    self.counterfactuals = minimal_counterfactuals
                    flip_rate = sum(classifier(minimal_counterfactuals).softmax(dim=-1).argmax(dim=-1) != classifier(x_in).softmax(dim=-1).argmax(dim=-1)) / len(x_in)
                    self.plot_counterfactuals()
                    plt.title(f'Noise Steps: {steps}, Gradient Scale: {s}, Flip Rate: {round(flip_rate.item(),3)}')
                    #plt.show()
        list_counterfactuals = [row_tensor for row_tensor in minimal_counterfactuals]
        diff_latent = x_in - minimal_counterfactuals
        
        confidences = classifier(minimal_counterfactuals).softmax(dim=-1)
        y_target_end_confidence = [confidences[i][target_classes[i]].detach() for i in range(len(minimal_counterfactuals))]
        x_list = [row_tensor for row_tensor in x_in]
        
        #self.counterfactuals = minimal_counterfactuals
        
        return list_counterfactuals, diff_latent, y_target_end_confidence, x_list
    
    
    def plot_counterfactuals(self):
        #plt.figure(figsize=(5,5))
        data = torch.zeros([len(dataset.data),len(dataset.attributes)], dtype=torch.float16)
        for idx, key in enumerate(dataset.data):
            data[idx] = dataset.data[key]
        self.data = data
        plt.scatter(data[:,self.input_idx[0]], data[:,self.input_idx[1]], c=np.where(data[:,self.target_idx] == 0, 'lightcyan', 'lightgray')[0])
        for i, point in enumerate(self.counterfactuals):
            plt.scatter(self.original_sample[i, 0], self.original_sample[i, 1], color='green', label='start')
            plt.scatter(point[0], point[1], color='red', label='end')
            plt.arrow(
                self.original_sample[i,0], self.original_sample[i, 1], # plot the original point plus arrow until (j+granularity)th point
                point[0] - self.original_sample[i, 0], 
                point[1] - self.original_sample[i, 1],
                head_width=0.05, head_length=0.05, fc='blue', ec='blue',

            )
        
        #plt.show()
    
    
    
        