In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import pandas as pd
from sklearn.metrics import accuracy_score
from torchvision.models import densenet121
import torchvision.models as models
# from vit_pytorch import ViT
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
import logging
import matplotlib.pyplot as plt




In [2]:

# Define your dataset class
class SpoofingDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.data.iloc[idx, 0][60:])
        image = Image.open(img_name).convert("RGB")
        label = self.data.iloc[idx, 1]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Define data augmentation transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomRotation(degrees=15),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load your dataset
root_dir = '/kaggle/input/oulu-npu/'
train_dataset = SpoofingDataset(csv_file=os.path.join('/kaggle/input/oulu-npu-csv', 'train_data.csv'), root_dir=root_dir, transform=transform)
val_dataset = SpoofingDataset(csv_file=os.path.join('/kaggle/input/oulu-npu-csv/', 'dev_data.csv'), root_dir=root_dir, transform=transform)
test_dataset = SpoofingDataset(csv_file=os.path.join('/kaggle/input/oulu-npu-csv', 'test_data.csv'), root_dir=root_dir, transform=transform)

# Define data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Define the ViT model with DenseNet backbone
class ViT_DenseNet(nn.Module):
    def __init__(self, num_classes):
        super(ViT_DenseNet, self).__init__()
        # Load the pre-trained DenseNet
        self.densenet = models.densenet121(pretrained=True)
        
        # Remove the classification layer of DenseNet
        self.densenet = nn.Sequential(*list(self.densenet.children())[:-1])
        self.reduce_channels = nn.Conv2d(in_channels=1024, out_channels=3, kernel_size=1)
        # Create the ViT model
        self.vit = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
        # Separate branch for binary supervision loss calculation
        self.supervision_branch = nn.Conv2d(1024, 1, kernel_size=1)
        
        self.fc = nn.Linear(1000, num_classes)  # Output layer for binary classification

    def forward(self, x):
        densenet_output = self.densenet(x)

        supervision_output = self.supervision_branch(densenet_output)  # For supervision loss
        # print('supervision_output', supervision_output.shape)

        resized_output = F.interpolate(densenet_output, scale_factor=32, mode='bilinear', align_corners=False)
        # print('resized_output', resized_output.shape)
        reduced_output = self.reduce_channels(resized_output)
        # print('reduced_output', reduced_output.shape)
        
        vit_output = self.vit(reduced_output)  # For classification
        # print('vit_output', vit_output.shape)

        final_output = self.fc(vit_output)
        # print('final_output', final_output.shape)
        
        return supervision_output, final_output

# Define the model, loss functions, and optimizer
model = ViT_DenseNet(num_classes=1)  # 1 output for binary classification
# mse_criterion = nn.MSELoss()  # Mean Squared Error loss
bce_criterion = nn.BCEWithLogitsLoss()  # Binary Cross-Entropy loss
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Training loop
num_epochs = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
if torch.cuda.is_available():
    if torch.cuda.device_count()>1:
        model=torch.nn.DataParallel(model)
    model.to(device)
else:
    raise NotImplementedError('This code unable to use GPU')

