In [2]:
import os
import pandas as pd
import torch
import warnings
from PIL import Image
import random
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
from torchvision.transforms import transforms
import torchvision.models as models
from tqdm import tqdm,trange
import matplotlib.pyplot as plt

warnings.filterwarnings('ignore')
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
label_path = 'recognizing-faces-in-the-wild/train_relationships.csv'
label_df = pd.read_csv(label_path)
label_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3598 entries, 0 to 3597
Data columns (total 2 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   p1      3598 non-null   object
 1   p2      3598 non-null   object
dtypes: object(2)
memory usage: 56.3+ KB


In [3]:
VAL_FAM = "F09"
def set_val(row):
    if VAL_FAM in row:
        return 1
    else:
        return 0    
    
all_img = []
for dir,_,f_list in os.walk('train'):
    for f in f_list:
        all_img.append(os.path.join(dir,f))
        
label_df['val'] = label_df['p1'].apply(set_val)
label_df['label'] = 1
train_pos = label_df[label_df['val']==0].count()['p1']
val_pos = len(label_df) - train_pos

Unnamed: 0,p1,p2,val,pos
0,F0002/MID1,F0002/MID3,0,1
1,F0002/MID2,F0002/MID3,0,1
2,F0005/MID1,F0005/MID2,0,1
3,F0005/MID3,F0005/MID2,0,1
4,F0009/MID1,F0009/MID4,0,1
...,...,...,...,...
3593,F1000/MID5,F1000/MID8,0,1
3594,F1000/MID5,F1000/MID9,0,1
3595,F1000/MID6,F1000/MID9,0,1
3596,F1000/MID7,F1000/MID8,0,1


In [5]:
def sample_neg(all_img):
    sampled = random.sample(all_img,2)
    first_sampled = sampled[0]
    family1 = first_sampled.split('/')[1]
    person1 = first_sampled.split('/')[2]
    second_sampled = sampled[1]
    family2 = second_sampled.split('/')[1]
    person2 = second_sampled.split('/')[2]
    while family1 == family2:
        second_sampled = random.sample(all_img,1)[0]
        family2 = second_sampled.split('/')[1]
        person2 = second_sampled.split('/')[2]
    return (family1+'/'+person1,family2+'/'+person2)

neg_samples = []
for i in range(train_pos):
    neg_samples.append(sample_neg(all_img))

In [6]:
neg_df = pd.DataFrame(neg_samples,columns=['p1','p2'])
neg_df['label'] = 0
neg_df['val'] = neg_df['p1'].apply(set_val)

Unnamed: 0,p1,p2,pos,val
0,F0552/MID1,F0425/MID1,0,0
1,F0439/MID1,F0099/MID1,0,0
2,F0164/MID1,F0730/MID1,0,0
3,F0601/MID6,F0879/MID1,0,0
4,F0672/MID7,F0599/MID2,0,0
...,...,...,...,...
3291,F0017/MID6,F0009/MID5,0,0
3292,F0828/MID5,F0690/MID9,0,0
3293,F0794/MID1,F0315/MID1,0,0
3294,F0763/MID3,F0683/MID1,0,0


In [7]:
train_neg = neg_df[neg_df['val']==0].count()['p1']
val_neg = len(neg_df) - train_neg

2979
317


In [8]:
final_df = pd.concat([label_df,neg_df])

Unnamed: 0,p1,p2,val,pos
0,F0002/MID1,F0002/MID3,0,1
1,F0002/MID2,F0002/MID3,0,1
2,F0005/MID1,F0005/MID2,0,1
3,F0005/MID3,F0005/MID2,0,1
4,F0009/MID1,F0009/MID4,0,1
...,...,...,...,...
3291,F0017/MID6,F0009/MID5,0,0
3292,F0828/MID5,F0690/MID9,0,0
3293,F0794/MID1,F0315/MID1,0,0
3294,F0763/MID3,F0683/MID1,0,0


In [10]:

final_df['check_exists_p1'] = final_df['p1'].map(lambda x: os.path.exists('train/'+x) and bool(os.listdir('train/'+x)))
final_df['check_exists_p2'] = final_df['p2'].map(lambda x: os.path.exists('train/'+x)and bool(os.listdir('train/'+x)))
final_df = final_df[(final_df['check_exists_p1'] is True) & (final_df['check_exists_p2'] is True)]

Unnamed: 0,p1,p2,val,pos,check_exists_p1,check_exists_p2
0,F0002/MID1,F0002/MID3,0,1,True,True
1,F0002/MID2,F0002/MID3,0,1,True,True
2,F0005/MID1,F0005/MID2,0,1,True,True
3,F0005/MID3,F0005/MID2,0,1,True,True
4,F0009/MID1,F0009/MID4,0,1,True,True
...,...,...,...,...,...,...
3291,F0017/MID6,F0009/MID5,0,0,True,True
3292,F0828/MID5,F0690/MID9,0,0,True,True
3293,F0794/MID1,F0315/MID1,0,0,True,True
3294,F0763/MID3,F0683/MID1,0,0,True,True


In [11]:
train_df = final_df[final_df['val']==0]

Unnamed: 0,p1,p2,val,pos,check_exists_p1,check_exists_p2
0,F0002/MID1,F0002/MID3,0,1,True,True
1,F0002/MID2,F0002/MID3,0,1,True,True
2,F0005/MID1,F0005/MID2,0,1,True,True
3,F0005/MID3,F0005/MID2,0,1,True,True
4,F0009/MID1,F0009/MID4,0,1,True,True
...,...,...,...,...,...,...
3290,F0866/MID1,F0403/MID3,0,0,True,True
3291,F0017/MID6,F0009/MID5,0,0,True,True
3292,F0828/MID5,F0690/MID9,0,0,True,True
3293,F0794/MID1,F0315/MID1,0,0,True,True


In [12]:
val_df = final_df[final_df['val']==1]

Unnamed: 0,p1,p2,val,pos,check_exists_p1,check_exists_p2
3274,F0900/MID2,F0900/MID1,1,1,True,True
3275,F0900/MID3,F0900/MID1,1,1,True,True
3276,F0901/MID1,F0901/MID4,1,1,True,True
3277,F0901/MID2,F0901/MID1,1,1,True,True
3278,F0901/MID2,F0901/MID4,1,1,True,True
...,...,...,...,...,...,...
3268,F0960/MID2,F0411/MID2,1,0,True,True
3276,F0983/MID3,F0903/MID3,1,0,True,True
3282,F0982/MID1,F0708/MID2,1,0,True,True
3285,F0973/MID2,F0920/MID6,1,0,True,True


In [13]:
class MyDataset(Dataset):
    def __init__(self,root_path,df,transforms=None):
        super(MyDataset,self).__init__()
        self.df = df
        self.root_path = root_path
        self.transforms = transforms
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self,idx):
        p1 = self.df.iloc[idx]['p1']
        p2 = self.df.iloc[idx]['p2']
        label = self.df.iloc[idx]['label']
        
        
        p1_path = os.path.join(self.root_path,p1)
        p2_path = os.path.join(self.root_path,p2)
        
        p1_sample = random.sample([os.path.join(p1_path,img) for img in os.listdir(p1_path)],1)[0]
        p2_sample = random.sample([os.path.join(p2_path,img) for img in os.listdir(p2_path)],1)[0]
            
        img1 = Image.open(p1_sample).convert('RGB')
        img2 = Image.open(p2_sample).convert('RGB')
        
        if self.transforms:
            img1 = self.transforms(img1)
            img2 = self.transforms(img2)
            
        return img1,img2,label

