In [None]:
def extract_face(image_path,show=False):
    image = Image.open(image_path)
    real_w,real_h = image.size
    bbox = open(image_path[:-4]+'_BB.txt').readlines()[0]

    bbox = [int(_) for _ in bbox.strip().split()[:4]]
    x1 = int(bbox[0]*(real_w / 224))
    y1 = int(bbox[1]*(real_h / 224))
    w1 = int(bbox[2]*(real_w / 224))
    h1 = int(bbox[3]*(real_h / 224))
    
    face = image.crop((x1, y1, x1 + w1, y1 + h1))
    if show:
        plt.imshow(face)
        plt.show()
#         plt.imshow(image)
    return face

In [None]:
data_dir = '/kaggle/input/celeba-spoof-for-face-antispoofing/CelebA_Spoof_/CelebA_Spoof/Data/'
train_size = len(os.listdir(os.path.join(data_dir, 'train')))
test_size = len(os.listdir(os.path.join(data_dir, 'test')))

print('train: {}; test: {}'.format(train_size, test_size))

In [None]:
path_train_json = '/kaggle/input/celeba-spoof-for-face-antispoofing/CelebA_Spoof_/CelebA_Spoof/metas/intra_test/train_label.json'
path_test_json = '/kaggle/input/celeba-spoof-for-face-antispoofing/CelebA_Spoof_/CelebA_Spoof/metas/intra_test/test_label.json'
path_local = '/kaggle/input/celeba-spoof-for-face-antispoofing/CelebA_Spoof_/CelebA_Spoof/'

In [None]:
df_train = pd.read_json(path_train_json, orient='index')
df_test = pd.read_json(path_test_json, orient='index')

df_train = df_train.reset_index()
df_test = df_test.reset_index()
df_train.rename(columns={'index': 'Filepath'}, inplace=True)
df_test.rename(columns={'index': 'Filepath'}, inplace=True)

In [None]:
df_train['Filepath'] = df_train['Filepath'].apply(lambda x: path_local +  x)
df_test['Filepath'] = df_test['Filepath'].apply(lambda x: path_local  + x)

In [None]:
df_train

In [None]:
invalid_file_name = '/kaggle/input/celeba-spoof-for-face-antispoofing/CelebA_Spoof_/CelebA_Spoof/Data/train/3329/spoof/004046.jpg'

df_train.drop(df_train[df_train['Filepath']==invalid_file_name].index, inplace=True)
df_train

In [None]:
df_train.Filepath[0]

In [None]:
# load bounding box
bbox = np.loadtxt(df_train.Filepath[15][:-4] + '_BB.txt')
bbox

In [None]:
face = extract_face(df_train.Filepath[0], show=True)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('using \'{}\' device'.format(device))
model = mobilenet_v2(pretrained=True)
# modify the final layer to output 2 classes (live/spoof)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 2)

model = model.to(device)

In [None]:
model

In [None]:
# transformations
transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
}


In [None]:
# prepare data
class FASDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]['Filepath']
        bbox_path = img_path[:-4] + '_BB.txt'
                
        cropped = extract_face(img_path)
        label = df_train.iloc[idx][43]
        
        if self.transforms is not None:
            

            image = self.transforms(cropped)
            
        return image, label

In [None]:
df_train[43].value_counts()

In [None]:
# down sample the size of trainning set
df_train_sample = df_train.sample(frac=0.005, random_state=43)
df_train_sample[43].value_counts()

In [None]:
df_1 = df_train_sample[df_train_sample[43]==1][:799]
df_2 = df_train_sample[df_train_sample[43]==0][:799]
df_train_sample_balanced = pd.concat([df_1, df_2])
df_train_sample_balanced = df_train_sample_balanced.sample(frac=1, random_state=42).reset_index(drop=True)

In [None]:
df_train_sample_balanced[43].value_counts()