logging.basicConfig(filename='training_log.txt', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


for epoch in range(num_epochs):
    model.train()
    train_losses_bcep = []
    train_losses_bce = []
    total_training_loss = []
    for images, labels in tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        
        # Forward pass through the model
        supervision_output, final_output = model(images)
        
        # Calculate the binary supervision loss using MSE
        batch_size = labels.size(0)
        # Repeat labels tensor over the dimensions of supervision_output
        supervision_labels = labels.view(batch_size, 1, 1, 1).repeat(1, 1, 7, 7)
        
        supervision_loss = bce_criterion(supervision_output.view(-1), supervision_labels.view(-1).float())
        
        # Calculate the binary cross-entropy loss using BCE

        # Create the target tensor with the same shape as final_output
        # Reshape labels to have shape [batch_size, 1]
        labels = labels.unsqueeze(1)
        
        # Create final_labels tensor
        final_labels = torch.ones_like(final_output) * labels  # Broadcasting to match final_output shape

        cross_entropy_loss = bce_criterion(final_output.view(-1), final_labels.view(-1).float())
        
        # Calculate the total loss
        theta = 0.3  # Adjust as needed
        total_loss = (1 - theta) * cross_entropy_loss + theta * supervision_loss
        
        total_loss.backward()
        optimizer.step()
        train_losses_bcep.append(supervision_loss.item())
        train_losses_bce.append(cross_entropy_loss.item())
        total_training_loss.append(total_loss.item())
    
    model.eval()
    val_losses_bcep = []
    val_losses_bce = []
    total_val_loss = []
    val_predictions = []
    val_true_labels = []
    with torch.no_grad():
        for images, labels in tqdm(val_loader):
            images, labels = images.to(device), labels.to(device)
            supervision_output, final_output = model(images)
            batch_size = labels.size(0)
            # Repeat labels tensor over the dimensions of supervision_output
            supervision_labels = labels.view(batch_size, 1, 1, 1).repeat(1, 1, 7, 7)
            # Calculate the MSE loss for validation
            supervision_loss = bce_criterion(supervision_output.view(-1), supervision_labels.view(-1).float())
            val_losses_bcep.append(supervision_loss.item())
            
            # Calculate the BCE loss for validation
            labels = labels.unsqueeze(1)

            # Create final_labels tensor
            final_labels = torch.ones_like(final_output) * labels  # Broadcasting to match final_output shape

            cross_entropy_loss = bce_criterion(final_output.view(-1), final_labels.view(-1).float())
            val_losses_bce.append(cross_entropy_loss.item())

            theta = 0.3  # Adjust as needed
            total_valid_loss = (1 - theta) * cross_entropy_loss + theta * supervision_loss
            total_val_loss.append(total_valid_loss.item())

            val_predictions.extend(torch.sigmoid(final_output).cpu().numpy())
            val_true_labels.extend(labels.cpu().numpy())
    
    # Calculate validation accuracy
    val_predictions = [1 if x >= 0.5 else 0 for x in val_predictions]
    val_accuracy = accuracy_score(val_true_labels, val_predictions)

    # Save the model checkpoint
    checkpoint_path = f'densenetViT_OULU_NPU_checkpoint_epoch_{epoch+1}.pt'
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses_bcep': train_losses_bcep,
        'train_losses_bce': train_losses_bce,
        'total_training_loss': total_training_loss,
        'val_losses_bcep': val_losses_bcep,
        'val_losses_bce': val_losses_bce,
        'total_val_loss': total_val_loss,
        'val_predictions': val_predictions,
        'val_true_labels': val_true_labels,
        'val_accuracy': val_accuracy
    }, checkpoint_path)
    logging.info(f"Model checkpoint saved at: {checkpoint_path}") 
    
    # Log the values
    logging.info(f"Epoch [{epoch+1}/{num_epochs}]")
    logging.info(f"Training BCE supervision Loss: {sum(train_losses_bcep) / len(train_losses_bcep):.4f}")
    logging.info(f"Training BCE classification Loss: {sum(train_losses_bce) / len(train_losses_bce):.4f}")
    logging.info(f"Training Total Loss: {sum(total_training_loss) / len(total_training_loss):.4f}")
    logging.info(f"Validation BCE supervision Loss: {sum(val_losses_bcep) / len(val_losses_bcep):.4f}")
    logging.info(f"Validation BCE classification Loss: {sum(val_losses_bce) / len(val_losses_bce):.4f}")
    logging.info(f"Validation Total Loss: {sum(total_val_loss) / len(total_val_loss):.4f}")
    logging.info(f"Validation Accuracy: {val_accuracy:.4f}")

    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f"Training BCE supervision Loss: {sum(train_losses_bcep) / len(train_losses_bcep):.4f}")
    print(f"Training BCE classification Loss: {sum(train_losses_bce) / len(train_losses_bce):.4f}")
    print(f"Validation BCE supervision Loss: {sum(val_losses_bcep) / len(val_losses_bcep):.4f}")
    print(f"Validation BCE classification Loss: {sum(val_losses_bce) / len(val_losses_bce):.4f}")
    print(f"Validation Accuracy: {val_accuracy:.4f}")
    

test_losses_bcep = []
test_losses_bce = []
total_test_loss = []
test_predictions = []
test_true_labels = []
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images, labels = images.to(device), labels.to(device)
        supervision_output, final_output = model(images)
        batch_size = labels.size(0)
        # Repeat labels tensor over the dimensions of supervision_output
        supervision_labels = labels.view(batch_size, 1, 1, 1).repeat(1, 1, 7, 7)
        # Calculate the MSE loss for validation
        supervision_loss = bce_criterion(supervision_output.view(-1), supervision_labels.view(-1).float())
        test_losses_bcep.append(supervision_loss.item())

        # Calculate the BCE loss for validation
        labels = labels.unsqueeze(1)

        # Create final_labels tensor
        final_labels = torch.ones_like(final_output) * labels  # Broadcasting to match final_output shape

        cross_entropy_loss = bce_criterion(final_output.view(-1), final_labels.view(-1).float())
        test_losses_bce.append(cross_entropy_loss.item())

        theta = 0.3  # Adjust as needed
        total_testing_loss = (1 - theta) * cross_entropy_loss + theta * supervision_loss
        total_test_loss.append(total_testing_loss.item())

        test_predictions.extend(torch.sigmoid(final_output).cpu().numpy())
        test_true_labels.extend(labels.cpu().numpy())

