In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import scipy.io as sio
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, cohen_kappa_score, confusion_matrix

## Co-ordinate Attention Module:

In [2]:
class CoordinateAttention(nn.Module):
    def __init__(self, in_channels, out_channels, reduction=32):
        super(CoordinateAttention, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        self.reduce_channels = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1)
        self.bn = nn.BatchNorm2d(in_channels // reduction)
        self.act = nn.ReLU(inplace=True)
        self.fc_h = nn.Conv2d(in_channels // reduction, out_channels, kernel_size=1)
        self.fc_w = nn.Conv2d(in_channels // reduction, out_channels, kernel_size=1)

    def forward(self, x):
        identity = x
        n, c, h, w = x.size()

        # Apply horizontal pooling and transformations
        x_h = self.pool_h(x).permute(0, 1, 3, 2)
        x_h = self.reduce_channels(x_h).permute(0, 1, 3, 2)
        x_h = self.bn(x_h)
        x_h = self.fc_h(self.act(x_h))
        x_h = x_h.sigmoid()

        # Apply vertical pooling and transformations
        x_w = self.pool_w(x)
        x_w = self.reduce_channels(x_w)
        x_w = self.bn(x_w)
        x_w = self.fc_w(self.act(x_w))
        x_w = x_w.sigmoid()

        # Fusion of coordinate attention across height and width
        out = identity * x_h * x_w
        return out

## Multi-Scale Fusion Network:

In [3]:
class MultiScaleFusionNetwork(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(MultiScaleFusionNetwork, self).__init__()

        # Convolutions for different patch scales
        self.conv_17 = nn.Conv3d(in_channels, 64, kernel_size=(1, 1, 17), padding=(0, 0, 8))
        self.conv_19 = nn.Conv3d(in_channels, 64, kernel_size=(1, 1, 19), padding=(0, 0, 9))
        self.conv_21 = nn.Conv3d(in_channels, 64, kernel_size=(1, 1, 21), padding=(0, 0, 10))

        # Batch Normalization and Activation
        self.bn_17 = nn.BatchNorm3d(64)
        self.bn_19 = nn.BatchNorm3d(64)
        self.bn_21 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU()

        # Fusion layer to concatenate and combine the multi-scale features
        self.fusion = nn.Conv3d(192, 128, kernel_size=1)  # 64 + 64 + 64 = 192 channels after concatenation
        self.fusion_bn = nn.BatchNorm3d(128)
        
        # Fully connected layers for classification
        self.fc = nn.Sequential(
            nn.Linear(128 * 145 * 145, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        # Apply convolutions for different scales
        x_17 = self.relu(self.bn_17(self.conv_17(x)))
        x_19 = self.relu(self.bn_19(self.conv_19(x)))
        x_21 = self.relu(self.bn_21(self.conv_21(x)))

        # Concatenate the features from all scales
        fused_features = torch.cat([x_17, x_19, x_21], dim=1)

        # Apply fusion convolution and batch normalization
        fused_features = self.relu(self.fusion_bn(self.fusion(fused_features)))

        # Flatten the features and apply the final classification layer
        fused_features = fused_features.view(fused_features.size(0), -1)
        out = self.fc(fused_features)

        return out

## Model with Attention and Multi-scale Fusion:

In [4]:
class HyperspectralClassificationModel(nn.Module):
    def __init__(self, in_channels, num_classes, use_attention=True, use_multi_scale_fusion=True):
        super(HyperspectralClassificationModel, self).__init__()

        # First few layers: simple 3D convolutions
        self.conv1 = nn.Conv3d(in_channels, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(64)
        self.bn2 = nn.BatchNorm3d(128)
        self.relu = nn.ReLU()

        # Optional Coordinate Attention module
        self.use_attention = use_attention
        if self.use_attention:
            self.coord_attention = CoordinateAttention(128, 128)

        # Optional Multi-Scale Fusion module
        self.use_multi_scale_fusion = use_multi_scale_fusion
        if self.use_multi_scale_fusion:
            self.multi_scale_fusion = MultiScaleFusionNetwork(in_channels=128, num_classes=num_classes)

        # Fully connected layers for final classification
        self.fc = nn.Sequential(
            nn.Linear(128 * 145 * 145, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        # Initial convolutions
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))

        # Coordinate Attention (if enabled)
        if self.use_attention:
            x = self.coord_attention(x)

        # Multi-Scale Fusion (if enabled)
        if self.use_multi_scale_fusion:
            x = self.multi_scale_fusion(x)
        else:
            # Flatten the features and apply the final classification layer directly
            x = x.view(x.size(0), -1)
            x = self.fc(x)

        return x

### Dataset Splitting and Training and Testing: 

In [5]:
# Custom Dataset for Hyperspectral Images with Patches
class HyperspectralDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data.astype(np.float16)  # Ensure data is of type float32
        self.labels = labels.astype(np.int32)  # Ensure labels are of type int64

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        patch = self.data[idx]  # Get the patch of data at index idx
        label = self.labels[idx]  # Get the label at index idx
        return torch.tensor(patch, dtype=torch.float16), torch.tensor(label, dtype=torch.long)

In [6]:
# Function to extract patches from hyperspectral images
def extract_patches(img, labels, patch_size):
    margin = patch_size // 2
    img_padded = np.pad(img, [(margin, margin), (margin, margin), (0, 0)], mode='constant')
    
    patches, patch_labels = [], []
    for i in range(margin, img.shape[0] - margin):
        for j in range(margin, img.shape[1] - margin):
            if labels[i, j] != 0:  # Exclude background
                patch = img_padded[i-margin:i+margin+1, j-margin:j+margin+1, :]
                patches.append(patch)
                patch_labels.append(labels[i, j] - 1)  # Convert to zero-based index
    return np.array(patches), np.array(patch_labels)

In [7]:
# Function to calculate metrics
def calculate_metrics(y_true, y_pred, num_classes):
    cm = confusion_matrix(y_true, y_pred, labels=range(num_classes))
    OA = accuracy_score(y_true, y_pred)
    class_acc = cm.diagonal() / cm.sum(axis=1)
    class_acc[np.isnan(class_acc)] = 0  # Handle divide by zero for empty classes
    AA = np.mean(class_acc)
    Kappa = cohen_kappa_score(y_true, y_pred)
    return {"OA": OA, "AA": AA, "Kappa": Kappa}

In [8]:
# Function to split Indian Pines data as per the provided class-wise splits
def load_indian_pines_data_with_splits(patch_size):
    # Load the Indian Pines dataset
    data = sio.loadmat('Indian_pines_corrected.mat')['indian_pines_corrected']
    labels = sio.loadmat('Indian_pines_gt.mat')['indian_pines_gt']
    
    # Define the per-class splits (Train, Val, Test)
    splits = {
        0: (8, 4, 34),
        1: (105, 105, 1218),
        2: (100, 100, 630),
        3: (50, 50, 130),
        4: (80, 80, 323),
        5: (100, 100, 530),
        6: (10, 4, 14),
        7: (40, 40, 398),
        8: (8, 4, 8),
        9: (110, 110, 752),
        10: (110, 110, 2235),
        11: (100, 100, 393),
        12: (40, 40, 125),
        13: (100, 100, 1065),
        14: (80, 80, 226),
        15: (20, 10, 63)
    }
    
    patches_train, labels_train = [], []
    patches_val, labels_val = [], []
    patches_test, labels_test = [], []
    
    # Extract patches and split into train, validation, and test sets
    for class_label, (train_count, val_count, test_count) in splits.items():
        class_indices = np.where(labels == class_label + 1)  # One-based indexing in labels
        class_samples = list(zip(class_indices[0], class_indices[1]))
        np.random.shuffle(class_samples)  # Shuffle indices for randomness
        
        for idx, (i, j) in enumerate(class_samples):
            patch = extract_patches(data, labels, patch_size)[0]  # Extract patch
            
            if idx < train_count:
                patches_train.append(patch)
                labels_train.append(class_label)
            elif idx < train_count + val_count:
                patches_val.append(patch)
                labels_val.append(class_label)
            else:
                patches_test.append(patch)
                labels_test.append(class_label)
    
    # Convert lists to arrays
    patches_train, labels_train = np.array(patches_train), np.array(labels_train)
    patches_val, labels_val = np.array(patches_val), np.array(labels_val)
    patches_test, labels_test = np.array(patches_test), np.array(labels_test)
    
    # Create datasets and dataloaders
    train_dataset = HyperspectralDataset(patches_train, labels_train)
    val_dataset = HyperspectralDataset(patches_val, labels_val)
    test_dataset = HyperspectralDataset(patches_test, labels_test)
    
    train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=10, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False)
    
    return train_loader, val_loader, test_loader, labels_test

In [9]:
# Training function (similar to previous code)
def train_model(train_loader, val_loader, in_channels, num_classes, epochs=20, lr=0.001):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = HyperspectralClassificationModel(in_channels=in_channels, num_classes=num_classes).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.permute(0, 3, 1, 2).to(device), labels.to(device)  # [B, C, H, W]
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        # Validation step
        model.eval()
        val_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.permute(0, 3, 1, 2).to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        val_accuracy = correct / total
        print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {running_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}, Val Accuracy: {val_accuracy:.4f}")
    
    return model

In [10]:
# Testing the model and calculating metrics
def test_model(model, test_loader, y_test, num_classes):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    
    y_pred = []
    with torch.no_grad():
        for inputs, _ in test_loader:
            inputs = inputs.permute(0, 3, 1, 2).to(device)  # [B, C, H, W]
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            y_pred.extend(predicted.cpu().numpy())

    metrics = calculate_metrics(y_test, y_pred, num_classes)
    return metrics

In [11]:
if __name__ == "__main__":
    patch_sizes = [17, 19, 21]  # List of window sizes
    results = {}  # Dictionary to store results for each model and patch size

    for patch_size in patch_sizes:
        print(f"Evaluating window size: {patch_size}x{patch_size}")

        train_loader, val_loader, test_loader, y_test = load_indian_pines_data_with_splits(patch_size)

        # Initialize models for training
        model1 = HyperspectralClassificationModel(
            in_channels=200, 
            num_classes=16, 
            use_attention=False, 
            use_multi_scale_fusion=False
        )

        model2 = HyperspectralClassificationModel(
            in_channels=200, 
            num_classes=16, 
            use_attention=True, 
            use_multi_scale_fusion=False
        )

        model3 = HyperspectralClassificationModel(
            in_channels=200, 
            num_classes=16, 
            use_attention=True, 
            use_multi_scale_fusion=True
        )

        # Train model 1
        train_loss1, val_loss1, val_accuracy1 = train_model(train_loader, val_loader, model1, epochs=100, lr=3e-4)
        results[f'Model 1 - {patch_size}x{patch_size}'] = {
            'Train Loss': train_loss1,
            'Validation Loss': val_loss1,
            'Validation Accuracy': val_accuracy1
        }

        # Train model 2
        train_loss2, val_loss2, val_accuracy2 = train_model(train_loader, val_loader, model2, epochs=100, lr=3e-4)
        results[f'Model 2 - {patch_size}x{patch_size}'] = {
            'Train Loss': train_loss2,
            'Validation Loss': val_loss2,
            'Validation Accuracy': val_accuracy2
        }

        # Train model 3
        train_loss3, val_loss3, val_accuracy3 = train_model(train_loader, val_loader, model3, epochs=100, lr=3e-4)
        results[f'Model 3 - {patch_size}x{patch_size}'] = {
            'Train Loss': train_loss3,
            'Validation Loss': val_loss3,
            'Validation Accuracy': val_accuracy3
        }

        # Print the results for all models for the current patch size
        for model_name, metrics in results.items():
            if patch_size in model_name:  # Print only the results for the current patch size
                print(f"{model_name}:")
                print(f"  Train Loss: {metrics['Train Loss']:.4f}")
                print(f"  Validation Loss: {metrics['Validation Loss']:.4f}")
                print(f"  Validation Accuracy: {metrics['Validation Accuracy']:.4f}")

        # Test each model and calculate metrics
        for model_key in ['Model 1', 'Model 2', 'Model 3']:
            model = eval(model_key.replace(" ", "").lower())  # Use model1, model2, model3
            metrics = test_model(model, test_loader, y_test, 16)
            print(f"Metrics for {model_key} at window size {patch_size}x{patch_size}: {metrics}")


Evaluating window size: 17x17


MemoryError: Unable to allocate 1005. MiB for an array with shape (9119, 17, 17, 200) and data type uint16