In [1]:
import os,argparse,time
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.optim 
import torch.utils.data
import torch.utils.data.distributed

from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix

from dataloaders.deep_moji import DeepMojiDataset
from networks.deepmoji_sa import DeepMojiModel
from networks.discriminator import Discriminator


from tqdm import tqdm, trange
from networks.customized_loss import DiffLoss

from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

from networks.eval_metrices import group_evaluation, leakage_evaluation

from pathlib import Path, PureWindowsPath

In [2]:
from scripts_deepmoji import adv_train_epoch
from scripts_deepmoji import adv_eval_epoch
from scripts_deepmoji import train_epoch
from scripts_deepmoji import eval_main

In [3]:
class Args:
    use_fp16 = False
    cuda = "cuda"
    hidden_size = 300
    emb_size = 2304
    num_classes = 2
    adv = True
    adv_level = -1
    lr = 0.00003
    LAMBDA = 1
    n_discriminator = 3
    adv_units = 256
    ratio = 0.8
    DL = True
    diff_LAMBDA = 10**(3.7)
    data_path = "D:\\Project\\User_gender_removal\\data\\deepmoji\\split2\\"

args = Args()

In [4]:
# file names
experiment_type = "adv_Diverse"

# DataLoader Parameters
params = {'batch_size': 512,
        'shuffle': True,
        'num_workers': 0}
# Device
device = torch.device("cuda")

data_path = args.data_path
# Load data
train_data = DeepMojiDataset(args, data_path, "train", ratio=args.ratio, n = 100000)
dev_data = DeepMojiDataset(args, data_path, "dev")
test_data = DeepMojiDataset(args, data_path, "test")

# Data loader
training_generator = torch.utils.data.DataLoader(train_data, **params)
validation_generator = torch.utils.data.DataLoader(dev_data, **params)
test_generator = torch.utils.data.DataLoader(test_data, **params)

Loading preprocessed deepMoji Encoded data
Done, loaded data shapes: (99998, 2304), (99998,), (99998,)
Loading preprocessed deepMoji Encoded data
Done, loaded data shapes: (8000, 2304), (8000,), (8000,)
Loading preprocessed deepMoji Encoded data
Done, loaded data shapes: (7998, 2304), (7998,), (7998,)


In [5]:
# Init model
model = DeepMojiModel(args)

model = model.to(device)

# Init discriminators
# Number of discriminators
n_discriminator = args.n_discriminator

discriminators = [Discriminator(args, args.hidden_size, 2) for _ in range(n_discriminator)]
discriminators = [dis.to(device) for dis in discriminators]

diff_loss = DiffLoss()
args.diff_loss = diff_loss

# Init optimizers
LEARNING_RATE = args.lr
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE)

adv_optimizers = [Adam(filter(lambda p: p.requires_grad, dis.parameters()), lr=1e-1*LEARNING_RATE) for dis in discriminators]

# Init learing rate scheduler
scheduler = ReduceLROnPlateau(optimizer, mode = 'max', factor = 0.5, patience = 2)

# Init criterion
criterion = torch.nn.CrossEntropyLoss()

In [6]:
main_model_path = "models\\deepnoji_model_{}.pt".format(experiment_type)
adv_model_path = "models\\discriminator_{}_{}.pt"

In [7]:
best_loss, valid_preds, valid_labels, _ = eval_main(
                                                    model = model, 
                                                    iterator = validation_generator, 
                                                    criterion = criterion, 
                                                    device = device, 
                                                    args = args
                                                    )

best_acc = accuracy_score(valid_labels, valid_preds)
best_epoch = 60

for i in trange(60):
    train_epoch(
                model = model, 
                discriminators = discriminators, 
                iterator = training_generator, 
                optimizer = optimizer, 
                criterion = criterion, 
                device = device, 
                args = args
                )

    valid_loss, valid_preds, valid_labels, _ = eval_main(
                                                        model = model, 
                                                        iterator = validation_generator, 
                                                        criterion = criterion, 
                                                        device = device, 
                                                        args = args
                                                        )
    valid_acc = accuracy_score(valid_preds, valid_labels)
    # learning rate scheduler
    scheduler.step(valid_loss)

    # early stopping
    if valid_loss < best_loss:
        if i >= 5:
            best_acc = valid_acc
            best_loss = valid_loss
            best_epoch = i
            torch.save(model.state_dict(), main_model_path)
    else:
        if best_epoch+5<=i:
            break

    # Train discriminator untile converged
    # evaluate discriminator 
    best_adv_loss, _, _, _ = adv_eval_epoch(
                                            model = model, 
                                            discriminators = discriminators, 
                                            iterator = validation_generator, 
                                            criterion = criterion, 
                                            device = device, 
                                            args = args
                                            )
    best_adv_epoch = -1
    for k in range(100):
        adv_train_epoch(
                        model = model, 
                        discriminators = discriminators, 
                        iterator = training_generator, 
                        adv_optimizers = adv_optimizers, 
                        criterion = criterion, 
                        device = device, 
                        args = args
                        )
        adv_valid_loss, _, _, _ = adv_eval_epoch(
                                                model = model, 
                                                discriminators = discriminators, 
                                                iterator = validation_generator, 
                                                criterion = criterion, 
                                                device = device, 
                                                args = args
                                                )
            
        if adv_valid_loss < best_adv_loss:
                best_adv_loss = adv_valid_loss
                best_adv_epoch = k
                for j in range(args.n_discriminator):
                    torch.save(discriminators[j].state_dict(), adv_model_path.format(experiment_type, j))
        else:
            if best_adv_epoch + 5 <= k:
                break
    for j in range(args.n_discriminator):
        discriminators[j].load_state_dict(torch.load(adv_model_path.format(experiment_type, j)))

model.load_state_dict(torch.load(main_model_path))

 70%|███████   | 42/60 [1:41:29<43:29, 144.99s/it]


<All keys matched successfully>

In [8]:
test_loss, preds, labels, p_labels = eval_main(model, test_generator, criterion, device, args)
preds = np.array(preds)
labels = np.array(labels)
p_labels = np.array(p_labels)

In [9]:
eval_metrices = group_evaluation(preds, labels, p_labels, silence=False)

Accuracy 0: 0.8005
Accuracy 1: 0.6993496748374187
TPR 0: 0.803
TPR 1: 0.7943971985992997
TNR 0: 0.798
TNR 1: 0.6043021510755378
TPR gap: 0.008602801400700355
TNR gap: 0.1936978489244623


In [10]:
(eval_metrices["Accuracy_0"]+eval_metrices["Accuracy_1"])/2

0.7499248374187093