# Calculate validation accuracy
test_predictions = [1 if x >= 0.5 else 0 for x in test_predictions]
test_accuracy = accuracy_score(test_true_labels, test_predictions)
logging.info(f"Test Accuracy: {test_accuracy:.4f}")
print(f"Test Accuracy: {test_accuracy:.4f}")


Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth
100%|██████████| 30.8M/30.8M [00:00<00:00, 59.7MB/s]
Downloading: "https://github.com/facebookresearch/deit/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /root/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth
100%|██████████| 330M/330M [00:09<00:00, 35.7MB/s] 
100%|██████████| 7466/7466 [6:46:31<00:00,  3.27s/it]  
100%|██████████| 5413/5413 [1:07:06<00:00,  1.34it/s]


Epoch [1/5]
Training BCE supervision Loss: 0.0177
Training BCE classification Loss: 0.0105
Validation BCE supervision Loss: 0.0192
Validation BCE classification Loss: 0.0240
Validation Accuracy: 0.9937


  1%|          | 47/7466 [02:36<6:53:00,  3.34s/it]


KeyboardInterrupt: 

In [None]:
# from IPython.display import FileLink
# FileLink(r'/kaggle/working/densenetViT_OULU_NPU_checkpoint_epoch_1.pth')

In [None]:
# Save the model
saved_model_path = 'densenet_vit_OULU_NPU.pth'
torch.save(model.state_dict(), saved_model_path)
print(f"Model saved at {saved_model_path}")

In [None]:
log_file = './exp3_training_log.txt'

epochs = []
train_bcep_losses = []
train_bce_losses = []
train_total_losses = []
val_bcep_losses = []
val_bce_losses = []
val_total_losses = []
val_accuracies = []

with open(log_file, 'r') as f:
    lines = f.readlines()
    for line in lines:
        if 'Epoch' in line:
            epochs.append(int(line.split('/')[0].split('[')[1]))
        elif 'Training BCE supervision Loss' in line:
            train_bcep_losses.append(float(line.split(': ')[1]))
        elif 'Training BCE classification Loss' in line:
            train_bce_losses.append(float(line.split(': ')[1]))
        # elif 'Training Total Loss' in line:
        #     train_total_losses.append(float(line.split(': ')[1]))
        elif 'Validation BCE supervision Loss' in line:
            val_bcep_losses.append(float(line.split(': ')[1]))
        elif 'Validation BCE classification Loss' in line:
            val_bce_losses.append(float(line.split(': ')[1]))
        # elif 'Validation Total Loss' in line:
        #     val_total_losses.append(float(line.split(': ')[1]))
        elif 'Validation Accuracy' in line:
            val_accuracies.append(float(line.split(': ')[1]))

# Now you have the extracted information, you can use it to create plots


In [None]:
# Plot total loss
# plt.figure(figsize=(12, 6))
# plt.subplot(2, 2, 1)
# plt.plot(epochs, train_total_losses, label='Training')
# plt.plot(epochs, val_total_losses, label='Validation')
# plt.xlabel('Epoch')
# plt.ylabel('Total Loss')
# plt.title('Total Loss Over Epochs')
# plt.legend()

# Plot MSE loss
plt.subplot(2, 2, 2)
plt.plot(epochs, train_bcep_losses, label='Training')
plt.plot(epochs, val_bcep_losses, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('BCE supervision Loss')
plt.title('BCE supervision Loss Over Epochs')
plt.legend()

# Plot BCE loss
plt.subplot(2, 2, 3)
plt.plot(epochs, train_bce_losses, label='Training')
plt.plot(epochs, val_bce_losses, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('BCE classification Loss')
plt.title('BCE classification Loss Over Epochs')
plt.legend()

# Plot accuracy
plt.subplot(2, 2, 4)
plt.plot(epochs, val_accuracies, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy Over Epochs')
plt.legend()

plt.tight_layout()
plt.show()