# Comparative Analysis of Self-Supervised and Supervised Pretraining Approaches for Genshin Impact Character Classification Using ResNet-18

This notebook implements an experimental framework to compare self-supervised and supervised pretraining approaches for the task of Genshin Impact character classification using the ResNet-18 architecture. The goal is to evaluate the effectiveness of different pretraining strategies in leveraging limited labeled data for improved classification performance.

## Required Libraries
- PyTorch
- Torchvision
- NumPy
- Matplotlib
- Pandas
- Scikit-learn
- Seaborn
- Hugging Face (Optional for model uploading)
- Kaggle (for dataset management)

## Key elements of the implementation:
- **Image Preprocessing**:
    - Image resizing to 256x256 pixels with stretching and no aspect ratio preservation.
    - Normalization using ImageNet statistics.

- **Data Augmentation**:
    - Random cropping, horizontal flipping, color jittering, and Gaussian blur for SimCLR pretraining.
    - Random scaling $[1.0, 1.5]$ with random center cropping, rotation $[-15^\circ, 15^\circ]$ and horizontal flipping for fine-tuning.

- **Model Architecture**:
    - ResNet-18 with a projection head for SimCLR pretraining.
    - Classification MLP head for fine-tuning with 6 output classes.

- **Training Configuration**:
    - SimCLR pretraining with 500 epochs, batch size of 256, and a learning rate of 0.5.
    - Fine-tuning with 100 epochs, batch size of 32, and a learning rate of 0.01, with gradual unfreezing ResNet layers.

- **Loss Functions**:
    - Contrastive loss with temperature scaling for SimCLR pretraining.
    - Cross-entropy loss for classification fine-tuning.

> The contrastive loss is computed using the cosine similarity between the projected features of positive pairs, while the classification loss is computed using the softmax output of the final classification layer.


- **Evaluation Metrics**:
    - Cross-validation with 5 folds to ensure robustness.
    - Top-1 and Top-5 accuracy, F1-score, precision, recall, and confusion matrix analysis.
    - Visualization of learned features using t-SNE and Grad-CAM for interpretability.

- **Models Training Approaches**:
    - Pure supervised training on final dataset. (Comparison baseline)
    - Self-supervised pretraining on unlabeled dataset followed by fine-tuning by supervised training on the final dataset.
    - Self-supervised pretraining on unlabeled dataset followed by fine-tuning by semi-supervised training on the final dataset + non labeled dataset.
    - ImageNet supervised pretraining followed by fine-tuning by supervised training on the final dataset.

## 1. Dependencies import

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
import torchvision.datasets as datasets
import torchvision.models as models
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import kagglehub
import os
from tqdm import tqdm


## 2.Check if GPU is available

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print(f"Memory Available: {torch.cuda.get_device_properties(0).total_memory / (1024 ** 3):.2f} GB")
    print(f"Memory Allocated: {torch.cuda.memory_allocated(0) / (1024 ** 3):.2f} GB")
    print(f"Memory Cached: {torch.cuda.memory_reserved(0) / (1024 ** 3):.2f} GB")

In [None]:
!nvidia-smi

## 3. Dataset retrieval and transformation for Self-Supervised Learning dataset
- Download the Genshin Impact character dataset from Kaggle.
- Apply image transformations.
- Save as a single dataset for self-supervised learning.

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406] # Using ImageNet mean and std for normalization
IMAGENET_STD = [0.229, 0.224, 0.225]

