In [None]:
from google.colab import drive
import os
drive.mount('/content/gdrive', force_remount=True) # My Drive
# Change to your local path if needed
 
path = '/content/gdrive/My Drive'
os.chdir(path)

Mounted at /content/gdrive


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import random

from torch.utils.data import Dataset, DataLoader
import matplotlib.image as img
from sklearn.model_selection import train_test_split

import torchvision.models as models



In [None]:
def threshold_contrastive_trans(feats,label_feats,thresh=2,tau=0.25,eps=1e-7):
    #this assumes that label_feats contain relevant feature vectors.
    #feats are expected to be of size (batch_size x feat_length)

    # pos_ix = (labels == 1).nonzero()
    # neg_ix = (labels == 0).nonzero()
    # pos_feats = feats[pos_ix]
    # neg_feats = feats[neg_ix]

    #perform masking and thresholding on perceptual features
    #reshape
    label_feats1 = torch.reshape(label_feats,(label_feats.shape[0],1,-1))
    label_feats2 = torch.reshape(label_feats,(1,label_feats.shape[0],-1))

    dist = torch.mean(torch.abs(label_feats1-label_feats2),dim=-1)
    # print(torch.mean(dist))
    # print('*************')
    # print(dist)

    mask_same = (dist<thresh).type(torch.uint8) 
    mask_diff = (dist>=thresh).type(torch.uint8) 

    #reshape all for broadcasting
    feats1 = torch.reshape(feats,(feats.shape[0],1,-1))
    feats2 = torch.reshape(feats,(1,feats.shape[0],-1))

    dot = torch.sum(torch.abs(feats1-feats2),dim=-1)
    sim = dot#/(pos_norm1*pos_norm2)
    met = sim#torch.exp(pos_sim/tau)

    num = torch.sum(mask_same*met)
    den = torch.sum(mask_diff*met)

    # #reshape for broadcasting
    # pos_feats1 = torch.reshape(pos_feats,(pos_feats.shape[0],1,-1))
    # pos_feats2 = torch.reshape(pos_feats,(1,pos_feats.shape[0],-1))

    # neg_feats1 = torch.reshape(neg_feats,(neg_feats.shape[0],1,-1))
    # neg_feats2 = torch.reshape(neg_feats,(1,neg_feats.shape[0],-1))

    # #compute norms
    # pos_norm1 = torch.norm(pos_feats1,p=2,dim=-1)
    # pos_norm2 = torch.norm(pos_feats2,p=2,dim=-1)

    # neg_norm1 = torch.norm(neg_feats1,p=2,dim=-1)
    # neg_norm2 = torch.norm(neg_feats2,p=2,dim=-1)

    # #compute positive similarity contrasts
    # pos_dot = torch.sum(torch.abs(pos_feats1-pos_feats2),dim=-1)
    # pos_sim = pos_dot#/(pos_norm1*pos_norm2)
    # pos_met = pos_sim#torch.exp(pos_sim/tau)

    # #compute negative similarity contrasts
    # neg_dot = torch.sum(torch.abs(neg_feats1-neg_feats2),dim=-1)
    # neg_sim = neg_dot#/(neg_norm1*neg_norm2)
    # neg_met = neg_sim#torch.exp(neg_sim/tau)

    # #compute cross similarity contrasts
    # cross_dot = torch.sum(torch.abs(pos_feats1-neg_feats2),dim=-1)
    # cross_sim = cross_dot#/(pos_norm1*neg_norm2)
    # cross_met = cross_sim#torch.exp(cross_sim/tau)

    # #start computing the numerators and denominators
    # pos_num = torch.sum(pos_met,dim=-1)
    # neg_num = torch.sum(neg_met,dim=-1)

    # pos_den = torch.sum(cross_met,dim=-1)
    # neg_den = torch.sum(cross_met,dim=0)

    # #compute log terms
    # pos_log = torch.log(pos_num/(pos_den+eps))
    # neg_log = torch.log(neg_num/(neg_den+eps))

    # print(pos_log.shape,neg_log.shape,pos_den.shape,neg_den.shape)
    # loss = -1*(torch.sum(pos_log)+torch.sum(neg_log))

    loss = -1*torch.log(num/(den+eps))
    return loss

In [None]:
#Fix random seed
sd = 0
np.random.seed(sd)
torch.backends.cudnn.deterministic = True
torch.manual_seed(sd)
random.seed(sd)
if torch.cuda.is_available():
  torch.cuda.manual_seed_all(sd)

In [None]:
# dict to map given label to a number
dict_age_to_number = {'0-2': 0,
                      '3-9' : 1,
                      '10-19' : 2,
                      '20-29' : 3,
                      '30-39' : 4,
                      '40-49' : 5,
                      '50-59' : 6,
                      '60-69' : 7,
                      'more than 70' : 8}

dict_gender_to_number = {'Male' : 0, 
                        'Female': 1}

dict_race_to_number = {'Black' : 0,
                       'White' : 1}

In [None]:
data_path = './UTKFace/'

