In [1]:
import os
from torch.utils.data import Dataset
from torch.utils.data import random_split, DataLoader
from torchvision import datasets, transforms
import torch.nn as nn
import torch
import torch.optim as optim
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from torchvision.models import shufflenet_v2_x1_0
import timm
from transformers import MobileViTForImageClassification, MobileViTConfig
from transformers import ConvNextImageProcessor, ConvNextForImageClassification

#Device Configurations
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
class FrameDataset(Dataset):
    def __init__(self, folders_with_labels, transform=None):
        """
        Args:
            folders_with_labels (list of tuples): List of (path_to_folder, class_name) pairs.
            transform: Transform to apply to each frame.
        """
        self.data = []
        self.transform = transform

        # Build class_to_idx mapping
        self.class_to_idx = {class_name: idx for idx, (folder, class_name) in enumerate(folders_with_labels)}

        # Load data and assign labels
        for folder, class_name in folders_with_labels:
            label = self.class_to_idx[class_name]  # Use class_name to get the numeric label

            # List all subfolders (videos) in the given folder
            video_folders = [os.path.join(folder, subfolder) for subfolder in os.listdir(folder) if os.path.isdir(os.path.join(folder, subfolder))]

            for video_folder in video_folders:
                # Add each frame in the video folder as an independent sample
                frame_paths = [os.path.join(video_folder, frame) for frame in os.listdir(video_folder) if os.path.isfile(os.path.join(video_folder, frame))]
                self.data.extend([(frame_path, label) for frame_path in frame_paths])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        frame_path, label = self.data[idx]

        # Open the frame as an image
        frame = Image.open(frame_path).convert("RGB")

        # Apply transforms
        if self.transform:
            frame = self.transform(frame)

        label = torch.tensor(label, dtype=torch.float32)

        return frame, label

In [6]:
class EarlyStopping:
    def __init__(self, patience=3, mode='min', delta=0, verbose=True):
        """
        Args:
            patience (int): Number of epochs to wait for improvement before stopping.
            mode (str): 'min' for metrics like validation loss (smaller is better), 'max' for metrics like accuracy.
            delta (float): Minimum change to qualify as an improvement.
            verbose (bool): If True, prints messages when training stops early.
        """
        self.patience = patience
        self.mode = mode
        self.delta = delta
        self.verbose = verbose
        self.best_score = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, current_score):
        if self.best_score is None:
            self.best_score = current_score
        elif (self.mode == 'min' and current_score > self.best_score - self.delta) or \
             (self.mode == 'max' and current_score < self.best_score + self.delta):
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping: No improvement in {self.counter}/{self.patience} epochs.")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = current_score
            self.counter = 0