def create_portable_dataset(source_paths, output_path, transform=None, verbose=True):
    """Create a portable dataset by preloading all images and storing them in a single file."""
    # Create the output directory if it doesn't exist
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    if verbose:
        print(f"Creating portable dataset at {output_path}")
    
    # Load datasets without transformation to get original images
    all_images = []
    all_labels = []
    
    for i, path in enumerate(source_paths):
        if verbose:
            print(f"Processing dataset {i+1}/{len(source_paths)} from {path}")
        
        try:
            # Load dataset with identity transform to get raw images
            dataset = datasets.ImageFolder(root=path, transform=transforms.ToTensor())
            
            # Create a loader without shuffling to preserve order
            loader = torch.utils.data.DataLoader(
                dataset, batch_size=32, shuffle=False, num_workers=2
            )
            
            # Process batches
            for images, labels in tqdm(loader, desc=f"Dataset {i+1}", disable=not verbose):
                # Apply the transform if provided
                if transform is not None:
                    transformed_images = []
                    for img in images:
                        # Convert tensor back to PIL for custom transforms
                        from torchvision.transforms.functional import to_pil_image
                        pil_img = to_pil_image(img)
                        transformed_images.append(transform(pil_img))
                    images = torch.stack(transformed_images)
                
                # Use the dataset index as a label offset
                all_images.append(images)
                all_labels.append(labels + i * 1000)  # Offset labels by 1000 for each dataset
        except Exception as e:
            print(f"Error processing dataset {path}: {e}")
            continue
    
    if len(all_images) == 0:
        raise ValueError("No images could be loaded from the provided paths.")
    
    # Combine all images and labels
    all_images = torch.cat(all_images, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    
    if verbose:
        print(f"Total images: {all_images.shape[0]}")
        print(f"Image shape: {all_images.shape[1:]}")
    
    # Create a dictionary to save
    portable_dataset = {
        'images': all_images,
        'labels': all_labels,
    }
    
    # Save the dataset
    torch.save(portable_dataset, output_path)
    if verbose:
        print(f"Dataset saved to {output_path}")
        print(f"File size: {os.path.getsize(output_path) / (1024 * 1024):.2f} MB")
    
    return portable_dataset

# Define the transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

# Create dataset directory
datasets_dir = os.path.join(os.getcwd(), "datasets")
os.makedirs(datasets_dir, exist_ok=True)

# SSL dataset creation
ssl_dataset_path = os.path.join(datasets_dir, "ssl-dataset-portable.pt")
if os.path.exists(ssl_dataset_path):
    print("SSL portable dataset already exists. Skipping creation.")
else:
    print("Creating portable SSL dataset...")
    
    # Download datasets from Kaggle
    try:
        ds1_path = kagglehub.dataset_download("soumikrakshit/anime-faces")
        ds2_path = kagglehub.dataset_download("stevenevan99/face-of-pixiv-top-daily-illustration-2020")
        ds3_path = kagglehub.dataset_download("hirunkulphimsiri/fullbody-anime-girls-datasets")
        
        # Create SSL dataset
        create_portable_dataset(
            [ds1_path, ds2_path, ds3_path],
            ssl_dataset_path,
            transform=transform
        )
    except Exception as e:
        print(f"Error creating SSL dataset: {e}")

class PortableDataset(torch.utils.data.Dataset):
    """Dataset class for loading preprocessed portable datasets."""
    def __init__(self, file_path):
        self.data = torch.load(file_path)
        self.images = self.data['images']
        self.labels = self.data['labels']
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]

# Load SSL dataset
if os.path.exists(ssl_dataset_path):
    ssl_dataset = PortableDataset(ssl_dataset_path)
    print(f"SSL dataset loaded with {len(ssl_dataset)} images")
else:
    print("SSL dataset not found. Please run the dataset creation block first.")
    ssl_dataset = None

def visualize_dataset(dataset, num_images=5, title=None):
    """Visualize random samples from a dataset."""
    if dataset is None:
        print("Dataset is not available for visualization.")
        return
        
    indices = np.random.choice(len(dataset), num_images, replace=False)
    images = [dataset[i][0] for i in indices]
    labels = [dataset[i][1] for i in indices]

    fig, axes = plt.subplots(1, num_images, figsize=(15, 5))
    for ax, img, label in zip(axes, images, labels):
        # Denormalize image for visualization
        img_np = img.numpy().transpose(1, 2, 0)
        mean = np.array(IMAGENET_MEAN)
        std = np.array(IMAGENET_STD)
        img_np = std * img_np + mean
        img_np = np.clip(img_np, 0, 1)
        
        ax.imshow(img_np)
        ax.axis('off')
        ax.set_title(f"Label: {label.item()}")
    
    if title:
        fig.suptitle(title)
    plt.tight_layout()
    plt.show()

# Visualize samples from SSL dataset
if ssl_dataset:
    visualize_dataset(ssl_dataset, num_images=10, title="SSL Dataset Samples")

## 4. Dataset retrieval and transformation for Supervised fine-tuning

In [None]:
from google.colab import drive # Mount Google Drive to access dataset
drive.mount('/content/gdrive/', force_remount=True)

In [None]:
import zipfile
import shutil

