In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import brentq
from sklearn.metrics import roc_curve, roc_auc_score
from scipy.interpolate import interp1d
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
import torch.nn as nn
import torch.optim as optim
import timm
import random
import torch.nn.functional as F
from scipy.ndimage import gaussian_filter

In [None]:
# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Define paths to your data folders
train_labeled_data_dir = '/data/train/'
train_unlabeled_data_dir = '/data/unlabel/'
val_data_dir = '/data/validation/'
test_data_dir = '/data/test/'


# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Custom dataset class for semi-supervised learning
class SemiSupervisedDataset(Dataset):
    def __init__(self, labeled_data_dir, unlabeled_data_dir, transform=None):
        self.labeled_dataset = datasets.ImageFolder(root=labeled_data_dir, transform=transform)
        self.unlabeled_dataset = datasets.ImageFolder(root=unlabeled_data_dir, transform=transform)

    def __getitem__(self, index):
        labeled_image, label = self.labeled_dataset[index % len(self.labeled_dataset)]
        unlabeled_image, _ = self.unlabeled_dataset[index % len(self.unlabeled_dataset)]
        return labeled_image, label, unlabeled_image

    def __len__(self):
        return max(len(self.labeled_dataset), len(self.unlabeled_dataset))

# Load labeled, unlabeled, validation, and test datasets
train_dataset = SemiSupervisedDataset(train_labeled_data_dir, train_unlabeled_data_dir, transform=transform)
val_dataset = datasets.ImageFolder(root=val_data_dir, transform=transform)
test_dataset = datasets.ImageFolder(root=test_data_dir, transform=transform)

# Print the number of images in each dataset
print(f"Number of labeled images in the training set: {len(train_dataset.labeled_dataset)}")
print(f"Number of unlabeled images in the training set: {len(train_dataset.unlabeled_dataset)}")
print(f"Number of images in the validation set: {len(val_dataset)}")
print(f"Number of images in the test set: {len(test_dataset)}")

# Create DataLoader for training, validation, and test sets
batch_size = 8  # Adjust according to your needs
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [None]:
# Define the model architecture
class SimpleDenseNet(nn.Module):
    def __init__(self, num_classes):
        super(SimpleDenseNet, self).__init__()
        self.model = timm.create_model('densenet201', pretrained=True)
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        return self.model(x)

# Instantiate the custom DenseNet model
num_classes = 2  # Binary classification
model = SimpleDenseNet(num_classes)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)

# Define Gaussian filter to confidence scores

In [None]:

# Train the model with semi-supervised learning
num_epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

best_val_loss = float('inf')
early_stopping_patience = 3
patience_counter = 0

for epoch in range(num_epochs):
    model.train()
    for labeled_images, labels, unlabeled_images in train_loader:
        labeled_images, labels, unlabeled_images = labeled_images.to(device), labels.to(device), unlabeled_images.to(device)

        optimizer.zero_grad()

        # Forward pass for labeled images
        labeled_outputs = model(labeled_images)
        labeled_loss = criterion(labeled_outputs, labels)

        # Forward pass for unlabeled images (no labels used)
        unlabeled_outputs = model(unlabeled_images)
        pseudo_labels = torch.softmax(unlabeled_outputs, dim=1).argmax(dim=1)
        pseudo_loss = criterion(unlabeled_outputs, pseudo_labels)
        
        # Apply Gaussian filter to confidence scores
        confidence_scores = torch.max(F.softmax(unlabeled_outputs, dim=1), dim=1)[0]
        sigma = 3.0  # Adjust the standard deviation based on your requirements
        filtered_confidence_scores = gaussian_filter(confidence_scores.detach().cpu().numpy(), sigma)

        # Convert back to torch tensor
        filtered_confidence_scores = torch.tensor(filtered_confidence_scores, dtype=torch.float32, device=device)

        # Combine losses
        weighted_pseudo_loss = pseudo_loss * filtered_confidence_scores
        loss = labeled_loss + weighted_pseudo_loss.mean()

        loss.backward()
        optimizer.step()

    # Validate the model
    model.eval()
    val_loss = 0.0
    all_labels = []
    all_scores = []

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            scores = F.softmax(outputs, dim=1)[:, 1]
            all_labels.extend(labels.cpu().numpy())
            all_scores.extend(scores.cpu().numpy())

    val_loss /= len(val_loader)

    # Early stopping check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= early_stopping_patience:
            print("Validation loss didn't improve for {} epochs. Early stopping...".format(early_stopping_patience))
            break

    # Calculate EER once after training on the validation set
    fpr, tpr, thresholds = roc_curve(all_labels, all_scores, pos_label=1)

    # Check for NaN values in the arrays
    if any(np.isnan(fpr)) or any(np.isnan(tpr)) or any(np.isnan(thresholds)):
        print("Error: NaN values encountered in fpr, tpr, or thresholds during EER calculation. Skipping this epoch.")
        continue  # Skip to the next epoch
    else:
        eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
        threshold = thresholds[np.nanargmin(np.abs(fpr - eer))]

        print(f"Epoch {epoch + 1}/{num_epochs}, Validation EER: {eer * 100:.2f}%")
        print(f"Validation EER Threshold: {threshold:.4f}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

test_labels = []
test_scores = []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = model(inputs)
        scores = torch.nn.functional.softmax(outputs, dim=1)[:, 1]

        test_labels.extend(labels.cpu().numpy())
        test_scores.extend(scores.cpu().numpy())

# Apply Gaussian weighting to test set scores
sigma = 3.0  # Adjust the standard deviation based on your requirements
weighted_test_scores = gaussian_filter(test_scores, sigma)

# Calculate the HTER on the testing set using the EER threshold
threshold_test = threshold  # Use the EER threshold from the validation set for testing
predicted_labels_test = [1 if score > threshold_test else 0 for score in weighted_test_scores]

false_acceptance_test = sum(1 for i in range(len(predicted_labels_test)) if predicted_labels_test[i] == 1 and test_labels[i] == 0)
false_rejection_test = sum(1 for i in range(len(predicted_labels_test)) if predicted_labels_test[i] == 0 and test_labels[i] == 1)

total_samples_test = len(test_labels)
hter_test = ((false_acceptance_test + false_rejection_test) / (2 * total_samples_test)) * 100
print(f"HTER using EER threshold, Gaussian weighting, and L2 regularization: {hter_test:.2f}%")

# Calculate AUC on the test set
auc_test = roc_auc_score(test_labels, test_scores)
print(f"Area Under the ROC Curve (AUC) on the test set: {auc_test * 100:.2f}%")