In [121]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=20):
    """
    Trains the model using the given training dataset and evaluates it on the validation dataset.

    Args:
        model (torch.nn.Module): The model to train.
        train_loader (DataLoader): DataLoader for the training dataset.
        val_loader (DataLoader): DataLoader for the validation dataset.
        criterion (torch.nn.Module): Loss function (e.g., BCEWithLogitsLoss for binary classification).
        optimizer (torch.optim.Optimizer): Optimization algorithm (e.g., Adam, SGD).
        num_epochs (int, optional): Number of training epochs. Default is 20.

    Returns:
        None
    """

    # Initialize early stopping to prevent overfitting
    early_stopping = EarlyStopping(patience=3, mode='min', verbose=True)

    # Training loop for the specified number of epochs
    for epoch in range(num_epochs):
        model.train()  # Set model to training mode
        running_loss = 0.0  # Accumulate loss per epoch
        correct_predictions = 0  # Track the number of correct predictions

        # Iterate through batches in the training dataset
        for frames, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
            # Move data to the selected device (GPU or CPU)
            frames, labels = frames.to(device), labels.to(device)

            # Forward pass - Attempt to process inputs with different model types
            outputs = None  # Initialize outputs to prevent unbound variable errors

            try:
                outputs = model(frames).squeeze(1)  # CNN model
            except Exception as e:
                pass
            try:
                outputs = model(frames).logits.squeeze(1)  # Transformer model
            except Exception as e:
                pass
            try:
                outputs = model(frames).logits.squeeze(-1)  # ConvNeXt-Tiny model
            except Exception as e:
                pass

            # If no valid output was generated, raise an error
            if outputs is None:
                raise ValueError("Model inference failed. Check input shapes and model compatibility.")

            # Compute loss using the specified criterion
            loss = criterion(outputs, labels)

            # Zero the gradients before backpropagation
            optimizer.zero_grad()

            # Backpropagation: Compute gradients
            loss.backward()

            # Update model parameters based on computed gradients
            optimizer.step()

            # Accumulate loss for the epoch (scaled by batch size)
            running_loss += loss.item() * frames.size(0)

            # Convert logits to probabilities using the sigmoid function (for binary classification)
            probs = torch.sigmoid(outputs)

            # Convert probabilities to binary predictions (threshold = 0.5)
            predicted = (probs > 0.5).float()

            # Count the number of correct predictions
            correct_predictions += (predicted == labels).sum().item()

        # Compute average training loss and accuracy for the epoch
        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = correct_predictions / len(train_loader.dataset)

        # Print training results for the current epoch
        print(f"Training Loss: {epoch_loss:.4f}, Training Accuracy: {epoch_acc:.4f}")

        # ---------------- Validation Phase ----------------
        model.eval()  # Set model to evaluation mode
        val_loss = 0.0
        val_correct_predictions = 0

        with torch.no_grad():  # Disable gradient computation for faster validation
            for frames, labels in tqdm(val_loader, desc="Validation"):
                # Move data to the selected device
                frames, labels = frames.to(device), labels.to(device)

                # Forward pass: Compute predictions
                outputs = None  # Initialize

                try:
                    outputs = model(frames).squeeze(1)  # CNN
                except Exception as e:
                    pass
                try:
                    outputs = model(frames).logits.squeeze(1)  # Transformer
                except Exception as e:
                    pass
                try:
                    outputs = model(frames).logits.squeeze(-1)  # ConvNeXt-Tiny
                except Exception as e:
                    pass

                if outputs is None:
                    raise ValueError("Model inference failed. Check input shapes and model compatibility.")

                # Compute validation loss
                loss = criterion(outputs, labels)
                val_loss += loss.item() * frames.size(0)

                # Convert logits to probabilities
                probs = torch.sigmoid(outputs)

                # Convert probabilities to binary predictions
                predicted = (probs > 0.5).float()

                # Count the number of correct predictions
                val_correct_predictions += (predicted == labels).sum().item()

        # Compute average validation loss and accuracy
        val_loss /= len(val_loader.dataset)
        val_acc = val_correct_predictions / len(val_loader.dataset)

        # Print validation results for the current epoch
        print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")

        # Early Stopping: Stop training if validation loss does not improve
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping triggered.")
            break  # Exit the training loop early


In [111]:
def evaluate_model(model, test_loader, device=device):
    """
    Evaluates the trained model on the test dataset.

    Args:
        model (torch.nn.Module): The trained model to evaluate.
        test_loader (DataLoader): DataLoader for the test dataset.
        device (torch.device): Device to perform computations (CPU or GPU).

    Returns:
        dict: Dictionary containing evaluation metrics including accuracy, precision, recall, F1-score,
              ROC-AUC score, along with all labels and predicted probabilities for further analysis.
    """
    # Set the model to evaluation mode (disables dropout, batch norm updates)
    model.eval()

    # Initialize lists to store labels, predictions, and probabilities
    all_labels = []  # Ground truth labels
    all_preds = []   # Predicted binary labels
    all_probs = []   # Predicted probabilities

    # Disable gradient computation for efficiency (no need for backpropagation)
    with torch.no_grad():
        # Iterate over test batches
        for inputs, labels in tqdm(test_loader, desc="Evaluating"):
            # Move inputs and labels to the selected device (GPU/CPU)
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)            inputs = inputs.to(device)
            labels = labels.to(device)

            # Forward pass: Obtain model outputs
            outputs = None  # Initialize outputs to prevent unbound variable errors

            # Attempt to process inputs with different model types
            try:
                outputs = model(inputs).squeeze(1)  # CNN Model (expects logits as direct output)
            except Exception as e:
                pass
            try:
                outputs = model(inputs).logits.squeeze(1)  # Transformer Model (access logits attribute)
            except Exception as e:
                pass
            try:
                outputs = model(inputs).logits.squeeze(-1)  # ConvNeXt-Tiny Model (adjust shape)
            except Exception as e:
                pass

            # If outputs were not successfully set, raise an error
            if outputs is None:
                raise ValueError("Model inference failed. Check input shapes and model compatibility.")

            # Convert logits to probabilities using the sigmoid function (for binary classification)
            probs = torch.sigmoid(outputs)

            # Convert probabilities to binary predictions using a 0.5 threshold
            preds = (probs > 0.5).float()

            # Store results for later evaluation
            all_labels.extend(labels.cpu().numpy())  # Convert to CPU and append
            all_preds.extend(preds.cpu().numpy())    # Convert to CPU and append
            all_probs.extend(probs.cpu().numpy())    # Convert to CPU and append

    # Compute classification metrics
    accuracy = accuracy_score(all_labels, all_preds)  # Percentage of correct predictions
    precision = precision_score(all_labels, all_preds, zero_division=0)  # True Positives / (True Positives + False Positives)
    recall = recall_score(all_labels, all_preds, zero_division=0)  # True Positives / (True Positives + False Negatives)
    f1 = f1_score(all_labels, all_preds, zero_division=0)  # Harmonic mean of precision and recall

    # Compute ROC-AUC score (handles probability-based evaluation)
    try:
        roc_auc = roc_auc_score(all_labels, all_probs)
    except ValueError:
        roc_auc = None  # Undefined if only one class is present in the dataset

    # Print the evaluation results
    print(f"Test Accuracy : {accuracy:.4f}")
    print(f"Test Precision: {precision:.4f}")
    print(f"Test Recall   : {recall:.4f}")
    print(f"Test F1-Score : {f1:.4f}")
    if roc_auc is not None:
        print(f"Test ROC-AUC  : {roc_auc:.4f}")
    else:
        print("Test ROC-AUC  : Undefined (only one class present in labels)")

    # Return all metrics in a dictionary for further analysis or plotting
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'roc_auc': roc_auc,
        'all_labels': all_labels,  # Needed for plotting the ROC curve later
        'all_probs': all_probs     # Needed for plotting the ROC curve later
    }


