In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim 
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, random_split
import gurobipy as gp
from gurobipy import GRB
from tqdm import tqdm 
from datasets import LFW_Dataset 
from helper import *  
from models import BasicBlock, ResNet

In [None]:
target_attribute = "Smiling"
sens_attribute = "Male"
nominal = False 

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# set parameters 
num_epochs = 20 
lr_theta = 1e-3 
lr_z = 1e-1 
epsilon = 0.01 

In [None]:
dataset = LFW_Dataset(targ=target_attribute, sens_attr=sens_attribute, train=True, vgg=False, sub=False)
train_loader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)
train_sens_1, train_y_1 = dataset.get_values()

In [None]:
model = ResNet(img_channels=3, num_layers=18, block=BasicBlock, num_classes=1)
model = model.to(device)
opt = optim.Adam(model.parameters(), lr=lr_theta) 

In [None]:
z_list = []
z0_train = torch.tensor([0 for i in range(len(train_y_1))])
z0_train = z0_train.to(device)
train_sens_1 = torch.Tensor(list(train_sens_1))
train_y_1 = torch.Tensor(list(train_y_1))
flag = get_flag(train_sens_1, train_y_1)
z0_train = proj_z_unstructured(z0_train, train_sens_1, train_y_1, flag, epsilon)

In [None]:
for _ in range(num_epochs): 
    if nominal: 
        train_err, train_loss = epoch_nominal(train_loader, model, opt=opt)
    else: 
        train_err, train_loss, z = epoch_flipped_unstructured(train_loader, model, z0_train, train_sens_1, train_y_1, epsilon=epsilon, lr_z=lr_z, opt=opt)
        z0_train = z
        print(*("{:.6f}".format(i) for i in (train_err, train_loss)), sep="\t") 

state = {
            'epoch': num_epochs,
            'state_dict': model.state_dict(),
            'optimizer': opt.state_dict(),
        }

if nominal:
    nom_str = 'nominal'
else:
    nom_str = 'flipped'

if sens_attribute == 'Male':
    sens_str = 'gender'
else:
    sens_str = 'race'

savepath='LFW_'+target_attribute+'_checkpoint_'+nom_str+'_'+sens_str+'.pth'
torch.save(state,savepath)