data_labels = os.listdir('UTKFace')

clean_labels = []

age = []
age2 = []
gender = []
race = []
for f in data_labels:
    temp = f.split('_')
    if len(temp[2])>1:
        continue
    age.append(temp[0])
    age2.append(int(temp[0]))
    gender.append(temp[1])
    race.append(temp[2])
    clean_labels.append(f)



In [None]:
#Assign labels to samples
age_class = []
for a in age2:
    if a<=2 and a>=0:
        age_class.append('0-2')
    elif a>=3 and a<=9:
        age_class.append('3-9')
    elif a>=10 and a<=19:
        age_class.append('10-19')
    elif a>=20 and a<=29:
        age_class.append('20-29')
    elif a>=30 and a<=39:
        age_class.append('30-39')
    elif a>=40 and a<=49:
        age_class.append('40-49')
    elif a>=50 and a<=59:
        age_class.append('50-59')
    elif a>=60 and a<=69:
        age_class.append('60-69')
    elif a>=70:
        age_class.append('more than 70')
    else:
        print('ErrorA')
        print(a)
        
gender_class = []
for g in gender:
    if g=='0':
        gender_class.append('Male')
    elif g=='1':
        gender_class.append('Female')
    else:
        print('ErrorG')
        print(g)
        
race_class = []
for g in race:
    if g=='0':
        race_class.append('White')
    elif g=='1':
        race_class.append('Black')
    elif g=='2':
        race_class.append('Asian')
    elif g=='3':
        race_class.append('Indian')
    elif g=='4':
        race_class.append('Other')
    else:
        print('ErrorR')
        print(g)



In [None]:
#create pandas dataframe
df = {'file': clean_labels, 'age': age_class, 'gender': gender_class, 'race': race_class}

df = pd.DataFrame(data=df)

In [None]:
# categories
age_list = ['3-9', '10-19', '20-29', '30-39', '40-49', '50-59', '60-69', 'more than 70']

In [None]:
train_labels = df

In [None]:
#Dataloader
class UTKFaceDataset(Dataset):
    def __init__(self, data, path , transform = None):
        super().__init__()
        self.data = data.values
        self.path = path
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,index):
        img_name = self.data[index][0]
        label = dict_gender_to_number[self.data[index][2]]   # index 3 for race, need as tensor -> convert to number from str first
        label = torch.tensor(label)
        img_path = os.path.join(self.path, img_name)
        image = img.imread(img_path)

        #group label
        gp_label = dict_race_to_number[self.data[index][3]]
        gp_label = torch.tensor(gp_label)
        
        if self.transform is not None:
            image = self.transform(image)
        return image, label, gp_label

In [None]:
#Transforms go here
train_transform = transforms.Compose([transforms.ToTensor()])
test_transform = transforms.Compose([transforms.ToTensor()])
valid_transform = transforms.Compose([transforms.ToTensor()])

In [None]:
#functions for data resampling
def resample_dataset_race(data,frac):
    flagsD = data['race']=='Black'
    flagsL = data['race']=='White'
    data_D = data[flagsD]
    data_L = data[flagsL]
    
    data_D = equalize_dataset_gender(data_D,0.5)
    data_L = equalize_dataset_gender(data_L,0.5)
    
    baseline = min(len(data_D),len(data_L))-2
    
    data_D = data_D.sort_values('gender')
    data_L = data_L.sort_values('gender')
    remD = len(data_D)-int((frac)*baseline)
    tempD = data_D[int(0.5*remD):-int(0.5*remD)]
    remL = len(data_L)-int((1-frac)*baseline)
    tempL = data_L[int(0.5*remL):-int(0.5*remL)]
    
    print(baseline,remD,remL)
    
    frames = [tempD,tempL]
    final_split = pd.concat(frames)
    final_split = final_split.sample(frac=1)
    
    return final_split

def resample_dataset_race_old(data,frac):
    flagsD = data['race']=='Black'
    flagsL = data['race']=='White'
    data_D = data[flagsD]
    data_L = data[flagsL]
    
    
    baseline = min(len(data_D),len(data_L))
    
    tempD = data_D[0:int(frac*baseline)]
    tempL = data_L[0:int((1-frac)*baseline)]
    
    frames = [tempD,tempL]
    final_split = pd.concat(frames)
    final_split = final_split.sample(frac=1)
    
    return final_split

def equalize_dataset_gender(data,frac):
    flagsM = data['gender']=='Male'
    flagsF = data['gender']=='Female'
    data_M = data[flagsM]
    data_F = data[flagsF]
    baseline = min(len(data_M),len(data_F))
    
    tempM = data_M[0:int(baseline)]
    tempF = data_F[0:int(baseline)]
    frames = [tempM,tempF]
    final_split = pd.concat(frames)
    final_split = final_split.sample(frac=1)
    
    return final_split