In [2]:
def reinitialize_model(model_name):
    """
    Reinitializes a deep learning model based on the given model name.

    Args:
        model_name (str): Name of the model to load.

    Returns:
        torch.nn.Module: A model with a binary classification output layer.

    Notes:
        - Supports models from `timm`, `torchvision`, `torch.hub`, and `transformers`.
        - Modifies the last classification layer to output a single value for binary classification.
    """

    # List of models from TIMM library (efficient pre-trained models)
    timm_models = ["mobilenetv2_100", 'tiny_vit_5m_224.dist_in22k_ft_in1k', "tf_efficientnet_lite0"]

    # Check if the requested model is in the TIMM list
    if model_name in timm_models:
        # Create a TIMM model with a modified classification layer
        model = timm.create_model(model_name, pretrained=True, num_classes=1)
        return model

    # Load a ShuffleNet model (from torchvision) and modify its classification head
    if model_name == "shufflenet_v2_x1_0":
        model = shufflenet_v2_x1_0(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, 1)  # Adjust for binary classification (single neuron output)
        return model

    # Load a GhostNet model (from Huawei's Torch Hub) and modify its classifier
    if model_name == "ghostnet_1x":
        model = torch.hub.load('huawei-noah/ghostnet', 'ghostnet_1x', pretrained=True)
        model.classifier = torch.nn.Linear(model.classifier.in_features, 1)  # Adjust classifier for binary output
        return model

    # Load a MobileViT model (from Apple's transformers library) and modify configuration
    if model_name == "apple/mobilevit-small":
        # Load configuration and modify number of labels to 1 for binary classification
        config = MobileViTConfig.from_pretrained('apple/mobilevit-small', num_labels=1)

        # Load the pre-trained model with the modified configuration
        model = MobileViTForImageClassification.from_pretrained(
            'apple/mobilevit-small',
            config=config,
            ignore_mismatched_sizes=True  # Avoid errors due to size mismatch
        )
        return model

    # Load a ConvNeXt-Tiny model (from Meta/Facebook) and modify its classifier
    if model_name == "facebook/convnext-tiny-224":
        # Load the image processor (not required for PyTorch inference but useful for pre-processing)
        processor = ConvNextImageProcessor.from_pretrained("facebook/convnext-tiny-224")

        # Load the pre-trained model
        model = ConvNextForImageClassification.from_pretrained("facebook/convnext-tiny-224")

        # Extract the number of input features for the classification head
        num_features = model.classifier.in_features

        # Replace the classifier with a new head for binary classification
        model.classifier = nn.Sequential(
            nn.Flatten(),              # Flatten output to fit into the linear layer
            nn.Linear(num_features, 1) # Single output neuron for binary classification
        )
        return model

    # If the model name does not match any known models, print an error message
    print("Model not found!")



In [113]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm  # For progress visualization

