In [1]:
from imblearn.over_sampling import RandomOverSampler, SMOTE
from imblearn.under_sampling import RandomUnderSampler, NearMiss
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
import matplotlib.pyplot as plt
import torch
import numpy as np
import pandas as pd 
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import seaborn as sns
%run gan_.ipynb

class ImbalanceHandler:
    def __init__(self, method = "none", sampler = None, class_weights = None, gan=None, gan_epochs=None, gan_noise=None, columns=None):
        
        self.method = method
        self.sampler = sampler
        self.class_weights = class_weights
        self.gan = gan
        self.gan_epochs= gan_epochs
        self.gan_noise = gan_noise
        self.columns = columns
        


    def report_distributions(self, y):
        
        counts = y.value_counts()
        total = counts.sum()

        report = {}
        for a, b in counts.items():
            percentage = (b / total) * 100
            report[a] = {"count": b, 
                        "percentage": f"{percentage}%"
                        }


        print(report)

        plt.figure(figsize=(3, 2))
        plt.bar(counts.index, counts.values, color = "blue")
        plt.show()
        
        
        
    def resampling(self, X, y):

        if self.method == "none":
            return X, y
        
        elif self.method == "oversampling":
            self.sampler = RandomOverSampler()
            X_resampled, y_resampled = self.sampler.fit_resample(X, y)
            return X_resampled, y_resampled

        elif self.method == "undersampling":
            self.sampler = RandomUnderSampler()
            X_resampled, y_resampled = self.sampler.fit_resample(X, y)
            return X_resampled, y_resampled

        elif self.method == "nearmiss":
            self.sampler = NearMiss(version = 1)
            X_resampled, y_resampled = self.sampler.fit_resample(X, y)
            return X_resampled, y_resampled

        elif self.method == "smote":
            self.sampler = SMOTE()
            X_resampled, y_resampled = self.sampler.fit_resample(X, y)
            return X_resampled, y_resampled

        elif self.method == "gan": 

            
            generator = Generator(noise=64, output=30, hidden = 64)
            discriminator = Discriminator(input_dim = 30, hidden = 64)
            self.gan = Gan(generator, discriminator, lr_g =0.01, lr_d = 0.01)
            
            fraud = len(X[(y==1)])
            non_fraud = len(X[(y==0)])
            need = non_fraud - fraud

            fraud_values = X[(y==1)]
            tensor_x = torch.tensor(fraud_values, dtype=torch.float32)
            tensor_y = torch.tensor(y[(y==1)].to_numpy(), dtype=torch.float32)

            tensor_data = TensorDataset(tensor_x, tensor_y)
            dataloader = DataLoader(tensor_data, batch_size = 64, shuffle = True)

            self.gan.train(dataloader)
        
            synthetic_samples = self.gan.creating_new(need)

            X = pd.DataFrame(X, columns=self.columns)
            y = pd.Series(y).reset_index(drop=True)
            synth_data = pd.DataFrame(synthetic_samples, columns = X.columns)

            synth_data_y = pd.Series([1]*need)
            X_resampled = pd.concat([X, synth_data], ignore_index=True)
            y_resampled = pd.concat([y, synth_data_y], ignore_index=True)
            

            return X_resampled, y_resampled
            
        
        elif self.method == "classweights": #https://www.analyticsvidhya.com/blog/2020/10/improve-class-imbalance-class-weights/

            if self.class_weights is None:
        
                weights = compute_class_weight(class_weight = "balanced", classes = np.unique(y), y = y)
                self.class_weights = dict(zip(np.unique(y), weights))
            
            
            return X, y
            
        else:
            raise ValueError(f"{self.method} is not a valid method!")

    def sampler_(self, y):
        if self.method == "none":
            return None
    
        elif self.method == "oversampling":
            self.sampler = RandomOverSampler()
            return self.sampler

        elif self.method == "undersampling":
            self.sampler = RandomUnderSampler()
            return self.sampler

        elif self.method == "nearmiss":
            self.sampler = NearMiss(version = 1)
            return self.sampler

        elif self.method == "smote":
            self.sampler = SMOTE()
            return self.sampler

        elif self.method == "classweights": #https://www.analyticsvidhya.com/blog/2020/10/improve-class-imbalance-class-weights/

            if self.class_weights is None:
        
                weights = compute_class_weight(class_weight = "balanced", classes = np.unique(y), y = y)
                self.class_weights = dict(zip(np.unique(y), weights))
            
            
            return self.class_weights
            
        else:
            raise ValueError(f"{self.method} is not a valid method!")
