In [1]:
!pip install torch torchvision pandas math

[31mERROR: Could not find a version that satisfies the requirement math (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for math[0m[31m
[0m

In [8]:
from sklearn.model_selection import train_test_split

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(
    images, labels, test_size=0.25, random_state=42
)

print(f"Training data shape: {X_train.shape}")
print(f"Testing data shape: {X_test.shape}")
print(f"Training labels shape: {y_train.shape}")
print(f"Testing labels shape: {y_test.shape}")

Training data shape: (2267, 50, 37)
Testing data shape: (756, 50, 37)
Training labels shape: (2267,)
Testing labels shape: (756,)


In [3]:
# Import required libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import numpy as np
import math
from typing import Optional, Tuple
import warnings
warnings.filterwarnings('ignore')

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cpu


In [10]:
class PatchEmbedding(nn.Module):
    """Convert image patches to embeddings"""
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        self.proj = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):
        # x: (B, C, H, W)
        x = self.proj(x)  # (B, embed_dim, H/patch_size, W/patch_size)
        x = x.flatten(2)  # (B, embed_dim, n_patches)
        x = x.transpose(1, 2)  # (B, n_patches, embed_dim)
        return x

In [11]:
class MultiHeadAttention(nn.Module):
    """Multi-head self-attention mechanism"""
    def __init__(self, embed_dim=768, n_heads=12, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads

        assert embed_dim % n_heads == 0, "embed_dim must be divisible by n_heads"

        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape

        # Generate Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Attention calculation
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

In [12]:
class MLP(nn.Module):
    """Multi-layer perceptron block"""
    def __init__(self, embed_dim=768, hidden_dim=3072, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = F.gelu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

In [13]:
class TransformerBlock(nn.Module):
    """Transformer encoder block with attention and MLP"""
    def __init__(self, embed_dim=768, n_heads=12, mlp_ratio=4, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, n_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), dropout)

    def forward(self, x):
        # Pre-norm design
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

In [14]:
class VisionTransformer(nn.Module):
    """Vision Transformer for morphed image detection"""
    def __init__(self, img_size=224, patch_size=16, in_channels=3,
                 embed_dim=768, depth=12, n_heads=12, mlp_ratio=4,
                 num_classes=2, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.n_patches = self.patch_embed.n_patches

        # Class token and positional embeddings
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.n_patches + 1, embed_dim))
        self.dropout = nn.Dropout(dropout)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, n_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

        # Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x, return_features=False):
        B = x.shape[0]

        # Patch embedding
        x = self.patch_embed(x)  # (B, n_patches, embed_dim)

        # Add class token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)

        # Add positional embeddings
        x = x + self.pos_embed
        x = self.dropout(x)

        # Transformer blocks
        for block in self.blocks:
            x = block(x)

        x = self.norm(x)
        cls_token_final = x[:, 0]  # Extract class token

        if return_features:
            return cls_token_final

        # Classification
        logits = self.head(cls_token_final)
        return logits

In [15]:
class Generator(nn.Module):
    """Generator for Multi-Collaborative GAN"""
    def __init__(self, latent_dim=100, img_channels=3, img_size=224):
        super().__init__()
        self.img_size = img_size
        self.init_size = img_size // 4  # Initial size before upsampling
        self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, img_channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

In [16]:
class Discriminator(nn.Module):
    """Single discriminator for collaborative training"""
    def __init__(self, img_channels=3, img_size=224):
        super().__init__()

        def discriminator_block(in_filters, out_filters, normalize=True):
            """Returns layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, 2, 1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_filters, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(img_channels, 16, normalize=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
        )

        # Calculate the size of the flattened features
        ds_size = img_size // 2 ** 5  # 5 downsampling layers
        self.adv_layer = nn.Sequential(
            nn.Linear(256 * ds_size ** 2, 1),
            nn.Sigmoid()
        )

        # Feature extraction layer for ensemble
        self.feature_layer = nn.Linear(256 * ds_size ** 2, 512)

    def forward(self, img, return_features=False):
        features = self.model(img)
        features_flat = features.view(features.shape[0], -1)

        if return_features:
            return self.feature_layer(features_flat)

        validity = self.adv_layer(features_flat)
        return validity

In [17]:
class MultiCollaborativeGAN(nn.Module):
    """Multi-Collaborative GAN with multiple discriminators"""
    def __init__(self, num_discriminators=3, latent_dim=100, img_channels=3, img_size=224):
        super().__init__()
        self.num_discriminators = num_discriminators
        self.latent_dim = latent_dim

        # Generator
        self.generator = Generator(latent_dim, img_channels, img_size)

        # Multiple discriminators
        self.discriminators = nn.ModuleList([
            Discriminator(img_channels, img_size) for _ in range(num_discriminators)
        ])

        # Feature fusion layer
        self.feature_fusion = nn.Sequential(
            nn.Linear(512 * num_discriminators, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128)
        )

    def forward(self, x, return_features=False):
        # Extract features from all discriminators
        disc_features = []
        disc_scores = []

        for discriminator in self.discriminators:
            features = discriminator(x, return_features=True)
            scores = discriminator(x, return_features=False)
            disc_features.append(features)
            disc_scores.append(scores)

        # Concatenate features from all discriminators
        combined_features = torch.cat(disc_features, dim=1)
        fused_features = self.feature_fusion(combined_features)

        if return_features:
            return fused_features

        # Average discriminator scores
        avg_score = torch.mean(torch.stack(disc_scores), dim=0)
        return avg_score

    def generate_samples(self, batch_size):
        """Generate samples using the generator"""
        z = torch.randn(batch_size, self.latent_dim, device=next(self.parameters()).device)
        return self.generator(z)

In [18]:
class EnsembleMetaModel(nn.Module):
    """Meta-model for ensemble learning combining ViT and Multi-Collaborative GAN"""
    def __init__(self, vit_feature_dim=768, gan_feature_dim=128, num_classes=2):
        super().__init__()

        # Feature normalization layers
        self.vit_norm = nn.LayerNorm(vit_feature_dim)
        self.gan_norm = nn.LayerNorm(gan_feature_dim)

        # Feature fusion network
        total_features = vit_feature_dim + gan_feature_dim
        self.fusion_network = nn.Sequential(
            nn.Linear(total_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.1)
        )

        # Final classification layers
        self.classifier = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )

        # Attention-based weighting
        self.attention = nn.Sequential(
            nn.Linear(total_features, 256),
            nn.ReLU(),
            nn.Linear(256, 2),  # 2 weights for ViT and GAN
            nn.Softmax(dim=1)
        )

    def forward(self, vit_features, gan_features):
        # Normalize features
        vit_features = self.vit_norm(vit_features)
        gan_features = self.gan_norm(gan_features)

        # Concatenate features
        combined_features = torch.cat([vit_features, gan_features], dim=1)

        # Calculate attention weights
        attention_weights = self.attention(combined_features)

        # Apply attention weighting
        weighted_vit = vit_features * attention_weights[:, 0:1]
        weighted_gan = gan_features * attention_weights[:, 1:2]

        # Combine weighted features
        final_features = torch.cat([weighted_vit, weighted_gan], dim=1)

        # Feature fusion
        fused_features = self.fusion_network(final_features)

        # Final classification
        logits = self.classifier(fused_features)

        return logits, attention_weights

In [19]:
class MorphDetectionEnsemble(nn.Module):
    """Complete ensemble model for morphed image detection"""
    def __init__(self, img_size=224, num_classes=2):
        super().__init__()

        # ViT models (B/16 and L/32 variants)
        self.vit_b16 = VisionTransformer(
            img_size=img_size, patch_size=16, embed_dim=768,
            depth=12, n_heads=12, num_classes=num_classes
        )

        self.vit_l32 = VisionTransformer(
            img_size=img_size, patch_size=32, embed_dim=1024,
            depth=24, n_heads=16, num_classes=num_classes
        )

        # Multi-Collaborative GAN
        self.mc_gan = MultiCollaborativeGAN(
            num_discriminators=3, img_size=img_size
        )

        # Meta-models for different combinations
        self.meta_model_b16_gan = EnsembleMetaModel(
            vit_feature_dim=768, gan_feature_dim=128, num_classes=num_classes
        )

        self.meta_model_l32_gan = EnsembleMetaModel(
            vit_feature_dim=1024, gan_feature_dim=128, num_classes=num_classes
        )

        # Final ensemble fusion
        self.final_fusion = nn.Sequential(
            nn.Linear(num_classes * 2, 128),  # 2 meta-models
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        # Extract features from ViT models
        vit_b16_features = self.vit_b16(x, return_features=True)
        vit_l32_features = self.vit_l32(x, return_features=True)

        # Extract features from Multi-Collaborative GAN
        gan_features = self.mc_gan(x, return_features=True)

        # Meta-model predictions
        meta1_logits, attention1 = self.meta_model_b16_gan(vit_b16_features, gan_features)
        meta2_logits, attention2 = self.meta_model_l32_gan(vit_l32_features, gan_features)

        # Combine meta-model predictions
        combined_logits = torch.cat([meta1_logits, meta2_logits], dim=1)

        # Final ensemble prediction
        final_logits = self.final_fusion(combined_logits)

        return {
            'final_logits': final_logits,
            'meta1_logits': meta1_logits,
            'meta2_logits': meta2_logits,
            'attention1': attention1,
            'attention2': attention2,
            'vit_b16_features': vit_b16_features,
            'vit_l32_features': vit_l32_features,
            'gan_features': gan_features
        }

In [20]:
class EnsembleLoss(nn.Module):
    """Combined loss function for ensemble training"""
    def __init__(self, alpha=0.6, beta=0.2, gamma=0.2):
        super().__init__()
        self.alpha = alpha  # Weight for final loss
        self.beta = beta    # Weight for meta-model 1 loss
        self.gamma = gamma  # Weight for meta-model 2 loss

        self.criterion = nn.CrossEntropyLoss()
        self.consistency_loss = nn.MSELoss()

    def forward(self, outputs, targets):
        # Main classification losses
        final_loss = self.criterion(outputs['final_logits'], targets)
        meta1_loss = self.criterion(outputs['meta1_logits'], targets)
        meta2_loss = self.criterion(outputs['meta2_logits'], targets)

        # Consistency loss between meta-models
        consistency = self.consistency_loss(
            F.softmax(outputs['meta1_logits'], dim=1),
            F.softmax(outputs['meta2_logits'], dim=1)
        )

        # Combine losses
        total_loss = (self.alpha * final_loss +
                     self.beta * meta1_loss +
                     self.gamma * meta2_loss +
                     0.1 * consistency)

        return total_loss, {
            'final_loss': final_loss.item(),
            'meta1_loss': meta1_loss.item(),
            'meta2_loss': meta2_loss.item(),
            'consistency_loss': consistency.item()
        }

In [21]:
def train_ensemble_epoch(model, dataloader, optimizer, criterion, device):
    """Train the ensemble model for one epoch"""
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (data, targets) in enumerate(dataloader):
        data, targets = data.to(device), targets.to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(data)

        # Calculate loss
        loss, loss_dict = criterion(outputs, targets)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Statistics
        total_loss += loss.item()
        _, predicted = outputs['final_logits'].max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if batch_idx % 100 == 0:
            print(f'Batch {batch_idx}: Loss = {loss.item():.4f}, Acc = {100.*correct/total:.2f}%')

    epoch_loss = total_loss / len(dataloader)
    epoch_acc = 100. * correct / total

    return epoch_loss, epoch_acc

In [22]:
def evaluate_ensemble(model, dataloader, criterion, device):
    """Evaluate the ensemble model"""
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for data, targets in dataloader:
            data, targets = data.to(device), targets.to(device)

            outputs = model(data)
            loss, _ = criterion(outputs, targets)

            total_loss += loss.item()
            _, predicted = outputs['final_logits'].max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    avg_loss = total_loss / len(dataloader)
    accuracy = 100. * correct / total

    return avg_loss, accuracy

In [32]:
# Data transformations
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Example dataset loading (modify paths as needed)
train_dataset = X_train
val_dataset = X_test

train_loader = y_train
val_loader = y_test

print("Data loading setup complete. Update dataset paths as needed.")

Data loading setup complete. Update dataset paths as needed.


In [33]:
# Initialize the complete ensemble model
model = MorphDetectionEnsemble(img_size=224, num_classes=2).to(device)

# Loss function and optimizer
criterion = EnsembleLoss(alpha=0.6, beta=0.2, gamma=0.2)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

# Model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model initialized successfully!")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

Model initialized successfully!
Total parameters: 455,739,748
Trainable parameters: 455,739,748


In [34]:
# Training loop (example - modify as needed)
def train_ensemble_model(model, train_loader, val_loader, num_epochs=50):
    """Complete training loop for the ensemble model"""
    best_val_acc = 0.0

    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print('-' * 50)

        # Training
        train_loss, train_acc = train_ensemble_epoch(
            model, train_loader, optimizer, criterion, device
        )

        # Validation
        val_loss, val_acc = evaluate_ensemble(
            model, val_loader, criterion, device
        )

        # Update learning rate
        scheduler.step()

        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_acc': best_val_acc,
            }, 'best_ensemble_model.pth')
            print(f'New best model saved with validation accuracy: {best_val_acc:.2f}%')

    print(f'\nTraining completed! Best validation accuracy: {best_val_acc:.2f}%')
    return model

# Uncomment to start training (ensure you have data loaders)
#trained_model = train_ensemble_model(model, train_loader, val_loader, num_epochs=50)

print("Training function ready. Uncomment the line above to start training.")

Training function ready. Uncomment the line above to start training.


In [37]:
def predict_single_image(model, image_path, transform, device):
    """Predict on a single image"""
    from PIL import Image

    model.eval()

    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(image_tensor)

        # Get probabilities
        probabilities = F.softmax(outputs['final_logits'], dim=1)
        prediction = probabilities.argmax(dim=1).item()
        confidence = probabilities.max().item()

        # Get attention weights
        attention1 = outputs['attention1'].cpu().numpy()[0]
        attention2 = outputs['attention2'].cpu().numpy()[0]

    result = {
        'prediction': 'Morphed' if prediction == 1 else 'Genuine',
        'confidence': confidence,
        'probabilities': probabilities.cpu().numpy()[0],
        'attention_weights_1': attention1,
        'attention_weights_2': attention2
    }

    return result

# Example usage:
#result = predict_single_image(model, 'path/to/image.jpg', val_transforms, device)
# print(f"Prediction: {result['prediction']} (Confidence: {result['confidence']:.3f})")

print("Inference function ready.")

Inference function ready.


In [38]:
# Model architecture summary
def print_model_summary(model):
    """Print detailed model architecture summary"""
    print("Ensemble Model Architecture Summary:")
    print("=" * 50)

    print("1. Vision Transformer B/16:")
    print(f"   - Patch size: 16x16")
    print(f"   - Embedding dimension: 768")
    print(f"   - Transformer blocks: 12")
    print(f"   - Attention heads: 12")

    print("2. Vision Transformer L/32:")
    print(f"   - Patch size: 32x32")
    print(f"   - Embedding dimension: 1024")
    print(f"   - Transformer blocks: 24")
    print(f"   - Attention heads: 16")

    print("3. Multi-Collaborative GAN:")
    print(f"   - Number of discriminators: 3")
    print(f"   - Feature dimension: 128")
    print(f"   - Collaborative training: Yes")

    print("4. Meta-Models:")
    print(f"   - Meta-model 1: ViT-B/16 + GAN")
    print(f"   - Meta-model 2: ViT-L/32 + GAN")
    print(f"   - Attention mechanism: Yes")

    print("5. Final Ensemble:")
    print(f"   - Fusion strategy: Feature-based super learning")
    print(f"   - Output classes: 2 (Genuine/Morphed)")

    total_params = sum(p.numel() for p in model.parameters())
    print(f"\nTotal Parameters: {total_params:,}")

print_model_summary(model)

Ensemble Model Architecture Summary:
1. Vision Transformer B/16:
   - Patch size: 16x16
   - Embedding dimension: 768
   - Transformer blocks: 12
   - Attention heads: 12
2. Vision Transformer L/32:
   - Patch size: 32x32
   - Embedding dimension: 1024
   - Transformer blocks: 24
   - Attention heads: 16
3. Multi-Collaborative GAN:
   - Number of discriminators: 3
   - Feature dimension: 128
   - Collaborative training: Yes
4. Meta-Models:
   - Meta-model 1: ViT-B/16 + GAN
   - Meta-model 2: ViT-L/32 + GAN
   - Attention mechanism: Yes
5. Final Ensemble:
   - Fusion strategy: Feature-based super learning
   - Output classes: 2 (Genuine/Morphed)

Total Parameters: 455,739,748


In [39]:
def save_complete_model(model, optimizer, epoch, accuracy, filepath='morphed_detection_ensemble.pth'):
    """Save the complete model with all components"""
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'accuracy': accuracy,
        'model_config': {
            'img_size': 224,
            'num_classes': 2,
        }
    }, filepath)
    print(f"Model saved to {filepath}")

def load_complete_model(filepath='morphed_detection_ensemble.pth', device='cpu'):
    """Load the complete model"""
    checkpoint = torch.load(filepath, map_location=device)

    # Initialize model
    model = MorphDetectionEnsemble(
        img_size=checkpoint['model_config']['img_size'],
        num_classes=checkpoint['model_config']['num_classes']
    ).to(device)

    # Load state dict
    model.load_state_dict(checkpoint['model_state_dict'])

    print(f"Model loaded from {filepath}")
    print(f"Training epoch: {checkpoint['epoch']}")
    print(f"Best accuracy: {checkpoint['accuracy']:.2f}%")

    return model, checkpoint

# Example usage:
# save_complete_model(model, optimizer, epoch=0, accuracy=0.0)
# loaded_model, checkpoint = load_complete_model()

print("Model save/load functions ready.")

Model save/load functions ready.


In [40]:
# Test the model with random input (for verification)
def test_model_forward_pass():
    """Test the model with random input to verify architecture"""
    model.eval()

    # Create random input
    test_input = torch.randn(2, 3, 224, 224).to(device)

    with torch.no_grad():
        outputs = model(test_input)

    print("Model Forward Pass Test:")
    print("=" * 30)
    print(f"Input shape: {test_input.shape}")
    print(f"Final logits shape: {outputs['final_logits'].shape}")
    print(f"Meta1 logits shape: {outputs['meta1_logits'].shape}")
    print(f"Meta2 logits shape: {outputs['meta2_logits'].shape}")
    print(f"Attention1 shape: {outputs['attention1'].shape}")
    print(f"Attention2 shape: {outputs['attention2'].shape}")
    print(f"ViT-B/16 features shape: {outputs['vit_b16_features'].shape}")
    print(f"ViT-L/32 features shape: {outputs['vit_l32_features'].shape}")
    print(f"GAN features shape: {outputs['gan_features'].shape}")

    # Test probabilities
    probs = F.softmax(outputs['final_logits'], dim=1)
    print(f"\nSample predictions (probabilities):")
    for i in range(test_input.size(0)):
        pred_class = probs[i].argmax().item()
        confidence = probs[i].max().item()
        class_name = 'Morphed' if pred_class == 1 else 'Genuine'
        print(f"  Sample {i+1}: {class_name} (confidence: {confidence:.3f})")

    print("\nModel architecture test completed successfully!")

# Run the test
test_model_forward_pass()

Model Forward Pass Test:
Input shape: torch.Size([2, 3, 224, 224])
Final logits shape: torch.Size([2, 2])
Meta1 logits shape: torch.Size([2, 2])
Meta2 logits shape: torch.Size([2, 2])
Attention1 shape: torch.Size([2, 2])
Attention2 shape: torch.Size([2, 2])
ViT-B/16 features shape: torch.Size([2, 768])
ViT-L/32 features shape: torch.Size([2, 1024])
GAN features shape: torch.Size([2, 128])

Sample predictions (probabilities):
  Sample 1: Morphed (confidence: 0.504)
  Sample 2: Morphed (confidence: 0.504)

Model architecture test completed successfully!