#functions for data
def resample_dataset_equal(data):
    flagsD = data['race']=='Black'
    flagsL = data['race']=='White'
    data_D = data[flagsD]
    data_L = data[flagsL]
    
    data_D = equalize_dataset_gender(data_D,0.5)
    data_L = equalize_dataset_gender(data_L,0.5)
    
    baseline = min(len(data_D),len(data_L))
    
    data_D = data_D.sort_values('gender')
    data_L = data_L.sort_values('gender')
    
    remD = len(data_D)-int((0.5)*baseline)
    tempD = data_D[int(0.5*remD):-int(0.5*remD)]
    remL = len(data_L)-int((0.5)*baseline)
    tempL = data_L[int(0.5*remL):-int(0.5*remL)]

    return tempD,tempL

In [None]:
#Split datasets
batch_size = 30

# def resample_dataset_gender(data,frac):
#     flagsM = data['gender']=='Male'
#     flagsF = data['gender']=='Female'
#     data_M = data[flagsM]
#     data_F = data[flagsF]
#     baseline = min(len(data_M),len(data_F))
    
#     tempM = data_M[0:int(frac*baseline)]
#     tempF = data_F[0:int((1-frac)*baseline)]
#     frames = [tempM,tempF]
#     final_split = pd.concat(frames)
#     final_split = final_split.sample(frac=1)
    
#     return final_split


train_split_t, val_split_t_dash = train_test_split(train_labels, stratify=train_labels.gender, test_size=0.2)
val_split_t, test_labels = train_test_split(val_split_t_dash, stratify=val_split_t_dash.gender, test_size=0.5)

#Ensure val set has equal representation
flagsD = val_split_t['race']=='Black'
flagsL = val_split_t['race']=='White'
data_L = val_split_t[flagsL]
data_L = equalize_dataset_gender(data_L,0.5)

data_D = val_split_t[flagsD]
data_D = equalize_dataset_gender(data_D,0.5)

val_split = data_L.sample(frac=1) 

test_split_D,test_split_L = resample_dataset_equal(test_labels)

print(val_split['gender'].value_counts())
print(val_split['race'].value_counts())

print(test_split_D['gender'].value_counts())
print(test_split_D['race'].value_counts())

print(test_split_L['gender'].value_counts())
print(test_split_L['race'].value_counts())

#dataloaders
valid_data = UTKFaceDataset(val_split, data_path, valid_transform )
test_data_D = UTKFaceDataset(test_split_D, data_path, test_transform )
test_data_L = UTKFaceDataset(test_split_L, data_path, test_transform )
valid_loader = DataLoader(dataset = valid_data, batch_size = batch_size, shuffle=False, num_workers=0)
test_loader_D = DataLoader(dataset = test_data_D, batch_size = batch_size, shuffle=False, num_workers=0)
test_loader_L = DataLoader(dataset = test_data_L, batch_size = batch_size, shuffle=False, num_workers=0)

Female    446
Male      446
Name: gender, dtype: int64
White    892
Name: race, dtype: int64
Female    112
Male      112
Name: gender, dtype: int64
Black    224
Name: race, dtype: int64
Female    112
Male      112
Name: gender, dtype: int64
White    224
Name: race, dtype: int64


In [None]:
from re import X
#Model
from torch.autograd import Variable

def swish(x):
    return F.relu(x)

class Network1(nn.Module):

    def __init__(self,D_out=2,dtype = torch.FloatTensor,device = 'cpu'):
        super().__init__()

        self.D_out = D_out

        model1 = models.resnet34(pretrained=False)
        # model.fc = nn.Linear(512, num_classes)
        newmodel1 = torch.nn.Sequential(*(list(model1.children())[:-1]))

        model2 = models.resnet34(pretrained=False)
        # model2.fc = nn.Linear(512, 512*D_out)
        model2.fc = nn.Linear(512, (512+1)*D_out)
        

        self.head1 = newmodel1
        self.head2 = model2

        self.w = Variable(torch.randn(1, 512, D_out).type(dtype), requires_grad=True).to(device)
        self.b = Variable(torch.randn(1, D_out).type(dtype), requires_grad=True).to(device)

    def forward(self,x):

        x1 = self.head1(x) ##Feature head
        x2 = self.head2(x) ##Offset head

        #reshape x1 and x2
        x1 = torch.reshape(x1,(x1.shape[0],1,x1.shape[1]))

        ttemp = x2[:,512*self.D_out:]

        x2 = torch.reshape(x2[:,0:512*self.D_out],(x2.shape[0],x1.shape[2],-1))
        # print(x1.shape,x2.shape)
        # print(x1-x2)
        # print(ttemp.shape)
        update_bias = ttemp+torch.tile(self.b, (x1.shape[0],1))

        update_wt = x2+torch.tile(self.w, (x1.shape[0],1,1))

        x = torch.matmul(x1,update_wt)

        x = torch.reshape(x,(x.shape[0],-1))+update_bias#self.b
        # print(x.shape,x1.shape,x2.shape)

        return x, x1, x2


