In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

In [2]:
COLUMNS = ["age", "workclass", "edu_level",
           "marital_status", "occupation", "relationship",
           "race", "sex", "hours_per_week",
           "native_country", "income"]

train_df = pd.read_csv(
    filepath_or_buffer="https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data",
    names=COLUMNS,
    engine='python',
    usecols=[0, 1, 4, 5, 6, 7, 8, 9, 12, 13, 14],
    sep=r'\s*,\s*',
    na_values="?"
)

test_df = pd.read_csv(
    filepath_or_buffer="https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test",
    names=COLUMNS,
    skiprows=[0],
    engine='python',
    usecols=[0, 1, 4, 5, 6, 7, 8, 9, 12, 13, 14],
    sep=r'\s*,\s*',
    na_values="?"
)


In [3]:
# Drop rows with missing values
train_df = train_df.dropna(how="any", axis=0)
test_df = test_df.dropna(how="any", axis=0)

# To reduce the complexity, we binarize the attribute
# To reduce the complexity, we binarize the attribute


def mapping(tuple):
    # age, 37
    tuple['age'] = 1 if tuple['age'] > 37 else 0
    # workclass
    tuple['workclass'] = 0 if tuple['workclass'] != 'Private' else 1
    # edu-level
    tuple['edu_level'] = 1 if tuple['edu_level'] > 9 else 0
    # maritial statue
    tuple['marital_status'] = 1 if tuple['marital_status'] == "Married-civ-spouse" else 0
    # occupation
    tuple['occupation'] = 1 if tuple['occupation'] == "Craft-repair" else 0
    # relationship
    tuple['relationship'] = 0 if tuple['relationship'] == "Not-in-family" else 1
    # race
    tuple['race'] = 0 if tuple['race'] != "White" else 1
    # sex
    tuple['sex'] = 0 if tuple['sex'] != "Male" else 1
    # hours per week
    tuple['hours_per_week'] = 1 if tuple['hours_per_week'] > 40 else 0
    # native country
    tuple['native_country'] = 1 if tuple['native_country'] == "United-States" else 0
    # income
    tuple['income'] = 1 if tuple['income'] == '>50K' or tuple['income'] == '>50K.' else 0
    return tuple


train_df = train_df.apply(mapping, axis=1)
test_df = test_df.apply(mapping, axis=1)

In [56]:
train_df["native_country"].mean()

0.911875870300378

In [10]:
train_data = torch.from_numpy(train_df.values)
test_data = torch.from_numpy(test_df.values)
# merge two datasets
dataset = torch.cat((train_data,test_data), 0)

In [11]:
from torch.utils.data import Dataset,DataLoader
class AdultDataset(Dataset):
    def __init__(self, data_set):
        self.x = data_set
        self.len = data_set.size()[0]
    def __getitem__(self,index):
        return self.x[index]
    def __len__(self):
        return self.len
adultDataset = AdultDataset(dataset)
dataLoader = DataLoader(dataset=adultDataset, batch_size=128, shuffle=True)



In [178]:
# a basic Generator