# Fine-tuning dataset creation
finetune_dataset_path = os.path.join(datasets_dir, "finetune-dataset-portable.pt")
if os.path.exists(finetune_dataset_path):
    print("Fine-tuning portable dataset already exists. Skipping creation.")
else:
    print("Creating portable fine-tuning dataset...")
    
    dataset_path_compressed = "/content/gdrive/MyDrive/GenshinImageClassifier/dataset.zip"
    if not os.path.exists(dataset_path_compressed):
        print(f"Dataset file at {dataset_path_compressed} does not exist. Please download it or update the path.")
    else:
        # Prepare temp directory
        tmp_dir = os.path.join(os.getcwd(), "tmp")
        os.makedirs(tmp_dir, exist_ok=True)
        
        # Copy and extract dataset
        print(f"Copying dataset file to tmp directory...")
        shutil.copy(dataset_path_compressed, os.path.join(tmp_dir, "dataset.zip"))
        
        # Unzip the dataset
        with zipfile.ZipFile(os.path.join(tmp_dir, "dataset.zip"), 'r') as zip_ref:
            zip_ref.extractall(tmp_dir)
        
        # Create portable dataset
        dataset_path = os.path.join(tmp_dir, "dataset")
        create_portable_dataset(
            [dataset_path],
            finetune_dataset_path,
            transform=transform
        )

# Load fine-tuning dataset
if os.path.exists(finetune_dataset_path):
    dataset = PortableDataset(finetune_dataset_path)
    print(f"Fine-tuning dataset loaded with {len(dataset)} images")
else:
    print("Fine-tuning dataset not found. Please run the dataset creation block first.")
    dataset = None

# Visualize samples from fine-tuning dataset
if dataset:
    visualize_dataset(dataset, num_images=10, title="Fine-tuning Dataset Samples")import zipfile
import shutil

# Fine-tuning dataset creation
finetune_dataset_path = os.path.join(datasets_dir, "finetune-dataset-portable.pt")
if os.path.exists(finetune_dataset_path):
    print("Fine-tuning portable dataset already exists. Skipping creation.")
else:
    print("Creating portable fine-tuning dataset...")
    
    dataset_path_compressed = "/content/gdrive/MyDrive/GenshinImageClassifier/dataset.zip"
    if not os.path.exists(dataset_path_compressed):
        print(f"Dataset file at {dataset_path_compressed} does not exist. Please download it or update the path.")
    else:
        # Prepare temp directory
        tmp_dir = os.path.join(os.getcwd(), "tmp")
        os.makedirs(tmp_dir, exist_ok=True)
        
        # Copy and extract dataset
        print(f"Copying dataset file to tmp directory...")
        shutil.copy(dataset_path_compressed, os.path.join(tmp_dir, "dataset.zip"))
        
        # Unzip the dataset
        with zipfile.ZipFile(os.path.join(tmp_dir, "dataset.zip"), 'r') as zip_ref:
            zip_ref.extractall(tmp_dir)
        
        # Create portable dataset
        dataset_path = os.path.join(tmp_dir, "dataset")
        create_portable_dataset(
            [dataset_path],
            finetune_dataset_path,
            transform=transform
        )

# Load fine-tuning dataset
if os.path.exists(finetune_dataset_path):
    dataset = PortableDataset(finetune_dataset_path)
    print(f"Fine-tuning dataset loaded with {len(dataset)} images")
else:
    print("Fine-tuning dataset not found. Please run the dataset creation block first.")
    dataset = None

# Visualize samples from fine-tuning dataset
if dataset:
    visualize_dataset(dataset, num_images=10, title="Fine-tuning Dataset Samples")

## 5. Models training, evaluation and feature extraction/visualization
- Train the baseline model on the final dataset.
- Train the self-supervised model on the unlabeled dataset.
- Fine-tune the self-supervised model on the final dataset using supervised training.
- Fine-tune the self-supervised model on the final dataset using semi-supervised training.
- Fine-tune the ImageNet pre-trained model on the final dataset using supervised training.

### 5.1. Baseline Model

In [None]:
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
import copy

EPOCHS = 200
BATCH_SIZE = 64
LEARNING_RATE = 0.001
K_FOLDS = 5

# Initialize 5-fold cross-validation
kfold = KFold(n_splits=K_FOLDS, shuffle=True, random_state=42)