def process_datasets(model_name, datasets, transform=transform):
    """
    Trains and evaluates a given model on multiple datasets.

    Args:
        model_name (str): The name of the model to be used.
        datasets (list of str): List of dataset names to process.
        transform (callable, optional): Transformations applied to the dataset (e.g., image preprocessing).

    Returns:
        list: A list of evaluation results (metrics) for each dataset.
    """

    results = []  # Store evaluation results for each dataset

    # Iterate over the list of datasets
    for dataset_name in datasets:
        # Initialize early stopping to prevent overfitting
        early_stopping = EarlyStopping(patience=3, mode='min', verbose=True)

        print(f"Training on {dataset_name} dataset...")  # Log the current dataset

        # Reinitialize the model for each dataset to start with fresh weights
        model = reinitialize_model(model_name)

        # Define loss function (Binary Cross-Entropy with Logits for binary classification)
        criterion = nn.BCEWithLogitsLoss()

        # Define optimizer (AdamW with learning rate 0.001)
        optimizer = optim.AdamW(model.parameters(), lr=0.001)

        # Load dataset using FrameDataset (expects folders with labeled frames)
        folders_with_labels = [
            (f"./{dataset_name}", "manipulated"),  # Manipulated (Deepfake) samples
            ('./youtube', "original")  # Original (Authentic) samples
        ]

        # Create a dataset instance using the given transformations
        dataset = FrameDataset(folders_with_labels, transform=transform)

        # Split dataset into training (70%), validation (15%), and testing (15%) sets
        train_size = int(0.7 * len(dataset))  # 70% training data
        val_size = int(0.15 * len(dataset))   # 15% validation data
        test_size = len(dataset) - train_size - val_size  # Remainder goes to testing

        train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

        # Define batch size for DataLoaders
        batch_size = 32

        # Create DataLoaders for efficient batch processing
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

        # Move model to the selected device (GPU or CPU)
        model = model.to(device)

        # Train the model on the current dataset
        train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=20)

        # Evaluate the trained model on the test set and store the results
        results.append(evaluate_model(model, test_loader))

    # Return the evaluation metrics for each dataset
    return results



In [114]:
def plot_multiple_roc_auc_curves(results_list, model_labels=None, dataset_labels=None):
    """
    Plots multiple ROC-AUC curves on the same plot with dataset names.

    Args:
        results_list (list of dicts): Each dict should contain:
            - 'all_labels': True labels (0 or 1)
            - 'all_probs': Predicted probabilities for the positive class
        model_labels (list of str, optional): Names of models. Default is None.
        dataset_labels (list of str, optional): Names of datasets. Default is None.

    Returns:
        None (Displays the ROC-AUC plot)
    """
    plt.figure(figsize=(8, 6))

    for i, results in enumerate(results_list):
        all_labels = results['all_labels']
        all_probs = results['all_probs']

        # Compute ROC curve and AUC score
        fpr, tpr, _ = roc_curve(all_labels, all_probs)
        roc_auc = auc(fpr, tpr)

        # Generate label for legend
        model_name = model_labels[i] if model_labels else f'Model {i+1}'
        dataset_name = dataset_labels[i] if dataset_labels else f'Dataset {i+1}'

        legend_label = f'{dataset_name} (AUC = {roc_auc:.4f})'

        # Plot ROC Curve
        plt.plot(fpr, tpr, lw=2, label=legend_label)

    # Plot baseline (random classifier)
    plt.plot([0, 1], [0, 1], color='gray', linestyle='--', lw=2, label='Chance (AUC = 0.50)')

    # Configure plot
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc='lower right')
    plt.grid()
    plt.show()