class Network2(nn.Module):

    def __init__(self,D_out=2,dtype = torch.FloatTensor,device = 'cpu'):
        super().__init__()

        model1 = models.resnet34(pretrained=False)
        # model.fc = nn.Linear(512, num_classes)
        newmodel1 = torch.nn.Sequential(*(list(model1.children())[:-1]))

        model2 = models.resnet34(pretrained=False)
        model2.fc = nn.Linear(512, 512)
        

        self.head1 = newmodel1
        self.head2 = model2

        self.w = Variable(torch.randn(1, 512, D_out).type(dtype), requires_grad=True).to(device)
        self.b = Variable(torch.randn(1, D_out).type(dtype), requires_grad=True).to(device)

    def forward(self,x):

        x1 = self.head1(x) ##Feature head
        x2 = self.head2(x) ##Offset head

        #reshape x1 and x2
        x1 = torch.reshape(x1,(x1.shape[0],1,x1.shape[1]))
        x2 = torch.reshape(x2,(x2.shape[0],1,x2.shape[1]))
        print(x1.shape,x2.shape)
        # print(x1-x2)
        update_bias = torch.tile(self.b, (x1.shape[0],1))

        update_wt = torch.tile(self.w, (x1.shape[0],1,1))

        x = torch.matmul(x1+x2,update_wt)

        x = torch.reshape(x,(x.shape[0],-1))+self.b
        # print(x.shape,x1.shape,x2.shape)

        return x, x1, x2


####Feature transformation networks
class NetworkFC(nn.Module):

    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(2048, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 1)

        self.bn1 = nn.BatchNorm1d(256)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(64)

    def forward(self,x):

        x = self.bn1(F.relu(self.fc1(x)))
        x = self.bn2(F.relu(self.fc2(x)))
        x = self.bn3(F.relu(self.fc3(x)))
        x = F.sigmoid(self.fc4(x))

        return x


class NetworkTransform(nn.Module):

    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(2048, 2048)
        self.fc2 = nn.Linear(2048, 2048)
        self.fc3 = nn.Linear(2048, 2048)

        self.fc4 = nn.Linear(2048, 2048)
        self.fc5 = nn.Linear(2048, 2048)
        self.fc6 = nn.Linear(2048, 2048)

        self.bn1 = nn.BatchNorm1d(2048)
        self.bn2 = nn.BatchNorm1d(2048)

        self.bn4 = nn.BatchNorm1d(2048)
        self.bn5 = nn.BatchNorm1d(2048)

    def forward(self,x):

        x = self.bn1(F.relu(self.fc1(x)))
        x = self.bn2(F.relu(self.fc2(x)))
        x_temp = self.fc3(x)

        x = self.bn4(F.relu(self.fc4(x_temp)))
        x = self.bn5(F.relu(self.fc5(x)))
        x = self.fc6(x)

        return x_temp, x

In [None]:
# CPU or GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [None]:
def test_performance_trans(model,group_disc,model_trans,dataL,criterion):

    model.eval()
    model.to(device)

    test_loss = 0
    test_acc = 0
    test_acc2 = 0
    test_acc3 = 0
    temp_test_acc = []

    for data, target, grp in dataL:

        data = data.to(device)
        target = target.to(device)
        grp = grp.to(device)

        target_dash = torch.zeros((data.shape[0])).to(device)

        output_dash = group_disc(data)                    # forward pass
        output_dash = torch.reshape(output_dash,(output_dash.shape[0],output_dash.shape[1])).detach()

        output_detangle, output_est = model_trans(output_dash)
        output = model(output_detangle) 

        loss = criterion(output.squeeze(-1).float(), target.float())
        # update-average-validation-loss 
        test_loss += loss.item() * data.size(0)

        op_temp = output.squeeze(-1).detach().cpu().numpy()
        op_temp = (op_temp>0.5).astype(np.uint8)

        test_acc += np.mean(op_temp==target.detach().cpu().numpy())*data.size(0)
        test_acc2 += np.mean(op_temp==target_dash.detach().cpu().numpy())*data.size(0)
        test_acc3 += np.mean((output_est.detach().cpu().numpy()-output_dash.detach().cpu().numpy())**2)*data.size(0)

    ttacc  = test_acc/len(dataL.sampler)
    ttacc2  = test_acc2/len(dataL.sampler)
    ttacc3  = test_acc3/len(dataL.sampler)
    test_loss_M = test_loss/len(dataL.sampler)
    
    test_print = 'Test Loss: {:.3f} \tTest Acc1: {:.3f} \t2: {:.3f} \t3: {:.3f}'.format(
        test_loss_M, ttacc, ttacc2, ttacc3)

    print(test_print)
    return test_print,ttacc

In [None]:
def write_file(fname,string,act):
    with open(fname, act) as text_file:       
        text_file.write(string+'\n')