# Storage for results across folds
fold_results = {
    'train_losses': [],
    'train_accuracies': [],
    'val_losses': [],
    'val_accuracies': [],
    'test_metrics': []
}

# Get all indices for the dataset
dataset_size = len(dataset)
indices = list(range(dataset_size))

print(f"Starting {K_FOLDS}-fold cross-validation on {dataset_size} samples...")

for fold, (train_idx, val_idx) in enumerate(kfold.split(indices)):
    print(f"\n{'='*50}")
    print(f"FOLD {fold + 1}/{K_FOLDS}")
    print(f"{'='*50}")

    # Create data samplers for train and validation
    train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
    val_sampler = torch.utils.data.SubsetRandomSampler(val_idx)

    # Create data loaders
    train_loader = torch.utils.data.DataLoader(
        dataset, batch_size=BATCH_SIZE, sampler=train_sampler
    )
    val_loader = torch.utils.data.DataLoader(
        dataset, batch_size=BATCH_SIZE, sampler=val_sampler
    )

    # Initialize model for this fold
    model = models.resnet18(weights=None, num_classes=6).to(device)

    # Define loss function and optimizer
    loss_criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # Storage for this fold's training history
    fold_train_losses = []
    fold_train_accuracies = []
    fold_val_losses = []
    fold_val_accuracies = []

    best_val_accuracy = 0.0
    best_model_state = None

    # Training loop for this fold
    for epoch in range(EPOCHS):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = loss_criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_train_loss = running_loss / len(train_idx)
        epoch_train_accuracy = 100 * correct / total

        # Validation phase
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = loss_criterion(outputs, labels)

                val_running_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        epoch_val_loss = val_running_loss / len(val_idx)
        epoch_val_accuracy = 100 * val_correct / val_total

        # Store metrics
        fold_train_losses.append(epoch_train_loss)
        fold_train_accuracies.append(epoch_train_accuracy)
        fold_val_losses.append(epoch_val_loss)
        fold_val_accuracies.append(epoch_val_accuracy)

        # Save best model
        if epoch_val_accuracy > best_val_accuracy:
            best_val_accuracy = epoch_val_accuracy
            best_model_state = copy.deepcopy(model.state_dict())

        # Print progress every 50 epochs
        if (epoch + 1) % 50 == 0:
            print(f"Epoch [{epoch+1}/{EPOCHS}]")
            print(f"  Train - Loss: {epoch_train_loss:.4f}, Acc: {epoch_train_accuracy:.2f}%")
            print(f"  Val   - Loss: {epoch_val_loss:.4f}, Acc: {epoch_val_accuracy:.2f}%")

    # Load best model for final evaluation
    model.load_state_dict(best_model_state)

    # Final evaluation on validation set (as test set for this fold)
    model.eval()
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)

            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculate detailed metrics for this fold
    fold_accuracy = accuracy_score(all_labels, all_predictions)
    fold_f1 = f1_score(all_labels, all_predictions, average='macro')
    fold_precision = precision_score(all_labels, all_predictions, average='macro')
    fold_recall = recall_score(all_labels, all_predictions, average='macro')

    # Store fold results
    fold_results['train_losses'].append(fold_train_losses)
    fold_results['train_accuracies'].append(fold_train_accuracies)
    fold_results['val_losses'].append(fold_val_losses)
    fold_results['val_accuracies'].append(fold_val_accuracies)
    fold_results['test_metrics'].append({
        'accuracy': fold_accuracy,
        'f1': fold_f1,
        'precision': fold_precision,
        'recall': fold_recall
    })

    print(f"\nFold {fold + 1} Results:")
    print(f"  Best Validation Accuracy: {best_val_accuracy:.2f}%")
    print(f"  Test Accuracy: {fold_accuracy*100:.2f}%")
    print(f"  Test F1-Score: {fold_f1:.4f}")
    print(f"  Test Precision: {fold_precision:.4f}")
    print(f"  Test Recall: {fold_recall:.4f}")

# Calculate and display overall results
print(f"\n{'='*60}")
print("CROSS-VALIDATION RESULTS SUMMARY")
print(f"{'='*60}")

# Extract test metrics
test_accuracies = [fold['accuracy'] for fold in fold_results['test_metrics']]
test_f1_scores = [fold['f1'] for fold in fold_results['test_metrics']]
test_precisions = [fold['precision'] for fold in fold_results['test_metrics']]
test_recalls = [fold['recall'] for fold in fold_results['test_metrics']]