In [None]:
train_dataset = FASDataset(df_train_sample_balanced, transforms['train'])
dataloader_train = DataLoader(train_dataset, batch_size=32)

In [None]:
print('length of dataset = ', len(train_dataset), '\n')
img, label = next(iter(dataloader_train))
img.size(), label.size()

In [None]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [None]:
num_epochs = 100

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    progress_bar = tqdm(dataloader_train, desc=f'Epoch {epoch+1}/{num_epochs}')
    for inputs, labels in progress_bar:
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader_train):.4f}")


In [None]:

# Save the model weights
model_save_path = '/kaggle/working/mobilenetv2_weights.pth'
torch.save(model.state_dict(), model_save_path)

print(f"Model weights saved to {model_save_path}")

In [None]:
# os.listdir('/kaggle/input/fas_mn2/pytorch/v1/2')

In [None]:
pretrain_weights = '/kaggle/input/fas_mn2/pytorch/v1/2/mobilenetv2_weights.pth'

In [None]:
test_dataset = FASDataset(df_test, transforms['test'])
dataloader_test = DataLoader(test_dataset, batch_size=32)

In [None]:
model = mobilenet_v2()

model.classifier[1] = nn.Linear(model.classifier[1].in_features, 2)

In [None]:
model.load_state_dict(torch.load(pretrain_weights))
model.to(device)
model.eval()

In [None]:
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc

# Initialize variables
correct = 0
total = 0
all_labels = []
all_preds = []
all_probs = []
sample_images = []  # To store sample images for visualization
sample_labels = []
sample_pred_probs = []

# Disable gradient calculation for evaluation
with torch.no_grad():
    progress_bar = tqdm(dataloader_test)
    for inputs, labels in progress_bar:
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Get model outputs
        outputs = model(inputs)
        
        # Compute probabilities using softmax
        probs = torch.nn.functional.softmax(outputs, dim=1)
        
        # Get the predicted class
        predicted_class = torch.argmax(probs, dim=1)
        
        # Update metrics
        total += labels.size(0)
        correct += (predicted_class == labels).sum().item()
        
        # Store for metrics calculation
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(predicted_class.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())
        
        # Store some samples for visualization (first 5 batches, 2 samples each)
        if len(sample_images) < 10 and len(sample_images) < inputs.size(0):
            sample_images.extend(inputs.cpu().numpy()[:2])
            sample_labels.extend(labels.cpu().numpy()[:2])
            sample_pred_probs.extend(probs.cpu().numpy()[:2])
        
        # Update progress bar
        accuracy = correct / total
        progress_bar.set_postfix({'acc': accuracy})

# Calculate final accuracy
accuracy = correct / total
print(f'Accuracy: {accuracy * 100:.2f}%')

# Additional metrics
print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=['Class 0', 'Class 1']))

# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)
print("\nConfusion Matrix:")
print(cm)

# ROC Curve (for binary classification)
if len(np.unique(all_labels)) == 2:  # Only for binary classification
    fpr, tpr, _ = roc_curve(all_labels, [p[1] for p in all_probs])
    roc_auc = auc(fpr, tpr)
    
    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic')
    plt.legend(loc="lower right")
    plt.show()

# Visualize sample predictions
print("\nSample Predictions:")
plt.figure(figsize=(15, 5))
for i in range(min(5, len(sample_images))):
    plt.subplot(1, 5, i+1)
    img = sample_images[i].transpose((1, 2, 0))  # Convert from (C, H, W) to (H, W, C)
    if img.shape[2] == 1:  # Grayscale to RGB
        img = np.repeat(img, 3, axis=2)
    # Normalize if needed
    img = (img - img.min()) / (img.max() - img.min())
    
    plt.imshow(img)
    true_label = sample_labels[i]
    pred_prob = sample_pred_probs[i]
    title = f'True: {true_label}\nProb: {pred_prob[true_label]:.2f}'
    plt.title(title)
    plt.axis('off')
plt.tight_layout()
plt.show()