###Combined iteration
def train_model_trans(dark_frac,train_split_t,valid_loader,test_loader_D,test_loader_L):

    num_epochs = 60
    num_classes = 2 
    learning_rate = 0.0005

    check_point_dir = 'October_perceptual_features_best_contrastive_bias_add_'+str(dark_frac)
    
    if not os.path.isdir(f"checkpoints/"+check_point_dir):
        os.makedirs(f"checkpoints/"+check_point_dir)
        print("Output directory is created")
        
    #make logger text file
    text_path = f"checkpoints/"+check_point_dir+"/"+"log.txt"
    try:
        os.remove(text_path)
    except OSError:
        pass
    
    write_file(text_path,'********* Dark fraction: {} *********'.format(dark_frac),'a')
    
    train_split = resample_dataset_race(train_split_t,dark_frac)
    
    write_file(text_path,str(train_split['race'].value_counts()),'a')
    
    write_file(text_path,str(train_split['gender'].value_counts()),'a')
    
    #Dataloaders
    train_data = UTKFaceDataset(train_split, data_path, train_transform )

    train_loader = DataLoader(dataset = train_data, batch_size = batch_size, shuffle=True, num_workers=0)
    
    group_disc = models.resnet50(pretrained=True)
    # group_disc = torch.hub.load("pytorch/vision", "resnext101_64x4d", weights="IMAGENET1K_V1")
    group_disc = torch.nn.Sequential(*(list(group_disc.children())[:-1]))
    # group_disc.fc = nn.Linear(512, 2)

    for param in group_disc.parameters():
        param.requires_grad = False
    group_disc.to(device)

    model = NetworkFC()
    model.to(device)

    criterion = nn.BCELoss()
    criterion_MSE = nn.MSELoss()

    model_trans = NetworkTransform()
    model_trans.to(device)

    lam1 = 3
    lam2 = 3

    optimizer_trans = torch.optim.AdamW(
        model_trans.parameters(), 
        lr=learning_rate, 
        betas=(0.5, 0.999), 
        weight_decay=0.08
        )
    
    optimizer_disc = torch.optim.AdamW(
        model.parameters(), 
        lr=learning_rate, 
        betas=(0.5, 0.999), 
        weight_decay=0.08
        )

    scheduler_trans = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_trans, T_max=30, 
        eta_min=0.01 * learning_rate, verbose=True
        )
    
    scheduler_disc = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_disc, T_max=30, 
        eta_min=0.01 * learning_rate, verbose=True
        )
    
    
    # Actual training of model

    train_losses = []
    valid_losses = []

    train_accuracies = []
    val_accuracies = []

    valid_accuracy = []
    test_accuracy_D = []
    test_accuracy_L = []


    print("Training model...")

    gan_thresh = 0
    flg_ctr = 0
    flg = 0

    best_val_acc = 0

    for epoch in range(1, num_epochs+1):
        # keep track of train/val loss
        train_loss = 0.0
        valid_loss = 0.0

        ##Set flags

        if epoch>gan_thresh:
            flg_ctr +=1
            if flg_ctr%4==0:
                flg = (flg+1)%2
                flg_ctr = 0


        # training the model
        model.train()
        group_disc.eval()
        temp_train_acc = 0.0
        temp_train_acc2 = 0.0
        temp_train_acc3 = 0.0
        for data, target, grp in train_loader:

            data = data.to(device)
            target = target.to(device)
            grp = grp.to(device)

            target_dash = torch.zeros((data.shape[0])).to(device)

            optimizer_trans.zero_grad()  
            optimizer_disc.zero_grad()  

            output_dash = group_disc(data)                    # forward pass
            output_dash = torch.reshape(output_dash,(output_dash.shape[0],output_dash.shape[1])).detach()

            if flg==0: #this means run the generator updates
                output_detangle, output_est = model_trans(output_dash)
                output = model(output_detangle) 

                loss = criterion_MSE(output_dash,output_est)+lam1*criterion(output.squeeze(-1).float(), target_dash.float())

                loss.backward()                         # loss backwards
                optimizer_trans.step()                        # update model params

            if flg==1: #this means run the discriminator updates
                output_detangle, output_est = model_trans(output_dash)
                output = model(output_detangle) 

                loss = lam2*criterion(output.squeeze(-1).float(), target.float())

                loss.backward()                         # loss backwards
                optimizer_disc.step()                        # update model params

            # print('here')
            train_loss += loss.item() * data.size(0)

            op_temp = output.squeeze(-1).detach().cpu().numpy()
            op_temp = (op_temp>0.5).astype(np.uint8)
            # print(op_temp, target)

            temp_train_acc += np.mean(op_temp==target.detach().cpu().numpy())*data.size(0)
            temp_train_acc2 += np.mean(op_temp==target_dash.detach().cpu().numpy())*data.size(0)
            temp_train_acc3 += np.mean((output_est.detach().cpu().numpy()-output_dash.detach().cpu().numpy())**2)*data.size(0)
            
        
        # validate-the-model
        model.eval()
        group_disc.eval()
        temp_val_acc = 0.0
        temp_val_acc2 = 0.0
        temp_val_acc3 = 0.0
        for data, target, grp in valid_loader:

            data = data.to(device)
            target = target.to(device)
            grp = grp.to(device)

            target_dash = torch.zeros((data.shape[0])).to(device)

            ##Validation of the main model
            output_dash = group_disc(data)                    # forward pass
            output_dash = torch.reshape(output_dash,(output_dash.shape[0],output_dash.shape[1])).detach()

            output_detangle, output_est = model_trans(output_dash)
            output = model(output_detangle) 

            loss = criterion(output.squeeze(-1).float(), target.float())

            # update-average-validation-loss 
            valid_loss += loss.item() * data.size(0)

            op_temp = output.squeeze(-1).detach().cpu().numpy()
            op_temp = (op_temp>0.5).astype(np.uint8)

            temp_val_acc += np.mean(op_temp==target.detach().cpu().numpy())*data.size(0)
            temp_val_acc2 += np.mean(op_temp==target_dash.detach().cpu().numpy())*data.size(0)
            temp_val_acc3 += np.mean((output_est.detach().cpu().numpy()-output_dash.detach().cpu().numpy())**2)*data.size(0)

        tvacc  = np.mean(np.array(temp_val_acc))

        if tvacc>best_val_acc:
            best_val_acc = tvacc
            # torch.save(model.state_dict(), f"checkpoints/"+check_point_dir+"/model_best.pt")
            print('Model saved')
            write_file(text_path,'Model saved','a')

        # calculate-average-losses
        train_loss = train_loss/len(train_loader.sampler)
        valid_loss = valid_loss/len(valid_loader.sampler)
        
        ttacc  = temp_train_acc/len(train_loader.sampler)
        ttacc2  = temp_train_acc2/len(train_loader.sampler)
        ttacc3  = temp_train_acc3/len(train_loader.sampler)

        tvacc  = temp_val_acc/len(valid_loader.sampler)
        tvacc2  = temp_val_acc2/len(valid_loader.sampler)
        tvacc3  = temp_val_acc3/len(valid_loader.sampler)
        
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)

        train_accuracies.append(ttacc)
        val_accuracies.append(tvacc)


        if flg==0: #this means run the generator updates
            scheduler_trans.step()

        if flg==1: #this means run the discriminator updates
            scheduler_disc.step()

        # print-training/validation-statistics 
        train_print = 'Epoch: {} \tTr Loss: {:.3f} \tTr Acc1: {:.3f}, \t2: {:.3f}, \t3: {:.3f} \tVal Loss: {:.3f} \tVal Acc1: {:.3f} \t2: {:.3f} \t3: {:.3f}'.format(
            epoch, train_loss, ttacc, ttacc2, ttacc3, valid_loss, tvacc, tvacc2, tvacc3)
        print(train_print)

        test_print_D, ttacc_D = test_performance_trans(model,group_disc,model_trans,test_loader_D,criterion)
        test_print_L, ttacc_L = test_performance_trans(model,group_disc,model_trans,test_loader_L,criterion)
        
        valid_accuracy.append(tvacc)
        test_accuracy_D.append(ttacc_D)
        test_accuracy_L.append(ttacc_L)

        write_file(text_path,train_print,'a')