In [115]:
def plot_performance_metrics(results, dataset_names):
    """
    Plots performance metrics comparison for multiple datasets.

    Args:
        results (list of dict): A list of dictionaries, where each dictionary contains metrics for a dataset.
        dataset_names (list of str): Names of the datasets corresponding to the results.
    """
    # Define metrics and colors
    metrics = ['accuracy', 'precision', 'recall', 'f1_score', 'roc_auc']
    colors = ['blue', 'orange', 'green', 'red', 'purple']

    # Extract metric values for each dataset
    metric_values = {metric: [result[metric] for result in results] for metric in metrics}

    # Bar plot settings
    x = np.arange(len(dataset_names))  # Positions for datasets on x-axis
    width = 0.15  # Width of each bar

    # Create the plot
    plt.figure(figsize=(10, 6))
    for i, (metric, color) in enumerate(zip(metrics, colors)):
        bars = plt.bar(x + i * width, metric_values[metric], width, label=metric.capitalize(), color=color)

        # Add metric name as a label inside each bar
        for bar in bars:
            height = bar.get_height()
            plt.text(
                bar.get_x() + bar.get_width() / 2,  # Center the text
                height - 0.05,                     # Slightly below the top of the bar
                f'{metric.capitalize()}',         # Metric name
                ha='center', va='top', fontsize=8, color='white'
            )

    # Customize the x-axis
    plt.yticks(np.arange(0.0, 1.1, 0.05))  # Add y-axis ticks from 0.0 to 1.0 with step 0.05
    plt.xticks(x + width * (len(metrics) - 1) / 2, dataset_names)  # Align dataset names with bars
    plt.xlabel("Dataset")
    plt.ylabel("Metric Value")
    plt.title("Performance Metrics Comparison")
    plt.ylim(0, 1.1)  # Limit y-axis to show all labels
    plt.legend()
    plt.grid(axis="y", linestyle="--", alpha=0.7)
    plt.tight_layout()

    # Show the plot
    plt.show()


In [116]:
def train_test_all_models(model_names, datasets, transform=transform):
    """
    Trains and evaluates multiple models on multiple datasets.

    Args:
        model_names (list of str): List of model names to train and evaluate.
        datasets (list of str): List of dataset names for training and evaluation.
        transform (callable, optional): Transformations applied to datasets (e.g., image preprocessing).

    Returns:
        dict: A dictionary where keys are model names and values are lists of evaluation results.
              Each entry contains evaluation metrics for each dataset.
    """

    data = {}  # Dictionary to store results for each model

    # Iterate through each model in the provided list
    for model in model_names:
        print(f"Training {model}...")  # Log the current model being trained

        # Train and evaluate the model on all datasets
        data[model] = process_datasets(model, datasets, transform=transform)

    # Return the collected evaluation results
    return data


In [118]:
# List of model names to be used for training and evaluation
# These models include:
# - `mobilenetv2_100`: A lightweight CNN model from TIMM.
# - `tiny_vit_5m_224.dist_in22k_ft_in1k`: A small Vision Transformer (ViT) model.
# - `tf_efficientnet_lite0`: A compact version of EfficientNet designed for mobile devices.
# - `shufflenet_v2_x1_0`: A fast, lightweight CNN optimized for mobile applications.
# - `ghostnet_1x`: A model optimized for low-memory inference, loaded from Torch Hub.
# - `apple/mobilevit-small`: A MobileViT model (Vision Transformer for mobile use).
# - `facebook/convnext-tiny-224`: A small ConvNeXt model from Meta (Facebook).
model_names = [
    "mobilenetv2_100",
    "tiny_vit_5m_224.dist_in22k_ft_in1k",
    "tf_efficientnet_lite0",
    "shufflenet_v2_x1_0",
    "ghostnet_1x",
    "apple/mobilevit-small",
    "facebook/convnext-tiny-224"
]

# List of deepfake datasets to train and evaluate models on
# - `Deepfakes`: A dataset containing AI-generated fake videos.
# - `Face2Face`: A dataset using the Face2Face manipulation method.
# - `FaceShifter`: A dataset with more realistic face-swapping techniques.
# - `FaceSwap`: A dataset containing swapped face videos.
# - `NeuralTextures`: A dataset that uses neural texture synthesis for face manipulation.
datasets = ["Deepfakes", "Face2Face", "FaceShifter", "FaceSwap", "NeuralTextures"]

# Define the transformations applied to images before feeding them into models
# - `Resize((224, 224))`: Resizes all images to 224x224 pixels to match model input size.
# - `ToTensor()`: Converts images to PyTorch tensors (HWC → CHW format).
# - `Normalize(mean, std)`: Normalizes images using ImageNet mean and standard deviation.
#   - Mean: [0.485, 0.456, 0.406] (Red, Green, Blue channel mean values).
#   - Std:  [0.229, 0.224, 0.225] (Standard deviation for each channel).
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224
    transforms.ToTensor(),  # Convert images to tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize using ImageNet stats
])


#### Transforms

In [None]:
# Noise Transform 
transform_noisy = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x + torch.randn_like(x) * 0.1),  # Noise before normalization
    transforms.Lambda(lambda x: torch.clamp(x, 0, 1)),  # Keep values in valid range
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# Random Crop 
transform_random_crop = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize to larger size first
    transforms.RandomCrop(224),     # Randomly crop to 224×224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# Lighting Variations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
data = train_test_all_models(model_names, datasets, transform=transform_noisy)