# Calculate statistics
print(f"Test Accuracy:  {np.mean(test_accuracies)*100:.2f}% ± {np.std(test_accuracies)*100:.2f}%")
print(f"Test F1-Score:  {np.mean(test_f1_scores):.4f} ± {np.std(test_f1_scores):.4f}")
print(f"Test Precision: {np.mean(test_precisions):.4f} ± {np.std(test_precisions):.4f}")
print(f"Test Recall:    {np.mean(test_recalls):.4f} ± {np.std(test_recalls):.4f}")

# Plot training curves for all folds
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Average training loss
avg_train_losses = np.mean(fold_results['train_losses'], axis=0)
std_train_losses = np.std(fold_results['train_losses'], axis=0)
epochs_range = range(1, EPOCHS + 1)

axes[0, 0].plot(epochs_range, avg_train_losses, 'b-', label='Mean Training Loss')
axes[0, 0].fill_between(epochs_range,
                       avg_train_losses - std_train_losses,
                       avg_train_losses + std_train_losses,
                       alpha=0.3, color='blue')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Average Training Loss Across Folds')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Average validation loss
avg_val_losses = np.mean(fold_results['val_losses'], axis=0)
std_val_losses = np.std(fold_results['val_losses'], axis=0)

axes[0, 1].plot(epochs_range, avg_val_losses, 'r-', label='Mean Validation Loss')
axes[0, 1].fill_between(epochs_range,
                       avg_val_losses - std_val_losses,
                       avg_val_losses + std_val_losses,
                       alpha=0.3, color='red')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].set_title('Average Validation Loss Across Folds')
axes[0, 1].legend()
axes[0, 1].grid(True)

# Average training accuracy
avg_train_acc = np.mean(fold_results['train_accuracies'], axis=0)
std_train_acc = np.std(fold_results['train_accuracies'], axis=0)

axes[1, 0].plot(epochs_range, avg_train_acc, 'g-', label='Mean Training Accuracy')
axes[1, 0].fill_between(epochs_range,
                       avg_train_acc - std_train_acc,
                       avg_train_acc + std_train_acc,
                       alpha=0.3, color='green')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Accuracy (%)')
axes[1, 0].set_title('Average Training Accuracy Across Folds')
axes[1, 0].legend()
axes[1, 0].grid(True)

# Average validation accuracy
avg_val_acc = np.mean(fold_results['val_accuracies'], axis=0)
std_val_acc = np.std(fold_results['val_accuracies'], axis=0)

axes[1, 1].plot(epochs_range, avg_val_acc, 'm-', label='Mean Validation Accuracy')
axes[1, 1].fill_between(epochs_range,
                       avg_val_acc - std_val_acc,
                       avg_val_acc + std_val_acc,
                       alpha=0.3, color='magenta')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Accuracy (%)')
axes[1, 1].set_title('Average Validation Accuracy Across Folds')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.tight_layout()
plt.show()

# Bar plot for test metrics comparison
fig, ax = plt.subplots(1, 1, figsize=(10, 6))
metrics_names = ['Accuracy', 'F1-Score', 'Precision', 'Recall']
metrics_means = [np.mean(test_accuracies)*100, np.mean(test_f1_scores)*100,
                np.mean(test_precisions)*100, np.mean(test_recalls)*100]
metrics_stds = [np.std(test_accuracies)*100, np.std(test_f1_scores)*100,
               np.std(test_precisions)*100, np.std(test_recalls)*100]

x_pos = np.arange(len(metrics_names))
bars = ax.bar(x_pos, metrics_means, yerr=metrics_stds, capsize=5,
              color=['skyblue', 'lightgreen', 'lightcoral', 'lightsalmon'])

ax.set_xlabel('Metrics')
ax.set_ylabel('Score (%)')
ax.set_title('Cross-Validation Test Metrics (Mean ± Std)')
ax.set_xticks(x_pos)
ax.set_xticklabels(metrics_names)
ax.grid(True, alpha=0.3)

# Add value labels on bars
for bar, mean_val, std_val in zip(bars, metrics_means, metrics_stds):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height + std_val + 0.5,
            f'{mean_val:.1f}±{std_val:.1f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()