In [None]:
import os
import copy
import torch
import torchvision
import torch.nn as nn
import scipy
import torchvision.transforms as transforms
from torchvision import datasets as ds
from torch.utils.data import DataLoader,Dataset,Subset,ConcatDataset,random_split
import matplotlib.pyplot as plt
import numpy as np
import random
import pandas as pd
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import tqdm
import celeba_dataset
import torch.nn.functional as F

In [None]:
print(os.getcwd())

In [None]:

# transform = transforms.Compose是把一系列图片操作组合起来，比如减去像素均值等。
# DataLoader读入的数据类型是PIL.Image
# 这里对图片不做任何处理，仅仅是把PIL.Image转换为torch.FloatTensor，从而可以被pytorch计算
transform_train = transforms.Compose([
                                       #  transforms.CenterCrop((178, 178)),
                                       # transforms.Resize((128, 128)),
                                       # transforms.ToTensor()
    # transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [None]:
learning_rate = 0.001
batch_size = 128

In [None]:
##########################
### MODEL
##########################


class VGG16(torch.nn.Module):

    def __init__(self, num_classes):
        super(VGG16, self).__init__()

        # calculate same padding:
        # (w - k + 2*p)/s + 1 = o
        # => p = (s(o-1) - w + k)/2

        self.block_1 = nn.Sequential(
                nn.Conv2d(in_channels=3,
                          out_channels=64,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          # (1(32-1)- 32 + 3)/2 = 1
                          padding=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=64,
                          out_channels=64,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=(2, 2),
                             stride=(2, 2))
        )

        self.block_2 = nn.Sequential(
                nn.Conv2d(in_channels=64,
                          out_channels=128,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=128,
                          out_channels=128,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=(2, 2),
                             stride=(2, 2))
        )

        self.block_3 = nn.Sequential(
                nn.Conv2d(in_channels=128,
                          out_channels=256,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=256,
                          out_channels=256,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=256,
                          out_channels=256,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=256,
                          out_channels=256,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=(2, 2),
                             stride=(2, 2))
        )


        self.block_4 = nn.Sequential(
                nn.Conv2d(in_channels=256,
                          out_channels=512,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=512,
                          out_channels=512,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=512,
                          out_channels=512,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=512,
                          out_channels=512,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=(2, 2),
                             stride=(2, 2))
        )

        self.block_5 = nn.Sequential(
                nn.Conv2d(in_channels=512,
                          out_channels=512,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=512,
                          out_channels=512,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=512,
                          out_channels=512,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=512,
                          out_channels=512,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=(2, 2),
                             stride=(2, 2))
        )

        self.classifier = nn.Sequential(
                nn.Linear(512*4*4, 4096),
                nn.ReLU(),
                nn.Linear(4096, 4096),
                nn.ReLU(),
                nn.Linear(4096, num_classes)
        )


        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                #n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                #m.weight.data.normal_(0, np.sqrt(2. / n))
                m.weight.detach().normal_(0, 0.05)
                if m.bias is not None:
                    m.bias.detach().zero_()
            elif isinstance(m, torch.nn.Linear):
                m.weight.detach().normal_(0, 0.05)
                m.bias.detach().detach().zero_()


    def forward(self, x):

        x = self.block_1(x)
        x = self.block_2(x)
        x = self.block_3(x)
        x = self.block_4(x)
        x = self.block_5(x)

        logits = self.classifier(x.view(-1, 512*4*4))

        return logits

In [None]:
torch.manual_seed(1)
net = VGG16(2)
print(net)
# 定义损失函数和优化器
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
# 如果有gpu就使用gpu，否则使用cpu
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = net.to(device)

# FairFace数据集

In [None]:
class FairFaceDataset(Dataset):
    def __init__(self, df, root_path, transform=None):
        self.img_dir = root_path+"fairface/"
        self.img_names = df.index.values
        self.y =df.values
        self.transform = transform
    def __getitem__(self, index):
        img = Image.open(os.path.join(self.img_dir,
                                      self.img_names[index]))
        if self.transform is not None:
            img = self.transform(img)

        label = self.y[index]
        return img, label

    def __len__(self):
        return self.y.shape[0]

In [None]:
df_ff_train=pd.read_csv("../fairface/fairface_label_train.csv",index_col=0)
df_ff_val=pd.read_csv("../fairface/fairface_label_val.csv",index_col=0)
for index,column in enumerate(df_ff_train.columns):
    print(str(index)+" "+column)
df_ff_train['age']=df_ff_train['age'].replace(['0-2','3-9','10-19','20-29','30-39','40-49','50-59','60-69','more than 70'],[0,1,2,3,4,5,6,7,8])
df_ff_train['gender']=df_ff_train['gender'].replace(['Female',"Male"],[0,1])
df_ff_train['race']=df_ff_train['race'].replace(['East Asian','Indian','Black','White','Middle Eastern','Latino_Hispanic','Southeast Asian'],[0,1,2,3,4,5,6])
df_ff_train['service_test']=df_ff_train['service_test'].replace([False,True],[0,1])
df_ff_val['age'] = df_ff_val['age'].replace(
    ['0-2', '3-9', '10-19', '20-29', '30-39', '40-49', '50-59', '60-69', 'more than 70'], [0, 1, 2, 3, 4, 5, 6, 7, 8])
df_ff_val['gender']=df_ff_val['gender'].replace(['Female',"Male"],[0,1])
df_ff_val['race'] = df_ff_val['race'].replace(
    ['East Asian', 'Indian', 'Black', 'White', 'Middle Eastern', 'Latino_Hispanic', 'Southeast Asian'],
    [0, 1, 2, 3, 4, 5, 6])
df_ff_val['service_test'] = df_ff_val['service_test'].replace([False, True], [0, 1])
df_ff_train.head()

In [None]:
print(len(df_ff_train[df_ff_train['gender']==0]))
print(len(df_ff_train[df_ff_train['gender']==1]))

In [None]:
batch_size=128

## 划分人种

In [None]:
ff_train_dataset=FairFaceDataset(df_ff_train,"../",transform_train)
ff_val_dataset=FairFaceDataset(df_ff_val,"../",transform_train)
ff_train_dataloader=DataLoader(ff_train_dataset,
                               batch_size=batch_size,
                               shuffle=True)
ff_val_dataloader=DataLoader(ff_val_dataset,
                             batch_size=batch_size,
                             shuffle=True)
plt.imshow(ff_train_dataset[4][0].swapaxes(0, 1).swapaxes(1, 2))

In [None]:
ff_val_dataloader=DataLoader(FairFaceDataset(df_ff_val[(df_ff_val['race']!=5)&(df_ff_val['race']!=4)],"../",transform_train),
                             batch_size=batch_size,
                             shuffle=True)

In [None]:
#训练集
df_ff_train_eastAsian=df_ff_train[df_ff_train['race']==0]
df_ff_train_indian=df_ff_train[df_ff_train['race']==1]
df_ff_train_black=df_ff_train[df_ff_train['race']==2]
df_ff_train_white=df_ff_train[df_ff_train['race']==3]
df_ff_train_middleEastern=df_ff_train[df_ff_train['race']==4]
df_ff_train_LatinoHispanic=df_ff_train[df_ff_train['race']==5]
df_ff_train_southeastAsian=df_ff_train[df_ff_train['race']==6]
df_ff_train_eastAsian_gender0=df_ff_train[(df_ff_train['race']==0) & (df_ff_train['gender']==0)]
df_ff_train_eastAsian_gender1=df_ff_train[(df_ff_train['race']==0) & (df_ff_train['gender']==1)]
df_ff_train_indian_gender0=df_ff_train[(df_ff_train['race']==1) & (df_ff_train['gender']==0)]
df_ff_train_indian_gender1=df_ff_train[(df_ff_train['race']==1) & (df_ff_train['gender']==1)]
df_ff_train_black_gender0=df_ff_train[(df_ff_train['race']==2)&(df_ff_train['gender']==0)]
df_ff_train_black_gender1=df_ff_train[(df_ff_train['race']==2)&(df_ff_train['gender']==1)]
df_ff_train_white_gender0=df_ff_train[(df_ff_train['race']==3)&(df_ff_train['gender']==0)]
df_ff_train_white_gender1=df_ff_train[(df_ff_train['race']==3)&(df_ff_train['gender']==1)]
df_ff_train_middleEastern_gender0=df_ff_train[(df_ff_train['race']==4)&(df_ff_train['gender']==0)]
df_ff_train_middleEastern_gender1=df_ff_train[(df_ff_train['race']==4)&(df_ff_train['gender']==1)]
df_ff_train_LatinoHispanic_gender0=df_ff_train[(df_ff_train['race']==5)&(df_ff_train['gender']==0)]
df_ff_train_LatinoHispanic_gender1=df_ff_train[(df_ff_train['race']==5)&(df_ff_train['gender']==1)]
df_ff_train_southeastAsian_gender0=df_ff_train[(df_ff_train['race']==6)&(df_ff_train['gender']==0)]
df_ff_train_southeastAsian_gender1=df_ff_train[(df_ff_train['race']==6)&(df_ff_train['gender']==1)]


ff_train_eastAsian_dataset=FairFaceDataset(df_ff_train_eastAsian,'../',transform_train)
ff_train_indian_dataset=FairFaceDataset(df_ff_train_indian,'../',transform_train)
ff_train_black_dataset=FairFaceDataset(df_ff_train_black,'../',transform_train)
ff_train_white_dataset=FairFaceDataset(df_ff_train_white,'../',transform_train)
ff_train_middleEastern_dataset=FairFaceDataset(df_ff_train_middleEastern,'../',transform_train)
ff_train_LatinoHispanic_dataset=FairFaceDataset(df_ff_train_LatinoHispanic,'../',transform_train)
ff_train_southeastAsian_dataset=FairFaceDataset(df_ff_train_southeastAsian,'../',transform_train)

ff_train_eastAsian_gender0_dataset=FairFaceDataset(df_ff_train_eastAsian_gender0,'../',transform_train)
ff_train_eastAsian_gender1_dataset=FairFaceDataset(df_ff_train_eastAsian_gender1,'../',transform_train)
ff_train_indian_gender0_dataset=FairFaceDataset(df_ff_train_indian_gender0,'../',transform_train)
ff_train_indian_gender1_dataset=FairFaceDataset(df_ff_train_indian_gender1,'../',transform_train)
ff_train_black_gender0_dataset=FairFaceDataset(df_ff_train_black_gender0,'../',transform_train)
ff_train_black_gender1_dataset=FairFaceDataset(df_ff_train_black_gender1,'../',transform_train)
ff_train_white_gender0_dataset=FairFaceDataset(df_ff_train_white_gender0,'../',transform_train)
ff_train_white_gender1_dataset=FairFaceDataset(df_ff_train_white_gender1,'../',transform_train)
ff_train_middleEastern_gender0_dataset=FairFaceDataset(df_ff_train_middleEastern_gender0,'../',transform_train)
ff_train_middleEastern_gender1_dataset=FairFaceDataset(df_ff_train_middleEastern_gender1,'../',transform_train)
ff_train_LatinoHispanic_gender0_dataset=FairFaceDataset(df_ff_train_LatinoHispanic_gender0,'../',transform_train)
ff_train_LatinoHispanic_gender1_dataset=FairFaceDataset(df_ff_train_LatinoHispanic_gender1,'../',transform_train)
ff_train_southeastAsian_gender0_dataset=FairFaceDataset(df_ff_train_southeastAsian_gender0,'../',transform_train)
ff_train_southeastAsian_gender1_dataset=FairFaceDataset(df_ff_train_southeastAsian_gender1,'../',transform_train)

#测试集
df_ff_val_eastAsian=df_ff_val[df_ff_val['race']==0]
df_ff_val_indian=df_ff_val[df_ff_val['race']==1]
df_ff_val_black=df_ff_val[df_ff_val['race']==2]
df_ff_val_white=df_ff_val[df_ff_val['race']==3]
df_ff_val_middleEastern=df_ff_val[df_ff_val['race']==4]
df_ff_val_LatinoHispanic=df_ff_val[df_ff_val['race']==5]
df_ff_val_southeastAsian=df_ff_val[df_ff_val['race']==6]

df_ff_val_eastAsian_gender0=df_ff_val[(df_ff_val['race']==0) & (df_ff_val['gender']==0)]
df_ff_val_eastAsian_gender1=df_ff_val[(df_ff_val['race']==0) & (df_ff_val['gender']==1)]
df_ff_val_indian_gender0=df_ff_val[(df_ff_val['race']==1) & (df_ff_val['gender']==0)]
df_ff_val_indian_gender1=df_ff_val[(df_ff_val['race']==1) & (df_ff_val['gender']==1)]
df_ff_val_black_gender0=df_ff_val[(df_ff_val['race']==2)&(df_ff_val['gender']==0)]
df_ff_val_black_gender1=df_ff_val[(df_ff_val['race']==2)&(df_ff_val['gender']==1)]
df_ff_val_white_gender0=df_ff_val[(df_ff_val['race']==3)&(df_ff_val['gender']==0)]
df_ff_val_white_gender1=df_ff_val[(df_ff_val['race']==3)&(df_ff_val['gender']==1)]
df_ff_val_middleEastern_gender0=df_ff_val[(df_ff_val['race']==4)&(df_ff_val['gender']==0)]
df_ff_val_middleEastern_gender1=df_ff_val[(df_ff_val['race']==4)&(df_ff_val['gender']==1)]
df_ff_val_LatinoHispanic_gender0=df_ff_val[(df_ff_val['race']==5)&(df_ff_val['gender']==0)]
df_ff_val_LatinoHispanic_gender1=df_ff_val[(df_ff_val['race']==5)&(df_ff_val['gender']==1)]
df_ff_val_southeastAsian_gender0=df_ff_val[(df_ff_val['race']==6)&(df_ff_val['gender']==0)]
df_ff_val_southeastAsian_gender1=df_ff_val[(df_ff_val['race']==6)&(df_ff_val['gender']==1)]

ff_val_eastAsian_gender0_dataset=FairFaceDataset(df_ff_val_eastAsian_gender0,'../',transform_train)
ff_val_eastAsian_gender1_dataset=FairFaceDataset(df_ff_val_eastAsian_gender1,'../',transform_train)
ff_val_indian_gender0_dataset=FairFaceDataset(df_ff_val_indian_gender0,'../',transform_train)
ff_val_indian_gender1_dataset=FairFaceDataset(df_ff_val_indian_gender1,'../',transform_train)
ff_val_black_gender0_dataset=FairFaceDataset(df_ff_val_black_gender0,'../',transform_train)
ff_val_black_gender1_dataset=FairFaceDataset(df_ff_val_black_gender1,'../',transform_train)
ff_val_white_gender0_dataset=FairFaceDataset(df_ff_val_white_gender0,'../',transform_train)
ff_val_white_gender1_dataset=FairFaceDataset(df_ff_val_white_gender1,'../',transform_train)
ff_val_middleEastern_gender0_dataset=FairFaceDataset(df_ff_val_middleEastern_gender0,'../',transform_train)
ff_val_middleEastern_gender1_dataset=FairFaceDataset(df_ff_val_middleEastern_gender1,'../',transform_train)
ff_val_LatinoHispanic_gender0_dataset=FairFaceDataset(df_ff_val_LatinoHispanic_gender0,'../',transform_train)
ff_val_LatinoHispanic_gender1_dataset=FairFaceDataset(df_ff_val_LatinoHispanic_gender1,'../',transform_train)
ff_val_southeastAsian_gender0_dataset=FairFaceDataset(df_ff_val_southeastAsian_gender0,'../',transform_train)
ff_val_southeastAsian_gender1_dataset=FairFaceDataset(df_ff_val_southeastAsian_gender1,'../',transform_train)

ff_val_eastAsian_dataloader=DataLoader(FairFaceDataset(df_ff_val_eastAsian,'../',transform_train),batch_size=batch_size)
ff_val_indian_dataloader=DataLoader(FairFaceDataset(df_ff_val_indian,'../',transform_train),batch_size=batch_size)
ff_val_black_dataloader=DataLoader(FairFaceDataset(df_ff_val_black,'../',transform_train),batch_size=batch_size)
ff_val_white_dataloader=DataLoader(FairFaceDataset(df_ff_val_white,'../',transform_train),batch_size=batch_size)
ff_val_middleEastern_dataloader=DataLoader(FairFaceDataset(df_ff_val_middleEastern,'../',transform_train),batch_size=batch_size)
ff_val_LatinoHispanic_dataloader=DataLoader(FairFaceDataset(df_ff_val_LatinoHispanic,'../',transform_train),batch_size=batch_size)
ff_val_southeastAsian_dataloader=DataLoader(FairFaceDataset(df_ff_val_eastAsian,'../',transform_train),batch_size=batch_size)

ff_val_eastAsian_gender0_dataloader=DataLoader(FairFaceDataset(df_ff_val_eastAsian_gender0,'../',transform_train),batch_size=batch_size)
ff_val_eastAsian_gender1_dataloader=DataLoader(FairFaceDataset(df_ff_val_eastAsian_gender1,'../',transform_train),batch_size=batch_size)
ff_val_indian_gender0_dataloader=DataLoader(FairFaceDataset(df_ff_val_indian_gender0,'../',transform_train),batch_size=batch_size)
ff_val_indian_gender1_dataloader=DataLoader(FairFaceDataset(df_ff_val_indian_gender1,'../',transform_train),batch_size=batch_size)
ff_val_black_gender0_dataloader=DataLoader(FairFaceDataset(df_ff_val_black_gender0,'../',transform_train),batch_size=batch_size)
ff_val_black_gender1_dataloader=DataLoader(FairFaceDataset(df_ff_val_black_gender1,'../',transform_train),batch_size=batch_size)
ff_val_white_gender0_dataloader=DataLoader(FairFaceDataset(df_ff_val_white_gender0,'../',transform_train),batch_size=batch_size)
ff_val_white_gender1_dataloader=DataLoader(FairFaceDataset(df_ff_val_white_gender1,'../',transform_train),batch_size=batch_size)
ff_val_middleEastern_gender0_dataloader=DataLoader(FairFaceDataset(df_ff_val_middleEastern_gender0,'../',transform_train),batch_size=batch_size)
ff_val_middleEastern_gender1_dataloader=DataLoader(FairFaceDataset(df_ff_val_middleEastern_gender1,'../',transform_train),batch_size=batch_size)
ff_val_LatinoHispanic_gender0_dataloader=DataLoader(FairFaceDataset(df_ff_val_LatinoHispanic_gender0,'../',transform_train),batch_size=batch_size)
ff_val_LatinoHispanic_gender1_dataloader=DataLoader(FairFaceDataset(df_ff_val_LatinoHispanic_gender1,'../',transform_train),batch_size=batch_size)
ff_val_southeastAsian_gender0_dataloader=DataLoader(FairFaceDataset(df_ff_val_eastAsian_gender0,'../',transform_train),batch_size=batch_size)
ff_val_southeastAsian_gender1_dataloader=DataLoader(FairFaceDataset(df_ff_val_eastAsian_gender1,'../',transform_train),batch_size=batch_size)

### 构造不平衡数据集

In [None]:
def splitByRatio(dataset,ratio):
    return random_split(dataset,[int(len(dataset)*ratio),len(dataset)-int(len(dataset)*ratio)])

In [None]:
unbalance_Dataset=ConcatDataset([random_split(ff_train_eastAsian_gender0_dataset,[4000,len(ff_train_eastAsian_gender0_dataset)-4000])[0],
                                 random_split(ff_train_eastAsian_gender1_dataset,[4000,len(ff_train_eastAsian_gender1_dataset)-4000])[0],
                                 random_split(ff_train_indian_gender0_dataset,[4000,len(ff_train_indian_gender0_dataset)-4000])[0],
                                 random_split(ff_train_indian_gender1_dataset,[4000,len(ff_train_indian_gender1_dataset)-4000])[0],
                                 random_split(ff_train_black_gender0_dataset,[4000,len(ff_train_black_gender0_dataset)-4000])[0],
                                 random_split(ff_train_black_gender1_dataset,[4000,len(ff_train_black_gender1_dataset)-4000])[0],
                                 random_split(ff_train_white_gender0_dataset,[2000,len(ff_train_white_gender0_dataset)-2000])[0],
                                 random_split(ff_train_white_gender1_dataset,[6000,len(ff_train_white_gender1_dataset)-6000])[0],
                                 # random_split(ff_train_middleEastern_gender0_dataset,[2000,len(ff_train_middleEastern_gender0_dataset)-2000])[0],
                                 # random_split(ff_train_middleEastern_gender1_dataset,[2000,len(ff_train_middleEastern_gender1_dataset)-2000])[0],
                                 # random_split(ff_train_LatinoHispanic_gender0_dataset,[4000,len(ff_train_LatinoHispanic_gender0_dataset)-4000])[0],
                                 # random_split(ff_train_LatinoHispanic_gender1_dataset,[4000,len(ff_train_LatinoHispanic_gender1_dataset)-4000])[0],
                                 random_split(ff_train_southeastAsian_gender0_dataset,[4000,len(ff_train_southeastAsian_gender0_dataset)-4000])[0],
                                 random_split(ff_train_southeastAsian_gender1_dataset,[4000,len(ff_train_southeastAsian_gender1_dataset)-4000])[0]
                                ])
unbalance_dataloader=DataLoader(unbalance_Dataset,batch_size,shuffle=True)

In [None]:
import numpy as np
#ratio为正确标签的比例
def shuffle_dataset(dataset,target_index,ratio,max_range=1):
    np.random.seed(1)
    ds=copy.deepcopy(dataset)
    for i,_ in enumerate(ds):
        if np.random.rand(1)>ratio:
            t= random.randint(0,max_range)
            while ds[i][1][target_index] == t or t==4 or t==5:
                t=random.randint(0,max_range)
            ds[i][1][target_index]=t
    return ds

In [None]:
len_balance=len(unbalance_Dataset)/10
balance_Dataset=ConcatDataset([
    random_split(ff_train_eastAsian_gender0_dataset,[int(len_balance/14),len(ff_train_eastAsian_gender0_dataset)-int(len_balance/14)])[0],
    random_split(ff_train_eastAsian_gender1_dataset,[int(len_balance/14),len(ff_train_eastAsian_gender1_dataset)-int(len_balance/14)])[0],
    random_split(ff_train_indian_gender0_dataset,[int(len_balance/14),len(ff_train_indian_gender0_dataset)-int(len_balance/14)])[0],
    random_split(ff_train_indian_gender1_dataset,[int(len_balance/14),len(ff_train_indian_gender1_dataset)-int(len_balance/14)])[0],
    random_split(ff_train_black_gender0_dataset,[int(len_balance/14),len(ff_train_black_gender0_dataset)-int(len_balance/14)])[0],
    random_split(ff_train_black_gender1_dataset,[int(len_balance/14),len(ff_train_black_gender1_dataset)-int(len_balance/14)])[0],
    random_split(ff_train_white_gender0_dataset,[int(len_balance/14),len(ff_train_white_gender0_dataset)-int(len_balance/14)])[0],
    random_split(ff_train_white_gender1_dataset,[int(len_balance/14),len(ff_train_white_gender1_dataset)-int(len_balance/14)])[0],
    # random_split(ff_train_middleEastern_gender0_dataset,[int(len_balance/14),len(ff_train_middleEastern_gender0_dataset)-int(len_balance/14)])[0],
    # random_split(ff_train_middleEastern_gender1_dataset,[int(len_balance/14),len(ff_train_middleEastern_gender1_dataset)-int(len_balance/14)])[0],
    # random_split(ff_train_LatinoHispanic_gender0_dataset,[int(len_balance/14),len(ff_train_LatinoHispanic_gender0_dataset)-int(len_balance/14)])[0],
    # random_split(ff_train_LatinoHispanic_gender1_dataset,[int(len_balance/14),len(ff_train_LatinoHispanic_gender1_dataset)-int(len_balance/14)])[0],
    random_split(ff_train_southeastAsian_gender0_dataset,[int(len_balance/14),len(ff_train_southeastAsian_gender0_dataset)-int(len_balance/14)])[0],
    random_split(ff_train_southeastAsian_gender1_dataset,[int(len_balance/14),len(ff_train_southeastAsian_gender1_dataset)-int(len_balance/14)])[0]
                                ])
balance_dataloader=DataLoader(balance_Dataset,batch_size=batch_size,shuffle=True)

In [None]:
len_balance=len(unbalance_Dataset)/100
shuffle_Dataset=ConcatDataset([
    random_split(ff_train_eastAsian_gender0_dataset,[int(len_balance/14),len(ff_train_eastAsian_gender0_dataset)-int(len_balance/14)])[0],
    random_split(ff_train_eastAsian_gender1_dataset,[int(len_balance/14),len(ff_train_eastAsian_gender1_dataset)-int(len_balance/14)])[0],
    random_split(ff_train_indian_gender0_dataset,[int(len_balance/14),len(ff_train_indian_gender0_dataset)-int(len_balance/14)])[0],
    random_split(ff_train_indian_gender1_dataset,[int(len_balance/14),len(ff_train_indian_gender1_dataset)-int(len_balance/14)])[0],
    random_split(ff_train_black_gender0_dataset,[int(len_balance/14),len(ff_train_black_gender0_dataset)-int(len_balance/14)])[0],
    random_split(ff_train_black_gender1_dataset,[int(len_balance/14),len(ff_train_black_gender1_dataset)-int(len_balance/14)])[0],
    random_split(ff_train_white_gender0_dataset,[int(len_balance/14),len(ff_train_white_gender0_dataset)-int(len_balance/14)])[0],
    random_split(ff_train_white_gender1_dataset,[int(len_balance/14),len(ff_train_white_gender1_dataset)-int(len_balance/14)])[0],
    # random_split(ff_train_middleEastern_gender0_dataset,[int(len_balance/14),len(ff_train_middleEastern_gender0_dataset)-int(len_balance/14)])[0],
    # random_split(ff_train_middleEastern_gender1_dataset,[int(len_balance/14),len(ff_train_middleEastern_gender1_dataset)-int(len_balance/14)])[0],
    # random_split(ff_train_LatinoHispanic_gender0_dataset,[int(len_balance/14),len(ff_train_LatinoHispanic_gender0_dataset)-int(len_balance/14)])[0],
    # random_split(ff_train_LatinoHispanic_gender1_dataset,[int(len_balance/14),len(ff_train_LatinoHispanic_gender1_dataset)-int(len_balance/14)])[0],
    random_split(ff_train_southeastAsian_gender0_dataset,[int(len_balance/14),len(ff_train_southeastAsian_gender0_dataset)-int(len_balance/14)])[0],
    random_split(ff_train_southeastAsian_gender1_dataset,[int(len_balance/14),len(ff_train_southeastAsian_gender1_dataset)-int(len_balance/14)])[0]
                                ])
shuffle_Dataset=shuffle_dataset(shuffle_Dataset,target_index=2,ratio=-1,max_range=6)
shuffle_dataloader=DataLoader(shuffle_Dataset,batch_size=batch_size,shuffle=True)

In [None]:
# 训练模型的方法定义

def test(loader, net,target_index):
    net.eval()
    correct_pred=0
    num_examples=0
    for batch, (data, target) in enumerate(loader):
        data, target = data.to(device), target[:,target_index].to(device)
        logits = net(data)
        probas = F.softmax(logits,dim=1)
        _, predicted_labels = torch.max(probas, 1)
        num_examples += target.size(0)
        correct_pred += (predicted_labels == target).sum()
        # acc += torch.sum(torch.argmax(output, dim=1) == target).item()
        # sum += len(target)
        # loss_sum += loss.item()
    print('test  acc: %.2f%% ' %(100 * correct_pred.float() / num_examples))
    return 100 * correct_pred.float()/ num_examples

def train(loader, model, target_index, training_type):
    '''
    :param loader:
    :param model:
    :param target_index: 标签下标
    :param training_type: 模型名称
    :return:
    '''
    model.train()
    sum = 0.0
    correct_pred=0.0
    for batch, (data, target) in tqdm.tqdm( enumerate(loader),desc="模型训练中：", total=len(loader)):
        data, target = data.to(device), target[:,target_index].type(torch.LongTensor).to(device)
        optimizer.zero_grad()
        logits = model(data)
        probas = F.softmax(logits,dim=1)
        cost = F.cross_entropy(logits, target)
        cost.backward()
        optimizer.step()
        _, predicted = torch.max(probas, 1)
        sum += target.size(0)
        correct_pred += (predicted == target).sum()
    acc=100 * correct_pred.float() / sum
    print('train acc: %.2f%%' % (acc))
    torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, "../models/23_1_4/" + str(training_type) + "_checkpoint.pth")
    if correct_pred==sum:
        return True
    return False

def testRES(loader, net,target_index):
    net.eval()
    correct_pred=0
    num_examples=0
    for batch, (data, target) in enumerate(loader):
        data, target = data.to(device), target[:,target_index].to(device)
        probas = net(data)
        _, predicted_labels = torch.max(probas, 1)
        num_examples += target.size(0)
        correct_pred += (predicted_labels == target).sum()
        # acc += torch.sum(torch.argmax(output, dim=1) == target).item()
        # sum += len(target)
        # loss_sum += loss.item()
    print('test  acc: %.2f%% ' %(100 * correct_pred.float() / num_examples))
    return 100 * correct_pred.float()/ num_examples

def FairFaceTest(model,target_index):
    print("全部测试集：")
    testRES(ff_val_dataloader,model,target_index=target_index)
    n=[]
    print("eastAsian gender0测试集：")
    n.append(testRES(ff_val_eastAsian_gender0_dataloader,model,target_index=target_index).item())
    print("eastAsian gender1测试集：")
    n.append(testRES(ff_val_eastAsian_gender1_dataloader,model,target_index=target_index).item())
    print("indian gender0测试集：")
    n.append(testRES(ff_val_indian_gender0_dataloader,model,target_index=target_index).item())
    print("indian gender1测试集：")
    n.append(testRES(ff_val_indian_gender1_dataloader,model,target_index=target_index).item())
    print("black gender0测试集：")
    n.append(testRES(ff_val_black_gender0_dataloader,model,target_index=target_index).item())
    print("black gender1测试集：")
    n.append(testRES(ff_val_black_gender1_dataloader,model,target_index=target_index).item())
    print("white gender0测试集：")
    n.append(testRES(ff_val_white_gender0_dataloader,model,target_index=target_index).item())
    print("white gender1测试集：")
    n.append(testRES(ff_val_white_gender1_dataloader,model,target_index=target_index).item())
    # print("middleEastern gender0测试集：")
    # n.append(testRES(ff_val_middleEastern_gender0_dataloader,model,target_index=target_index).item())
    # print("middleEastern gender1测试集：")
    # n.append(testRES(ff_val_middleEastern_gender1_dataloader,model,target_index=target_index).item())
    # print("LatinoHispanic gender0测试集：")
    # n.append(testRES(ff_val_LatinoHispanic_gender0_dataloader,model,target_index=target_index).item())
    # print("LatinoHispanic gender1测试集：")
    # n.append(testRES(ff_val_LatinoHispanic_gender1_dataloader,model,target_index=target_index).item())
    print("southeastAsian gender0测试集：")
    n.append(testRES(ff_val_southeastAsian_gender0_dataloader,model,target_index=target_index).item())
    print("southeastAsian gender1测试集：")
    n.append(testRES(ff_val_southeastAsian_gender1_dataloader,model,target_index=target_index).item())
    print("方差：")
    print(np.var(n))
    return n
def load_model(model_path=None):
    net = torchvision.models.resnet34()
    global optimizer
    optimizer= torch.optim.Adam(net.parameters(), lr=learning_rate)
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    net = net.to(device)
    if model_path!=None:
        checkpoint = torch.load(model_path)
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    # if torch.cuda.device_count() > 1:
    #     print("Using", torch.cuda.device_count(), "GPUs")
    #     net= nn.DataParallel(net)
    return net

In [None]:
%%time
#原始训练
net=load_model()
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
for epoch in range(50):
        print('epoch: %d' % epoch)
        train(unbalance_dataloader,net,target_index=2,training_type="fairface_RESNET_race")#后两个标签_敏感特征
        if (epoch+1)%5==0:
            FairFaceTest(net,target_index=2)

In [None]:
#混淆训练
net=load_model('../models/22_12_30/fairface_RESNET_race_checkpoint.pth')
# optimizer.param_groups[0]["lr"]=0.001
for epoch in range(1):
        print('epoch: %d' % epoch)
        train(shuffle_dataloader,net,target_index=2,training_type="fairface_shuffle_RESNET_race")
n=FairFaceTest(net,target_index=2)
for i in n:
    print(i)

In [None]:
%%time
#恢复训练
net1=load_model('../models/23_1_4/fairface_shuffle_RESNET_race_checkpoint.pth')
for epoch in range(1000):
        print('epoch: %d' % epoch)
        if train(balance_dataloader,net1,target_index=2,training_type="fairface_balance_RESNET_race") :
            break
n=FairFaceTest(net1,target_index=2)
for i in n:
    print(i)

In [None]:
net2=load_model('../models/22_12_30/fairface_RESNET_race_checkpoint.pth')
for epoch in range(1000):
        print('epoch: %d' % epoch)
        if train(balance_dataloader,net2,target_index=2,training_type="fairface_finetune_RESNET_race") :
            break
n=FairFaceTest(net2,target_index=2)
for i in n:
    print(i)

In [None]:
net3=load_model('../models/23_1_3/fairface_RESNET_race_checkpoint.pth')
params_1x = [param for name, param in net3.named_parameters()
             if name not in ["fc.weight", "fc.bias"]]
optimizer= torch.optim.Adam([{'params': params_1x},
                               {'params': net3.fc.parameters(),
                                'lr': 0.001 }],lr=0.0003)
for epoch in range(100):
        print('epoch: %d' % epoch)
        if train(balance_dataloader,net3,target_index=2,training_type="fairface_finetune_RESNET_race") :
            break
n=FairFaceTest(net3,target_index=2)
for i in n:
    print(i)