#         with open(text_path, "w") as text_file:
#             text_file.write(train_print)
        
        write_file(text_path,test_print_D,'a')
#         with open(text_path, "w") as text_file:
#             text_file.write(test_print_D)
        
        write_file(text_path,test_print_L,'a')
#         with open(text_path, "w") as text_file:
#             text_file.write(test_print_L)
            
    path_val = f"checkpoints/"+check_point_dir+"/"+"validation_accuracy"
    path_D = f"checkpoints/"+check_point_dir+"/"+"test_accuracy_D"
    path_L = f"checkpoints/"+check_point_dir+"/"+"test_accuracy_L"
    valid_accuracy = np.array(valid_accuracy)
    test_accuracy_D = np.array(test_accuracy_D)
    test_accuracy_L = np.array(test_accuracy_L)
    # np.save(path_val, valid_accuracy)
    # np.save(path_D, test_accuracy_D)
    # np.save(path_L, test_accuracy_L)

    return model_trans, group_disc
    



In [None]:
male_fracs = [0.5]#np.linspace(0.0,1.0,11)
for male_frac in male_fracs:
    print('********* Male fraction: {} *********'.format(male_frac))
    model_trans, group_disc = train_model_trans(male_frac,train_split_t,valid_loader,test_loader_D,test_loader_L)

********* Male fraction: 0.5 *********
3512 1758 5616




Adjusting learning rate of group 0 to 5.0000e-04.
Adjusting learning rate of group 0 to 5.0000e-04.
Training model...


  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()