In [14]:

train_transformation = transforms.Compose([
                                     transforms.Resize(255),
                                     transforms.RandomCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                                     transforms.RandomErasing(),
                                    ])
valid_transformation = transforms.Compose([
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                                    ])
img_root = 'train'
train_data = MyDataset(img_root,train_df,train_transformation)
train_loader = DataLoader(train_data,shuffle=True,batch_size=64)

val_data = MyDataset(img_root,val_df,valid_transformation)
val_loader = DataLoader(val_data,shuffle=False,batch_size=1)

In [15]:

weights_path = 'resnet50-0676ba61.pth'
resnet = models.resnet50()
resnet.load_state_dict(torch.load(weights_path))
resnet_truncated = torch.nn.Sequential(*list(resnet.children())[:-1])
resnet_truncated.eval()

<All keys matched successfully>

In [17]:
class SiameseNet(nn.Module):
    def __init__(self,base_model):
        super(SiameseNet,self).__init__()
        self.base_model = base_model
        self.net = nn.Sequential(
            nn.Linear(12288,1024),
            nn.ReLU(),
            nn.Linear(1024,128),
            nn.ReLU(),
            nn.Linear(128,2)
        )
    
    def forward(self,img1,img2):
        b= img1.shape[0]
        feat1 = self.base_model(img1)
        feat2 = self.base_model(img2)
        feat1 = feat1.view(b,feat1.shape[1])
        feat2 = feat2.view(b,feat2.shape[1])
        feat1 = torch.cat((feat1,feat1),dim = 1)
        feat2 = torch.cat((feat2,feat2),dim = 1)

        cf1 = (feat1 * feat1) - (feat2 * feat2)
        cf2 = (feat1 - feat2) * (feat1 - feat2)
        cf3 = feat1 * feat2

        combined_features = torch.cat((cf1,cf2,cf3),dim=1)
        x = self.net(combined_features)
        return x