class Generator(nn.Module):
    def __init__(self, f, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        # f is action function
        self.f = f

    def forward(self, x):
        x = self.map1(x)
        x = self.f(x)
        x = self.map2(x)
        x = self.f(x)
        x = self.map3(x)
        return x

# a basic Discriminator

In [134]:
class Discriminator(nn.Module):
    def __init__(self, f, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        # f is action function
        self.f = f

    def forward(self, x):
        x = self.f(self.map1(x))
        x = self.f(self.map2(x))
        return torch.sigmoid(self.map3(x))

In [121]:
class CFGAN(nn.Module):
    def __init__(self, f):
        super(CFGAN, self).__init__()

        self.age_net = Generator(
            f, 1, 2, 1)
        self.workclass_net = Generator(
            f, 5, 10, 1)
        self.edu_level_net = Generator(
            f, 6, 12, 1)
        self.marital_status_net = Generator(
            f, 5, 10, 1)
        self.occupation_net = Generator(
            f, 6, 12, 1)
        self.relationship_net = Generator(
            f, 6, 12, 1)
        self.race_net = Generator(
            f, 1, 2, 1)
        self.sex_net = Generator(
            f, 1, 2, 1)
        self.hours_per_week_net = Generator(
            f, 7, 14, 1)
        self.native_country_net = Generator(
            f, 1, 2, 1)
        self.income_net = Generator(
            f, 11, 22, 1)

    def forward(self, input, intervention=-1):
        name = ["race", "age", "sex", "native_country", "marital_status",
                "edu_level", "occupation", "hours_per_week", "workclass", "relationship", "income"]
        Z = dict(zip(name, input.transpose(0, 1).view(len(name), -1, 1)))

        # hight = 0 in the graph
        # sex should considered about intervention
        if(intervention == -1):
            self.sex = self.race_net(Z["sex"])
        elif(intervention == 0):
            self.sex = torch.zeros(Z["sex"].size())
        else:
            self.sex = torch.ones(Z["sex"].size())
        self.age = self.age_net(Z["age"])
        self.race = self.sex_net(Z["race"])
        self.native_country = self.native_country_net(Z["native_country"])

        # hight = 1 in the graph
        self.marital_status = self.marital_status_net(torch.cat(
            [Z["marital_status"], self.race, self.age,
                self.sex, self.native_country], 1
        ))

        # hight = 2 in the gragh
        self.edu_level = self.edu_level_net(torch.cat(
            [Z["edu_level"], self.race, self.age, self.sex,
                self.native_country, self.marital_status], 1
        ))

        # hight = 3 in the gragh
        self.occupation = self.occupation_net(torch.cat(
            [Z["occupation"], self.race, self.age, self.sex,
                self.marital_status, self.edu_level], 1
        ))

        self.hours_per_week = self.hours_per_week_net(torch.cat(
            [Z["hours_per_week"], self.race, self.age, self.sex,
             self.native_country, self.marital_status, self.edu_level], 1
        ))

        self.workclass = self.workclass_net(torch.cat(
            [Z["workclass"], self.age, self.marital_status,
                self.edu_level, self.native_country], 1
        ))

        self.relationship = self.relationship_net(torch.cat(
            [Z["relationship"], self.age, self.sex, self.native_country,
                self.marital_status, self.edu_level], 1
        ))

        # hight = 4 in the gragh

        self.income = self.income_net(torch.cat(
            [Z["income"], self.race, self.age, self.sex, self.native_country, self.marital_status,
                self.edu_level, self.occupation, self.hours_per_week, self.workclass, self.relationship], 1
        ))

        return torch.cat([self.age, self.workclass, self.edu_level, self.marital_status,
        self.occupation, self.relationship, self.race, self.sex,
        self.hours_per_week, self.native_country, self.income], 1)

In [179]:
num_epochs = 100
g_steps = 5
g2_steps = 50
batch = 128
LR = 0.001
print_interval = 1
# action function
discriminator_activation_function = nn.LeakyReLU(0.2)
generator_activation_function = torch.tanh

# net init
discriminator_1 = Discriminator(
    discriminator_activation_function, 11, 64, 1)
generator = CFGAN(generator_activation_function)

# Binary cross entropy: https://pytorch.org/docs/stable/nn.html?highlight=bceloss#torch.nn.BCELoss
criterion = nn.BCELoss()
# optim
generator_optim = torch.optim.Adam(
    generator.parameters(), lr=LR, betas=(0.9, 0.99))
discriminator_1_optim = torch.optim.Adam(
    discriminator_1.parameters(), lr=LR, betas=(0.9, 0.99))



In [16]:
import copy

# debug
## test for paramaters

In [173]:
data = copy.copy(dataLoader)
for real_data in data:

    # 1A: Train D1 on real
    discriminator_1.zero_grad()
    d_real_data = real_data
    # real data's lable should be true
    d_real_labe = torch.ones(d_real_data.size()[0])
    d_real_decision = discriminator_1(d_real_data.float())
    d_real_loss = criterion(
        torch.squeeze(d_real_decision), d_real_labe)
    d_real_loss.backward()

    # 1B: Train D1 on fake data
    d_fake_data = generator(torch.randn(batch, 11))
    # print(d_fake_data.size())
    d_fake_lable = torch.zeros(batch)
    d_fake_decision = discriminator_1(d_fake_data)
    d_fake_loss = criterion(torch.squeeze(
        d_fake_decision), d_fake_lable)
    d_fake_loss.backward()
    # Only optimizes D1's parameters
    discriminator_1_optim.step()

In [169]:
data = copy.copy(dataLoader)
for real_data in data:

    # 1A: Train D1 on real
    discriminator_1.zero_grad()
    d_real_data = real_data
    # real data's lable should be true
    d_real_labe = torch.ones(d_real_data.size()[0])
    d_real_decision = discriminator_1(d_real_data.float())
    d_real_loss = criterion(
        torch.squeeze(d_real_decision), d_real_labe)
    d_real_loss.backward()

    # 1B: Train D1 on fake data
    d_fake_data = generator(torch.randn(batch, 11))
    # print(d_fake_data.size())
    d_fake_lable = torch.zeros(batch)
    d_fake_decision = discriminator_1(d_fake_data)
    d_fake_loss = criterion(torch.squeeze(
        d_fake_decision), d_fake_lable)
    d_fake_loss.backward()
    # Only optimizes D1's parameters
    discriminator_1_optim.step()
    
    for g_index in range(4):
        # Train G on D's response
        generator.zero_grad()
        g_fake_data = generator(torch.randn(batch, 11))
        d_g_fake_decision = discriminator_1(g_fake_data)
        g_fake_lable = torch.ones(batch)
        g_loss = criterion(torch.squeeze(d_g_fake_decision), g_fake_lable)
        g_loss.backward()
        generator_optim.step()

In [180]:
print(generator(torch.randn(10, 11)))
dis = discriminator_1(generator(torch.randn(batch, 11)))
dis

tensor([[ 0.1253, -0.0360, -0.1340, -0.1488, -0.2294,  0.2546,  0.2110, -0.2229,
         -0.0851, -0.3133, -0.0554],
        [-0.0318,  0.1591, -0.1530,  0.0957, -0.2356,  0.1763,  0.1169, -0.2250,
         -0.0033, -0.3591, -0.0544],
        [-0.0186,  0.1610, -0.1391,  0.0399, -0.2191,  0.2612,  0.3327, -0.2050,
         -0.1094, -0.2992, -0.0640],
        [-0.0609,  0.0481, -0.1449,  0.0693, -0.2487,  0.2059,  0.2075, -0.2491,
         -0.0944, -0.2996, -0.0546],
        [ 0.0174,  0.1076, -0.1303, -0.0503, -0.2395,  0.2269,  0.3320, -0.0618,
         -0.0193, -0.3014, -0.0522],
        [ 0.0957,  0.1393, -0.1297, -0.0981, -0.1831,  0.2016,  0.3509, -0.1763,
         -0.0739, -0.2546, -0.0613],
        [ 0.1455,  0.1808, -0.1272, -0.1302, -0.2531,  0.2288,  0.3741, -0.1687,
         -0.0565, -0.2619, -0.0736],
        [ 0.1853,  0.0538, -0.1528,  0.0675, -0.2008,  0.2613,  0.1255, -0.2005,
         -0.0897, -0.2610, -0.0306],
        [ 0.1193,  0.1260, -0.1226,  0.0014, -0.1626,  0

tensor([[0.4699],
        [0.4710],
        [0.4695],
        [0.4714],
        [0.4712],
        [0.4700],
        [0.4705],
        [0.4700],
        [0.4698],
        [0.4733],
        [0.4686],
        [0.4705],
        [0.4689],
        [0.4707],
        [0.4698],
        [0.4730],
        [0.4693],
        [0.4728],
        [0.4720],
        [0.4722],
        [0.4704],
        [0.4710],
        [0.4717],
        [0.4720],
        [0.4717],
        [0.4712],
        [0.4721],
        [0.4710],
        [0.4695],
        [0.4714],
        [0.4688],
        [0.4714],
        [0.4696],
        [0.4685],
        [0.4692],
        [0.4736],
        [0.4703],
        [0.4710],
        [0.4702],
        [0.4708],
        [0.4685],
        [0.4713],
        [0.4732],
        [0.4697],
        [0.4697],
        [0.4704],
        [0.4696],
        [0.4685],
        [0.4705],
        [0.4723],
        [0.4719],
        [0.4695],
        [0.4691],
        [0.4730],
        [0.4725],
        [0

In [166]:
data = copy.copy(dataLoader)
data = enumerate(data)
i ,real_data_0 = data.__next__()
real_data_0

tensor([[0, 1, 1,  ..., 1, 1, 0],
        [0, 1, 0,  ..., 0, 1, 0],
        [1, 1, 1,  ..., 1, 0, 1],
        ...,
        [1, 1, 0,  ..., 1, 0, 1],
        [0, 1, 1,  ..., 0, 1, 0],
        [0, 1, 1,  ..., 0, 1, 0]])

In [23]:
real_data_0.size()

torch.Size([128, 11])

In [175]:
discriminator_1(real_data_0.float())

tensor([[1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1

In [146]:
 for g_index in range(300):
    # Train G on D's response
    generator.zero_grad()
    g_fake_data = generator(torch.randn(batch, 11))
    d_g_fake_decision = discriminator_1(g_fake_data)
    g_fake_lable = torch.ones(batch)
    g_loss = criterion(torch.squeeze(d_g_fake_decision), g_fake_lable)
    g_loss.backward()
    generator_optim.step()

In [21]:
import copy
for epoch in range(num_epochs):
    # GAN1
    # 1. Train D on real+fake
    # D.zero_grad()
    data = copy.copy(dataLoader)

    for real_data in data:

        # 1A: Train D1 on real
        discriminator_1.zero_grad()
        d_real_data = real_data
        # real data's lable should be true
        d_real_labe = torch.ones(d_real_data.size()[0])
        d_real_decision = discriminator_1(d_real_data.float())
        d_real_loss = criterion(
            torch.squeeze(d_real_decision), d_real_labe)
        d_real_loss.backward()

        # 1B: Train D1 on fake data
        d_fake_data = generator(torch.randn(batch, 11))
        # print(d_fake_data.size())
        d_fake_lable = torch.zeros(batch)
        d_fake_decision = discriminator_1(d_fake_data)
        d_fake_loss = criterion(torch.squeeze(
            d_fake_decision), d_fake_lable)
        d_fake_loss.backward()
        # Only optimizes D1's parameters
        discriminator_1_optim.step()

        drl, dfl = d_real_loss.tolist(), d_fake_loss.tolist()
    for g_index in range(g_steps):
        # Train G on D's response
        generator.zero_grad()

        g_fake_data = generator(torch.randn(batch, 11))
        d_g_fake_decision = discriminator_1(g_fake_data)
        g_fake_lable = torch.ones(batch)
        g_loss = criterion(torch.squeeze(d_g_fake_decision), g_fake_lable)
        g_loss.backward()
        generator_optim.step()
        gl = g_loss.tolist()

    # GAN2
    for g_index in range(g2_steps):
        generator.zero_grad()
        noise_z = torch.randn(batch, 11)
        fake_data = generator(noise_z)
        # O = {race; native country}:(0,0) (0,1) (1,0) (1,1)
        noise_o0 = []
        noise_o1 = []
        noise_o2 = []
        noise_o3 = []

        for index, single_data in enumerate(fake_data):
            if(single_data[7] < 0.5 and single_data[9] < 0.5):
                noise_o0.append(noise_z[index].view(1, -1))
            elif(single_data[7] < 0.5 and single_data[9] >= 0.5):
                noise_o1.append(noise_z[index].view(1, -1))
            elif(single_data[7] >= 0.5 and single_data[9] < 0.5):
                noise_o2.append(noise_z[index].view(1, -1))
            else:
                noise_o3.append(noise_z[index].view(1, -1))
        ge0,ge1,ge2,ge3 = None,None,None,None
        if(len(noise_o0) != 0):
            noise_o0 = torch.cat(noise_o0)
            o0_0_lable = generator(noise_o0, 0)[:, -1]
            o0_1_lable = generator(noise_o0, 1)[:, -1].detach()
            g_error0 = criterion(o0_0_lable, o0_1_lable)
            g_error0.backward()
            ge0 = g_error0.tolist()
        if(len(noise_o1) != 0):
            noise_o1 = torch.cat(noise_o1)
            o1_0_lable = generator(noise_o1, 0)[:, -1]
            o1_1_lable = generator(noise_o1, 1)[:, -1].detach()
            g_error1 = criterion(o1_0_lable, o1_1_lable)
            g_error1.backward()
            ge1 = g_error1.tolist()
        if(len(noise_o2) != 0):
            noise_o2 = torch.cat(noise_o2)
            o2_0_lable = generator(noise_o2, 0)[:, -1]
            o2_1_lable = generator(noise_o2, 1)[:, -1].detach()
            g_error2 = criterion(o2_0_lable, o2_1_lable)
            g_error2.backward()
            ge2 = g_error2.tolist()
        if(len(noise_o3) != 0):
            noise_o3 = torch.cat(noise_o3)
            o3_0_lable = generator(noise_o3, 0)[:, -1]
            o3_1_lable = generator(noise_o3, 1)[:, -1].detach()
            g_error3 = criterion(o3_0_lable, o3_1_lable)
            g_error3.backward()
            ge3 = g_error3.tolist()

        generator_optim.step()

    if epoch % print_interval == 0:
        print("Epoch %s: D (%s real_err, %s fake_err) G_l (%s err) G_0l (%s) G_1l (%s) G_2l (%s) G_3l (%s);" % (
            epoch, drl, dfl, gl, ge0, ge1, ge2, ge3))

Epoch 0: D (0.22125402092933655 real_err, 0.059217192232608795 fake_err) G_l (0.015695534646511078 err) G_0l (None) G_1l (None) G_2l (0.07190127670764923) G_3l (0.07072243094444275);
Epoch 1: D (0.8784480094909668 real_err, 0.14986123144626617 fake_err) G_l (0.05311468243598938 err) G_0l (None) G_1l (None) G_2l (None) G_3l (0.04634556174278259);
Epoch 2: D (0.32248008251190186 real_err, 0.18921847641468048 fake_err) G_l (0.010749738663434982 err) G_0l (None) G_1l (None) G_2l (None) G_3l (0.06491663306951523);
Epoch 3: D (0.2830143868923187 real_err, 0.25805848836898804 fake_err) G_l (0.02052653394639492 err) G_0l (None) G_1l (None) G_2l (None) G_3l (0.029447432607412338);
Epoch 4: D (0.08148318529129028 real_err, 0.15486261248588562 fake_err) G_l (0.0037686177529394627 err) G_0l (None) G_1l (None) G_2l (None) G_3l (0.09675496816635132);
Epoch 5: D (0.6273917555809021 real_err, 0.5055474638938904 fake_err) G_l (0.02917487919330597 err) G_0l (None) G_1l (None) G_2l (None) G_3l (0.0302715

KeyboardInterrupt: 