Model saved
Adjusting learning rate of group 0 to 4.9864e-04.
Epoch: 1 	Tr Loss: 2.325 	Tr Acc1: 0.500, 	2: 0.397, 	3: 0.155 	Val Loss: 0.693 	Val Acc1: 0.507 	2: 0.289 	3: 0.443
Test Loss: 0.784 	Test Acc1: 0.554 	2: 0.366 	3: 0.442
Test Loss: 0.775 	Test Acc1: 0.478 	2: 0.353 	3: 0.430
Model saved
Adjusting learning rate of group 0 to 4.9459e-04.
Epoch: 2 	Tr Loss: 2.263 	Tr Acc1: 0.482, 	2: 0.354, 	3: 0.104 	Val Loss: 0.655 	Val Acc1: 0.627 	2: 0.320 	3: 0.398
Test Loss: 0.971 	Test Acc1: 0.562 	2: 0.312 	3: 0.396
Test Loss: 0.922 	Test Acc1: 0.576 	2: 0.344 	3: 0.394
Adjusting learning rate of group 0 to 4.8789e-04.
Epoch: 3 	Tr Loss: 2.204 	Tr Acc1: 0.499, 	2: 0.347, 	3: 0.047 	Val Loss: 0.695 	Val Acc1: 0.492 	2: 0.526 	3: 0.516
Test Loss: 0.703 	Test Acc1: 0.424 	2: 0.567 	3: 0.513
Test Loss: 0.690 	Test Acc1: 0.554 	2: 0.571 	3: 0.510
Model saved
Adjusting learning rate of group 0 to 4.9864e-04.
Epoch: 4 	Tr Loss: 1.369 	Tr Acc1: 0.784, 	2: 0.492, 	3: 0.515 	Val Loss: 0.464 	Va

In [None]:
def test_performance(model,dataL,criterion):

    model.eval()
    model.to(device)

    test_loss = 0
    test_acc = 0
    temp_test_acc = []

    for data, target, grp in dataL:

        data = data.to(device)
        target = target.to(device)
        grp = grp.to(device)

        output,_,_ = model(data)

        loss = criterion(output, target)
        # update-average-validation-loss 
        test_loss += loss.item() * data.size(0)

        op_temp = output.detach().cpu().numpy()
        op_temp = np.argmax(op_temp,axis=1)

        test_acc += np.mean(op_temp==target.detach().cpu().numpy())*data.size(0)

    ttacc  = test_acc/len(dataL.sampler)
    test_loss_M = test_loss/len(dataL.sampler)
    
    test_print = 'Test Loss: {:.3f} \tTest Acc: {:.3f}'.format(
        test_loss_M, ttacc)

    print(test_print)
    return test_print,ttacc

In [None]:
def write_file(fname,string,act):
    with open(fname, act) as text_file:
        text_file.write(string+'\n')

###Combined iteration
def train_model(dark_frac,train_split_t,valid_loader,test_loader_D,test_loader_L,model_trans,group_disc):

    num_epochs = 60
    num_classes = 2 
    learning_rate = 0.0005

    check_point_dir = 'October_perceptual_features_best_contrastive_bias_add_'+str(dark_frac)
    
    if not os.path.isdir(f"checkpoints/"+check_point_dir):
        os.makedirs(f"checkpoints/"+check_point_dir)
        print("Output directory is created")
        
    #make logger text file
    text_path = f"checkpoints/"+check_point_dir+"/"+"log.txt"
    try:
        os.remove(text_path)
    except OSError:
        pass
    
    write_file(text_path,'********* Dark fraction: {} *********'.format(dark_frac),'a')
    
    train_split = resample_dataset_race(train_split_t,dark_frac)
    
    write_file(text_path,str(train_split['race'].value_counts()),'a')
    
    write_file(text_path,str(train_split['gender'].value_counts()),'a')
    
    #Dataloaders
    train_data = UTKFaceDataset(train_split, data_path, train_transform )

    train_loader = DataLoader(dataset = train_data, batch_size = batch_size, shuffle=True, num_workers=0)
    
    # model = models.resnet34(pretrained=False)
    # model.fc = nn.Linear(512, num_classes)
    # model.load_state_dict(torch.load(f"model_init_2class.pt"))
    # model.to(device)

    model = Network1(D_out = num_classes,device=device)
    model.to(device)
    criterion = nn.CrossEntropyLoss()

    for param in group_disc.parameters():
        param.requires_grad = False
    group_disc.to(device)

    for param in model_trans.parameters():
        param.requires_grad = False
    group_disc.to(device)

    lam = 0.5
    beta = 0.8
    
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=learning_rate, 
        betas=(0.5, 0.999), 
        weight_decay=0.08
        )

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=30, 
        eta_min=0.01 * learning_rate, verbose=True
        )
    
    
    
    # Actual training of model

    train_losses = []
    valid_losses = []

    train_accuracies = []
    val_accuracies = []

    valid_accuracy = []
    test_accuracy_D = []
    test_accuracy_L = []


    print("Training model...")

    best_val_acc = 0

    for epoch in range(1, num_epochs+1):
        # keep track of train/val loss
        train_loss = 0.0
        valid_loss = 0.0

        # training the model
        model.train()
        temp_train_acc = 0.0
        for data, target, grp in train_loader:

            data = data.to(device)
            target = target.to(device)
            grp = grp.to(device)

            optimizer.zero_grad()                   # init gradients to zeros   

            #feature extractor
            output_dash = group_disc(data)                    # forward pass
            output_dash = torch.reshape(output_dash,(output_dash.shape[0],output_dash.shape[1])).detach()

            #disentangle
            output_detangle, _ = model_trans(output_dash)

            output_detangle = output_detangle.detach()

            output,_,x2 = model(data)                    # forward pass

            x2 = torch.reshape(x2,(x2.shape[0],-1))
            loss = criterion(output, target)+lam*(beta**(epoch-1))*threshold_contrastive_trans(x2,output_detangle)        # compute loss

            loss.backward()                         # loss backwards
            optimizer.step()                        # update model params

            train_loss += loss.item() * data.size(0)

            op_temp = output.detach().cpu().numpy()
            op_temp = np.argmax(op_temp,axis=1)

            temp_train_acc += np.mean(op_temp==target.detach().cpu().numpy())*data.size(0)
            
        
        # validate-the-model
        model.eval()
        temp_val_acc = 0.0
        for data, target, grp in valid_loader:

            data = data.to(device)
            target = target.to(device)
            grp = grp.to(device)

            output,_,_ = model(data)

            loss = criterion(output, target)

            # update-average-validation-loss 
            valid_loss += loss.item() * data.size(0)

            op_temp = output.detach().cpu().numpy()
            op_temp = np.argmax(op_temp,axis=1)

            temp_val_acc += np.mean(op_temp==target.detach().cpu().numpy())*data.size(0)

        tvacc  = np.mean(np.array(temp_val_acc))

        if tvacc>best_val_acc:
            best_val_acc = tvacc
            torch.save(model.state_dict(), f"checkpoints/"+check_point_dir+"/model_best.pt")
            print('Model saved')
            write_file(text_path,'Model saved','a')

        # calculate-average-losses
        train_loss = train_loss/len(train_loader.sampler)
        valid_loss = valid_loss/len(valid_loader.sampler)
        
        ttacc  = temp_train_acc/len(train_loader.sampler)
        tvacc  = temp_val_acc/len(valid_loader.sampler)
        
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)

        train_accuracies.append(ttacc)
        val_accuracies.append(tvacc)

        scheduler.step()

        # print-training/validation-statistics 
        train_print = 'Epoch: {} \tTr Loss: {:.3f} \tTr Acc: {:.3f} \tVal Loss: {:.3f} \tVal Acc: {:.3f}'.format(
            epoch, train_loss, ttacc, valid_loss, tvacc)
        print(train_print)

        test_print_D, ttacc_D = test_performance(model,test_loader_D,criterion)
        test_print_L, ttacc_L = test_performance(model,test_loader_L,criterion)
        
        valid_accuracy.append(tvacc)
        test_accuracy_D.append(ttacc_D)
        test_accuracy_L.append(ttacc_L)

        write_file(text_path,train_print,'a')