In [19]:
learning_rate = 0.0001
epochs = 10
model = SiameseNet(resnet_truncated).to(device)
optimizer = torch.optim.Adam(model.parameters(),lr = learning_rate,weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
train_loss_plot = []
train_acc_plot = []
val_loss_plot = []
val_acc_plot = []
best_val_acc = 0.0
save_path = 'SiameseNet.path'
for i in trange(epochs):
    model.train()
    total_num = 0
    corrected_num = 0
    total_loss = 0
    for train_img1,train_img2,train_label in tqdm(train_loader):
        train_img1,train_img2,train_label = train_img1.to(device),train_img2.to(device),train_label.to(device)
        output = model(train_img1,train_img2)
        loss = criterion(output,train_label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_num += train_label.shape[0]
        total_loss += train_label.shape[0] * loss.item()
        predicted = output.argmax(dim=-1)
        corrected_num += sum(predicted == train_label)

    model.eval()
    val_num = 0
    val_correct = 0
    val_loss = 0
    for val_img1,val_img2,val_label in tqdm(val_loader):
        val_img1,val_img2,val_label = val_img1.to(device),val_img2.to(device),val_label.to(device)
        val_output = model(val_img1,val_img2)
        val_loss += criterion(val_output,val_label).item()
        val_predicted = val_output.argmax(dim=-1)
        val_correct += sum(val_predicted == val_label)
        val_num += val_label.shape[0]

    epoch_train_loss = total_loss/total_num
    epoch_train_acc = corrected_num/total_num
    epoch_val_loss = val_loss/val_num
    epoch_val_acc = val_correct/val_num
    train_loss_plot.append(epoch_train_loss)
    train_acc_plot.append(epoch_train_acc)
    val_loss_plot.append(epoch_val_loss)
    val_acc_plot.append(epoch_val_acc)

    print(f'epoch: {i+1}: training loss: {epoch_train_loss} training accuracy: {epoch_train_acc:.2}% validation loss: {epoch_val_loss} validation accuracy: {epoch_val_acc:.2}%')
    print('-'*200)
    if not best_val_acc or epoch_val_loss < best_val_acc:
        torch.save(model.state_dict(), save_path)
        best_val_acc = epoch_val_loss
        print(f"Saving the model to \'{save_path}\'...")
        
        

100%|██████████| 95/95 [01:05<00:00,  1.44it/s]
100%|██████████| 613/613 [00:06<00:00, 99.77it/s] 
 10%|█         | 1/10 [01:12<10:48, 72.09s/it]

epoch: 1: training loss: 0.6239332280245766 training accuracy: 0.64% validation loss: 0.6259247287310247 validation accuracy: 0.67%
----------------------------------------------------------------------------------------------------


100%|██████████| 95/95 [01:05<00:00,  1.46it/s]
100%|██████████| 613/613 [00:06<00:00, 96.48it/s]
 20%|██        | 2/10 [02:23<09:33, 71.70s/it]

epoch: 2: training loss: 0.5974000159524708 training accuracy: 0.68% validation loss: 0.6093457493853951 validation accuracy: 0.68%
----------------------------------------------------------------------------------------------------


100%|██████████| 95/95 [01:05<00:00,  1.46it/s]
100%|██████████| 613/613 [00:05<00:00, 103.69it/s]
 30%|███       | 3/10 [03:34<08:20, 71.46s/it]

epoch: 3: training loss: 0.572193338986366 training accuracy: 0.7% validation loss: 0.607402968671261 validation accuracy: 0.69%
----------------------------------------------------------------------------------------------------


100%|██████████| 95/95 [01:04<00:00,  1.46it/s]
100%|██████████| 613/613 [00:06<00:00, 100.63it/s]
 40%|████      | 4/10 [04:45<07:07, 71.29s/it]

epoch: 4: training loss: 0.5527511061421103 training accuracy: 0.71% validation loss: 0.5601847342308155 validation accuracy: 0.7%
----------------------------------------------------------------------------------------------------


100%|██████████| 95/95 [01:05<00:00,  1.46it/s]
100%|██████████| 613/613 [00:06<00:00, 100.48it/s]
 50%|█████     | 5/10 [05:56<05:56, 71.24s/it]

epoch: 5: training loss: 0.5373675191570058 training accuracy: 0.73% validation loss: 0.5417819227354804 validation accuracy: 0.72%
----------------------------------------------------------------------------------------------------


100%|██████████| 95/95 [01:05<00:00,  1.46it/s]
100%|██████████| 613/613 [00:05<00:00, 102.67it/s]
 60%|██████    | 6/10 [07:07<04:44, 71.20s/it]

epoch: 6: training loss: 0.5319788143987789 training accuracy: 0.73% validation loss: 0.5711873298749057 validation accuracy: 0.7%
----------------------------------------------------------------------------------------------------


100%|██████████| 95/95 [01:05<00:00,  1.46it/s]
100%|██████████| 613/613 [00:05<00:00, 103.75it/s]
 70%|███████   | 7/10 [08:19<03:33, 71.16s/it]

epoch: 7: training loss: 0.5141719959115469 training accuracy: 0.74% validation loss: 0.5759375457513088 validation accuracy: 0.73%
----------------------------------------------------------------------------------------------------


100%|██████████| 95/95 [01:04<00:00,  1.46it/s]
100%|██████████| 613/613 [00:05<00:00, 103.62it/s]
 80%|████████  | 8/10 [09:29<02:22, 71.07s/it]

epoch: 8: training loss: 0.4988955305665855 training accuracy: 0.75% validation loss: 0.5212545060380978 validation accuracy: 0.74%
----------------------------------------------------------------------------------------------------


100%|██████████| 95/95 [01:04<00:00,  1.46it/s]
100%|██████████| 613/613 [00:05<00:00, 102.67it/s]
 90%|█████████ | 9/10 [10:40<01:11, 71.01s/it]

epoch: 9: training loss: 0.49756974877159493 training accuracy: 0.75% validation loss: 0.5271019583299633 validation accuracy: 0.75%
----------------------------------------------------------------------------------------------------


100%|██████████| 95/95 [01:04<00:00,  1.46it/s]
100%|██████████| 613/613 [00:06<00:00, 101.00it/s]
100%|██████████| 10/10 [11:51<00:00, 71.18s/it]

epoch: 10: training loss: 0.4874902180446091 training accuracy: 0.77% validation loss: 0.5456147156324246 validation accuracy: 0.75%
----------------------------------------------------------------------------------------------------





In [33]:
train_acc_plot = [acc.cpu() for acc in train_acc_plot]
val_acc_plot = [acc.cpu() for acc in val_acc_plot]

In [None]:
fig,(ax1,ax2) = plt.subplots(2,1)
plt.subplots_adjust(left=0.1,right=0.9,top = 0.9,bottom=0.1,wspace=0.5)
ax1.plot(train_loss_plot,label='train')
ax1.plot(val_loss_plot,label='val')
ax1.set_title('Loss')
ax2.plot(train_acc_plot,label='train')
ax2.plot(val_acc_plot,label='val')
ax2.set_title('Accuracy')
fig.legend()
fig.show()