**SSWIN-AD Alzheimer's Disease Classification**
- Contrast Limited Adaptive Histogram Equalization (CLAHE)
- Spectral Swin Transformer (SSWIN)
- Adaptive Contrast-aware Fine-Tuning (ACaFT)
- CNN Ensemble (AlexNet, GoogLeNet, ResNet-18)

## Import Libraries

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np
import cv2

## Hyperparameters

In [2]:
batch_size = 32
num_classes = 4
learning_rate = 1e-4
num_epochs = 100

## CLAHE Implementation

In [None]:
class CLAHETransform:
    def __init__(self, clipLimit=2.0, tileGridSize=(8,8)):
        self.clipLimit = clipLimit
        self.tileGridSize = tileGridSize

    def __call__(self, img):
        img = np.array(img)
        lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=self.clipLimit, tileGridSize=self.tileGridSize)
        l = clahe.apply(l)
        lab = cv2.merge((l, a, b))
        return cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)

## Data Loading and Preprocessing

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    CLAHETransform(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder('datasets/train', transform=transform)
test_dataset = datasets.ImageFolder('datasets/test', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## Spectral Swin Transformer Implementation

In [None]:
class SpectralSwinTransformer(nn.Module):
    def __init__(self, embed_dim=96, num_classes=4):
        super(SpectralSwinTransformer, self).__init__()
        self.transformer = models.swin_v2_t(weights='DEFAULT')
        self.fc = nn.Linear(1000, num_classes)

    def forward(self, x):
        x = self.transform_to_spectral(x)
        x = self.model(x)
        return self.fc(x)

    def transform_to_spectral(self, x):
        x_fft = torch.fft.fft2(x)
        x_fft_shifted = torch.fft.fftshift(x_fft)
        magnitude_spectrum = torch.abs(x_fft)
        return magnitude_spectrum

sswin_model = SpectralSwinTransformer()

## CNN Model Initialization

In [None]:
alexnet = models.alexnet(pretrained=True)
googlenet = models.googlenet(pretrained=True)
resnet18 = models.resnet18(pretrained=True)

alexnet.classifier[6] = nn.Linear(alexnet.classifier[6].in_features, num_classes)
googlenet.fc = nn.Linear(googlenet.fc.in_features, num_classes)
resnet18.fc = nn.Linear(resnet18.fc.in_features, num_classes)

models_list = [alexnet, googlenet, resnet18]

## Adaptive Contrast-aware Fine-Tuning (ACaFT)

In [None]:
def acaft_fine_tune(model, loader):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate, momentum=0.9)
    model.train()

    for epoch in range(num_epochs):
        epoch_loss = 0
        for images, labels in loader:
            spectral_features = spectral_swin_transformer(images)
            outputs = model(spectral_features)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(loader):.4f}")

## Fine-tune Models

In [None]:
for model in [alexnet, googlenet, resnet18]:
    for param in model.parameters():
        param.requires_grad = False
    for param in list(model.parameters())[-10:]:
        param.requires_grad = True
    acaft_fine_tune(model, train_loader)

## Ensemble Evaluation

In [None]:
ensemble_preds, ensemble_labels = [], []

with torch.no_grad():
    for images, labels in test_loader:
        spectral_features = spectral_swin_transformer(images)
        outputs = [torch.softmax(model(spectral_features), dim=1) for model in [alexnet, googlenet, resnet18]]
        ensemble_output = torch.mean(torch.stack(outputs), dim=0)
        pred_labels = torch.argmax(ensemble_output, dim=1)

        ensemble_preds.extend(pred_labels.cpu().numpy())
        ensemble_labels.extend(labels.cpu().numpy())

## Performance Metrics

In [None]:
accuracy = accuracy_score(ensemble_labels, ensemble_preds)
precision = precision_score(ensemble_labels, ensemble_preds, average='weighted')
recall = recall_score(ensemble_labels, ensemble_preds, average='weighted')
f1 = f1_score(ensemble_labels, ensemble_preds, average='weighted')

print("Model Performance Metrics:")
print(f"Accuracy: {accuracy * 100:.2f}%")
print(f"Precision: {precision * 100:.2f}%")
print(f"Recall: {recall * 100:.2f}%")
print(f"F1 Score: {f1 * 100:.2f}%")