#         with open(text_path, "w") as text_file:
#             text_file.write(train_print)
        
        write_file(text_path,test_print_D,'a')
#         with open(text_path, "w") as text_file:
#             text_file.write(test_print_D)
        
        write_file(text_path,test_print_L,'a')
#         with open(text_path, "w") as text_file:
#             text_file.write(test_print_L)
            
    path_val = f"checkpoints/"+check_point_dir+"/"+"validation_accuracy"
    path_D = f"checkpoints/"+check_point_dir+"/"+"test_accuracy_D"
    path_L = f"checkpoints/"+check_point_dir+"/"+"test_accuracy_L"
    valid_accuracy = np.array(valid_accuracy)
    test_accuracy_D = np.array(test_accuracy_D)
    test_accuracy_L = np.array(test_accuracy_L)
    np.save(path_val, valid_accuracy)
    np.save(path_D, test_accuracy_D)
    np.save(path_L, test_accuracy_L)

In [None]:
male_fracs = [0.5]#np.linspace(0.0,1.0,11)
for male_frac in male_fracs:
    print('********* Male fraction: {} *********'.format(male_frac))
    train_model(male_frac,train_split_t,valid_loader,test_loader_D,test_loader_L,model_trans,group_disc)

********* Male fraction: 0.5 *********
3512 1758 5616




Adjusting learning rate of group 0 to 5.0000e-04.
Training model...
Model saved
Adjusting learning rate of group 0 to 4.9864e-04.
Epoch: 1 	Tr Loss: inf 	Tr Acc: 0.564 	Val Loss: nan 	Val Acc: 0.500
Test Loss: nan 	Test Acc: 0.500
Test Loss: nan 	Test Acc: 0.500
Adjusting learning rate of group 0 to 4.9459e-04.
Epoch: 2 	Tr Loss: nan 	Tr Acc: 0.500 	Val Loss: nan 	Val Acc: 0.500
Test Loss: nan 	Test Acc: 0.500
Test Loss: nan 	Test Acc: 0.500
Adjusting learning rate of group 0 to 4.8789e-04.
Epoch: 3 	Tr Loss: nan 	Tr Acc: 0.500 	Val Loss: nan 	Val Acc: 0.500
Test Loss: nan 	Test Acc: 0.500


KeyboardInterrupt: ignored