<a href="https://colab.research.google.com/github/Judyxyang/judyxyang/blob/master/HSi_UH2013_P7_AB_VIM_V3_6_0330.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# HyperMamba Model

In [1]:
pip install spectral mat73  einops

Collecting spectral
  Downloading spectral-0.23.1-py3-none-any.whl (212 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.9/212.9 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting mat73
  Downloading mat73-0.63-py3-none-any.whl (19 kB)
Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: spectral, einops, mat73
Successfully installed einops-0.7.0 mat73-0.63 spectral-0.23.1


In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
import math

from einops import rearrange
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from scipy import io
import torch.utils.data
import scipy.io as sio
import mat73
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 0 Upload Data

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
! ls '/content/drive/MyDrive/A02_RemoteSensingData/UHS_2013_DFTC/'

 2013_DFTC
 2013_IEEE_GRSS_DF_Contest_CASI_349_1905_144.mat
 2013_IEEE_GRSS_DF_Contest_LiDAR.mat
 ablationp5_spatialonly_UH2013_model_state_dict.pth
 ablationp7_removeZpath_Uh2013_model_state_dict.pth
 ablationp7_removeZpath_UH2013_model_state_dict.pth
 ablationp7_spatialonly_Uh2013_model_state_dict.pth
 ablationp7_spatialonly_UH2013_model_state_dict.pth
 ablationp7_Uh2013_model_state_dict.pth
 ablationp7_UH2013_model_state_dict.pth
 ablation_Uh2013_model_state_dict.pth
 ablation_UH2013_model_state_dict.pth
 Autoencodermodel.pth
 Autoencodermodel_uh2013_20Ksample.pth
 Autoencodermodel_uh2013_adam_20Ksample.pth
 Autoencodermodel_uh2013_adamp13_20Ksample.pth
 Autoencodermodel_uh2013_adamp9_20Ksample.pth
 Autoencodermodel_uh2013_admp3fuseddata_50Ksample.pth
 Autoencodermodel_uh2013_admp5fuseddata_50Ksample.pth
 Autoencodermodel_uh2013_admp7fuseddata_50Ksample.pth
 Autoencodermodel_uh2013.pth
 Autoencodermodel_uh2013_rms_20Ksample.pth
 Autoencodermodel_uh2013_sgdp11fuseddata_20Ksample.pth


In [5]:
# # Define the path
path='/content/drive/MyDrive/A02_RemoteSensingData/UHS_2013_DFTC/'

In [6]:
# 2.1 Loads Data
# Load hyperpsectral data
hsi_2013_data=sio.loadmat(path+'2013_IEEE_GRSS_DF_Contest_CASI_349_1905_144.mat')['ans']
print('hsi_2013_data shape:', hsi_2013_data.shape)

# Loader Lidar  data
import mat73
lidar_2013_data = sio.loadmat(path+'2013_IEEE_GRSS_DF_Contest_LiDAR.mat')['LiDAR_data']

print('Lidar_2013_data shape:', lidar_2013_data.shape)

#Load ground truth labels
gt_2013_data=sio.loadmat(path+'GRSS2013.mat')['name']
print('gt_2013_data.shape:', gt_2013_data.shape)

hsi_2013_data shape: (349, 1905, 144)
Lidar_2013_data shape: (349, 1905, 1)
gt_2013_data.shape: (349, 1905)


#1.0  Model Building

In [7]:
# Configuration class
class Config:
    def __init__(self, in_channels, num_patches, kernel_size, patch_size, emb_size, dim, depth, heads, dim_head, mlp_ratio, num_classes, dropout, pos_emb_size, class_emb_size, stride, output_dim):  # Set default output_dim to 1
        self.in_channels = in_channels
        self.num_patches = num_patches
        self.kernel_size = kernel_size
        self.patch_size = patch_size
        self.emb_size = emb_size
        self.dim = dim
        self.depth = depth
        self.heads = heads
        self.dim_head = dim_head
        self.mlp_ratio = mlp_ratio
        self.num_classes = num_classes
        self.dropout = dropout
        self.pos_emb_size = pos_emb_size
        self.class_emb_size = class_emb_size
        self.stride = stride
        self.output_dim = output_dim  # Ensure output_dim is a part of the config

### 1.1 Full Architecture Of Forward backward Processing

In [8]:
# Version 2.0 This involves reversing the input tensor for the backward path before applying the backward_conv1d operation
import torch
import torch.nn as nn
import torch.nn.functional as F

class HSIVimBlock(nn.Module):
    def __init__(self, spatial_dim, num_bands, hidden_dim, output_dim, delta_param_init):
        super(HSIVimBlock, self).__init__()
        # Initialization with self.hidden_dim
        self.spatial_dim = spatial_dim
        self.num_bands = num_bands
        self.hidden_dim = hidden_dim

        # LayerNorm is now expecting a flattened feature vector of Bands*H*W elements
        self.norm = nn.LayerNorm(num_bands * spatial_dim * spatial_dim)

        # Adjust linear layers according to the new input dimension
        self.linear_x = nn.Linear(num_bands * spatial_dim * spatial_dim, hidden_dim)
        self.linear_z = nn.Linear(num_bands * spatial_dim * spatial_dim, hidden_dim)

        self.forward_conv1d = nn.Conv1d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=3, padding=1)
        self.backward_conv1d = nn.Conv1d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=3, padding=1)

        self.A = nn.Parameter(torch.randn(hidden_dim, hidden_dim))
        self.B = nn.Parameter(torch.randn(hidden_dim, hidden_dim))
        #self.C = nn.Parameter(torch.randn(output_dim, hidden_dim))
        self.delta_param = nn.Parameter(torch.full((hidden_dim,), delta_param_init))

        self.linear_forward = nn.Linear(hidden_dim, output_dim)
        self.linear_backward = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        Batch, H, W, Bands = x.shape  # Correct shape extraction assuming [Batch, Height, Width, Bands]

        # Correctly reshape for LayerNorm to flatten all spatial and spectral information
        x = x.reshape(Batch, -1)  # New shape: [Batch, Bands*H*W]

        # Normalize across the flattened spatial-spectral data
        x = self.norm(x)

        # Projection to hidden dimensions
        x_proj = self.linear_x(x)
        z_proj = self.linear_z(x)

        # Ensure correct reshaping for Conv1d compatibility
        x_proj = x_proj.view(Batch, self.hidden_dim, -1)
        z_proj = z_proj.view(Batch, self.hidden_dim, -1)

        # Reverse z_proj for the backward path
        z_proj_reversed = torch.flip(z_proj, dims=[-1])

        # Bidirectional Conv1d processing using reversed input for the backward path
        x_forward = F.silu(self.forward_conv1d(x_proj))
        x_backward = F.silu(self.backward_conv1d(z_proj_reversed))

        # Apply delta parameter correctly
        delta_expanded = self.delta_param.unsqueeze(0).unsqueeze(2)  # Correct shape for broadcasting

        # SSM processing with delta applied, using the original and reversed inputs for forward and backward paths respectively
        forward_ssm_output = torch.tanh(self.forward_conv1d(x_proj) + self.A * delta_expanded)
        backward_ssm_output = torch.tanh(self.backward_conv1d(z_proj_reversed) + self.B * delta_expanded)

        # Combine forward and backward outputs into a single representation
        forward_reduced = forward_ssm_output.mean(dim=2)
        backward_reduced = backward_ssm_output.mean(dim=2)

        # Combine the reduced forward and backward paths
        y_forward = self.linear_forward(forward_reduced)
        y_backward = self.linear_backward(backward_reduced)

        # Element-wise sum of forward and backward outputs
        y_combined = y_forward + y_backward

        # Return the combined output
        return y_combined


### 1.2 S[atialFeature processing

In [9]:
# New version
import torch
import torch.nn as nn
import torch.nn.functional as F

class SpatialFeatureProcessing(nn.Module):
    def __init__(self, input_channels):
        super(SpatialFeatureProcessing, self).__init__()
        self.conv_layers = nn.Sequential(
            # First convolutional layer with dilation rate of 1 (standard convolution)
            nn.Conv2d(in_channels=input_channels, out_channels=256, kernel_size=(3, 3), padding=1, dilation=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            # Second convolutional layer with a higher dilation rate to increase the receptive field
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), padding=2, dilation=2),  # Note the increased padding to maintain the spatial dimensions
            nn.ReLU(),
            nn.BatchNorm2d(512)
        )
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))  # Adding global average pooling

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.global_avg_pool(x)  # Apply global average pooling
        x = torch.flatten(x, start_dim=1)  # Flatten all dimensions except batch
        return x


### 1.3 Classifier

In [10]:
class Classifier(nn.Module):
    def __init__(self, in_features, num_classes):
        super(Classifier, self).__init__()
        self.fc_layers = nn.Sequential(
            nn.Linear(in_features=in_features, out_features=1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(in_features=1024, out_features=num_classes),
        )

    def forward(self, x):
        x = self.fc_layers(x)
        # Remove softmax here if you're using a loss function that includes it, such as nn.CrossEntropyLoss
        return x


###1.4 Integrated into Main Model

In [11]:
class HSIClassificationMambaModel(nn.Module):
    def __init__(self, spatial_dim, num_bands, hidden_dim, output_dim, delta_param_init, num_classes):
        super(HSIClassificationMambaModel, self).__init__()
        self.vim_block = HSIVimBlock(spatial_dim, num_bands, hidden_dim, output_dim, delta_param_init)
        self.output_dim = output_dim  # Save output_dim as an attribute of the class

        # Initialize SpatialFeatureProcessing and Classifier here
        # Adjusted to pass 'output_dim' as 'input_channels' to SpatialFeatureProcessing
        self.spatial_processing = SpatialFeatureProcessing(input_channels=output_dim)
        # Assuming the output of SpatialFeatureProcessing matches the in_features expected by Classifier
        self.classifier = Classifier(in_features=512, num_classes=num_classes)

    def forward(self, x):
        x = self.vim_block(x)
        # This is a placeholder. Actual reshaping depends on the output of HSIVimBlock and the input expectation of SpatialFeatureProcessing
        x = x.view(-1, self.output_dim, 1, 1)  # Reshape to include spatial dimensions if needed
        x = self.spatial_processing(x)

        # Flatten the output from spatial processing if it's not already flat
        x = torch.flatten(x, start_dim=1)

        x = self.classifier(x)
        return x


# Instance the Model

In [12]:

# Instantiate the model
model = HSIClassificationMambaModel(
    spatial_dim=7,
    num_bands=144,
    hidden_dim=256,
    output_dim=128,  # Make sure this matches the actual output of your HSIVimBlock
    delta_param_init=0.01,
    num_classes=15
)

# Print the model architecture
print(model)


HSIClassificationMambaModel(
  (vim_block): HSIVimBlock(
    (norm): LayerNorm((7056,), eps=1e-05, elementwise_affine=True)
    (linear_x): Linear(in_features=7056, out_features=256, bias=True)
    (linear_z): Linear(in_features=7056, out_features=256, bias=True)
    (forward_conv1d): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
    (backward_conv1d): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
    (linear_forward): Linear(in_features=256, out_features=128, bias=True)
    (linear_backward): Linear(in_features=256, out_features=128, bias=True)
  )
  (spatial_processing): SpatialFeatureProcessing(
    (conv_layers): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2))
      (4): ReLU()
      (5): BatchNorm2d(512, 

### Optional 1   Data Preparation

In [13]:
# 2.1 Define the class information
class_info = [(1, "Healthy grass", 'training_sample', 198, 'test_sample', 1053,  'total', 1251),
    (2, "Stressed grass",'training_sample', 190, 'test_sample', 1064,  'total', 1254),
    (3, "Synthetic grass", 'training_sample', 192, 'test_sample', 505,  'total', 697),
    (4, "Trees", 'training_sample', 188, 'test_sample', 1058,  'total', 1244),
    (5, "Soil",'training_sample', 186, 'test_sample', 1056,  'total', 1242),
    (6, "Water", 'training_sample', 182, 'test_sample', 141,  'total', 325),
    (7, "Residential", 'training_sample', 196, 'test_sample', 1072,  'total', 1268),
    (8, "Commercial", 'training_sample', 191, 'test_sample', 1053,  'total', 1244),
    (9, "Road", 'training_sample', 193, 'test_sample', 1059,  'total', 1252),
    (10, "Highway", 'training_sample', 191, 'test_sample', 1036,  'total', 1227),
    (11, "Railway", 'training_sample', 181, 'test_sample', 1054,  'total', 1235),
    (12, "Parking lot 1", 'training_sample', 192, 'test_sample', 1041,  'total', 1233),
    (13, "Parking lot 2", 'training_sample', 184, 'test_sample',285,  'total', 469),
    (14, "Tennis court",'training_sample', 181, 'test_sample', 247,  'total', 428),
    (15, "Running track", 'training_sample', 187, 'test_sample', 473,  'total', 660)]

# Create a dictionary to store class number, class name, and class samples
class_dict = {class_number: {"class_name": class_name,
                             'training_sample': training_sample,
                             'test_sample': test_sample,
                             "total_samples": total}
              for class_number, class_name, _, training_sample, _, test_sample, _, total in class_info}

print(class_dict)


{1: {'class_name': 'Healthy grass', 'training_sample': 198, 'test_sample': 1053, 'total_samples': 1251}, 2: {'class_name': 'Stressed grass', 'training_sample': 190, 'test_sample': 1064, 'total_samples': 1254}, 3: {'class_name': 'Synthetic grass', 'training_sample': 192, 'test_sample': 505, 'total_samples': 697}, 4: {'class_name': 'Trees', 'training_sample': 188, 'test_sample': 1058, 'total_samples': 1244}, 5: {'class_name': 'Soil', 'training_sample': 186, 'test_sample': 1056, 'total_samples': 1242}, 6: {'class_name': 'Water', 'training_sample': 182, 'test_sample': 141, 'total_samples': 325}, 7: {'class_name': 'Residential', 'training_sample': 196, 'test_sample': 1072, 'total_samples': 1268}, 8: {'class_name': 'Commercial', 'training_sample': 191, 'test_sample': 1053, 'total_samples': 1244}, 9: {'class_name': 'Road', 'training_sample': 193, 'test_sample': 1059, 'total_samples': 1252}, 10: {'class_name': 'Highway', 'training_sample': 191, 'test_sample': 1036, 'total_samples': 1227}, 11: 

In [14]:
# 2.2 Samples Extraction

# Define patch size and stride
patch_size = 7
stride = 1

# Create an empty list to store patches and labels
hsi_samples = []
lidar_samples = []
labels = []

# Initialize a dictionary to store class count
class_count = {i: 0 for i in class_dict.keys()}

# Function to check if all classes have the required number of samples
def all_classes_completed(class_count, class_dict):
    return all(class_count[class_num] == class_dict[class_num]["total_samples"] for class_num in class_dict.keys())

while not all_classes_completed(class_count, class_dict):
    # Loop through the ground truth data
    for label in class_dict.keys():
        # Get the coordinates of the ground truth pixels
        #coords = np.argwhere((gt_2013_data == label) & (mask > 0))
        coords = np.argwhere(gt_2013_data == label)

        # Shuffle the coordinates to randomize the patch extraction
        np.random.shuffle(coords)

        for coord in coords:
            i, j = coord
            # Calculate the patch indices
            i_start, i_end = i - patch_size // 2, i + patch_size // 2 + 1
            j_start, j_end = j - patch_size // 2, j + patch_size // 2 + 1

            # Check if the indices are within the bounds of the HSI data
            if i_start >= 0 and i_end <= hsi_2013_data.shape[0] and j_start >= 0 and j_end <= hsi_2013_data.shape[1]:
                # Extract the patch
                hsi_patch = hsi_2013_data[i_start:i_end, j_start:j_end, :]

                # Extract the LiDAR patch
                lidar_patch = lidar_2013_data[i_start:i_end, j_start:j_end, :]

                # If the class count is less than the required samples
                if class_count[label] < class_dict[label]["total_samples"]:
                    # Append the patch and its label to the list
                    hsi_samples.append(hsi_patch)
                    lidar_samples.append(lidar_patch)
                    labels.append(label)
                    class_count[label] += 1

                    # If all classes have the required number of samples, exit the loop
                    if all_classes_completed(class_count, class_dict):
                        break

# Convert the list of patches and labels into arrays
hsi_samples = np.array(hsi_samples)
lidar_samples = np.array(lidar_samples)
labels = np.array(labels)
print('hsi_samples shape:', hsi_samples.shape)
print('lidar_samples shape:', lidar_samples.shape)
print('labels shape:', labels.shape)

hsi_samples shape: (15029, 7, 7, 144)
lidar_samples shape: (15029, 7, 7, 1)
labels shape: (15029,)


In [15]:
# Create training_samples_dict based on class_dict
training_samples_dict = {class_num: class_info["training_sample"] for class_num, class_info in class_dict.items()}

# Assuming `hsi_samples`, `lidar_samples`, and `labels` have been previously defined
# Convert the list of patches and labels into arrays if they aren't already
hsi_samples = np.array(hsi_samples)
lidar_samples = np.array(lidar_samples)
labels = np.array(labels)

# Create lists to store training and test samples and labels
hsi_training_samples, lidar_training_samples, training_labels = [], [], []
hsi_test_samples, lidar_test_samples, test_labels = [], [], []

# Split samples into training and test sets based on the desired number of training samples
for label, train_samples in training_samples_dict.items():
    # Get indices of the current class
    class_indices = np.where(labels == label)[0]

    # Randomly shuffle the indices
    np.random.shuffle(class_indices)

    # Split the indices into training and test set indices
    train_indices = class_indices[:train_samples]
    test_indices = class_indices[train_samples:]

    # Add training samples and labels for the current class
    hsi_training_samples.extend(hsi_samples[train_indices])
    lidar_training_samples.extend(lidar_samples[train_indices])
    training_labels.extend(labels[train_indices])

    # Add test samples and labels for the current class
    hsi_test_samples.extend(hsi_samples[test_indices])
    lidar_test_samples.extend(lidar_samples[test_indices])
    test_labels.extend(labels[test_indices])

# Convert lists back to numpy arrays
hsi_training_samples = np.array(hsi_training_samples)
lidar_training_samples = np.array(lidar_training_samples)
training_labels = np.array(training_labels)

hsi_test_samples = np.array(hsi_test_samples)
lidar_test_samples = np.array(lidar_test_samples)
test_labels = np.array(test_labels)

# Print shapes to verify
print('hsi_training_samples shape:', hsi_training_samples.shape)
print('lidar_training_samples shape:', lidar_training_samples.shape)
print('training_labels shape:', training_labels.shape)

print('hsi_test_samples shape:', hsi_test_samples.shape)
print('lidar_test_samples shape:', lidar_test_samples.shape)
print('test_labels shape:', test_labels.shape)


hsi_training_samples shape: (2832, 7, 7, 144)
lidar_training_samples shape: (2832, 7, 7, 1)
training_labels shape: (2832,)
hsi_test_samples shape: (12197, 7, 7, 144)
lidar_test_samples shape: (12197, 7, 7, 1)
test_labels shape: (12197,)


In [22]:
import numpy as np
from scipy.ndimage import rotate

def augment_training_data(hsi_training_data, lidar_training_data, training_labels, rotations=[45, 90, 135], flip_up_down=True, flip_left_right=True):
    augmented_hsi = []
    augmented_lidar = []
    augmented_labels = []

    for hsi, lidar, label in zip(hsi_training_data, lidar_training_data, training_labels):
        # Original data
        augmented_hsi.append(hsi)
        augmented_lidar.append(lidar)
        augmented_labels.append(label)

        # Rotations
        for angle in rotations:
            hsi_rotated = rotate(hsi, angle, axes=(0, 1), reshape=False, mode='nearest')
            lidar_rotated = rotate(lidar, angle, axes=(0, 1), reshape=False, mode='nearest')

            augmented_hsi.append(hsi_rotated)
            augmented_lidar.append(lidar_rotated)
            augmented_labels.append(label)

        # Flip up-down
        if flip_up_down:
            hsi_flipped_ud = np.flipud(hsi)
            lidar_flipped_ud = np.flipud(lidar)

            augmented_hsi.append(hsi_flipped_ud)
            augmented_lidar.append(lidar_flipped_ud)
            augmented_labels.append(label)

        # Flip left-right
        if flip_left_right:
            hsi_flipped_lr = np.fliplr(hsi)
            lidar_flipped_lr = np.fliplr(lidar)

            augmented_hsi.append(hsi_flipped_lr)
            augmented_lidar.append(lidar_flipped_lr)
            augmented_labels.append(label)

    return np.array(augmented_hsi), np.array(augmented_lidar), np.array(augmented_labels)

# Augmenting the training samples
augmented_hsi_training_samples, augmented_lidar_training_samples, augmented_training_labels = augment_training_data(hsi_training_samples, lidar_training_samples, training_labels)

# Print shapes to verify the augmented training data
print('Augmented HSI training samples shape:', augmented_hsi_training_samples.shape)
print('Augmented LiDAR training samples shape:', augmented_lidar_training_samples.shape)
print('Augmented training labels shape:', augmented_training_labels.shape)

Augmented HSI training samples shape: (16992, 7, 7, 144)
Augmented LiDAR training samples shape: (16992, 7, 7, 1)
Augmented training labels shape: (16992,)


In [24]:
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split

# Split the augmented training data into training, validationsets
X_train, X_val, y_train, y_val = train_test_split(
    augmented_hsi_training_samples, augmented_training_labels, test_size=0.1, random_state=42, stratify=augmented_training_labels
)
X_test=hsi_test_samples
y_test=test_labels

print('X_train shape:', X_train.shape)
print('X_train_val shape:', X_val.shape)
print('y_train shape:', y_train.shape)

print('X_test shape:', X_test.shape)
print('y_test shape:', y_test.shape)


# Convert the splitted datasets to tensor datasets
train_dataset = TensorDataset(torch.tensor(X_train.astype(np.float32)), torch.tensor(y_train).long())
val_dataset = TensorDataset(torch.tensor(X_val.astype(np.float32)), torch.tensor(y_val).long())
test_dataset = TensorDataset(torch.tensor(X_test.astype(np.float32)), torch.tensor(y_test).long())

# Create DataLoader instances for training, validation, and testing
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)


X_train shape: (15292, 7, 7, 144)
X_train_val shape: (1700, 7, 7, 144)
y_train shape: (15292,)
X_test shape: (12197, 7, 7, 144)
y_test shape: (12197,)


# 5.0 Training Model Memeory and Time calcualtion

In [23]:
# Before the training loop, to record the initial memory usage (GPU)
if torch.cuda.is_available():
    torch.cuda.reset_peak_memory_stats()  # Reset peak memory stats at the start
    initial_memory = torch.cuda.memory_allocated()
    print(f"Initial Memory Allocated: {initial_memory / 1e6} MB")

Initial Memory Allocated: 45.670912 MB


### 5.1 Training Model for complete forward and backward archtoeture

In [25]:
# Training Model in GPU
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import copy
import time  # Step 1: Import the time module

model = HSIClassificationMambaModel(
    spatial_dim=7, num_bands=144, hidden_dim=256, output_dim=128, delta_param_init=0.01, num_classes=15
).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.00001)

epochs = 50
best_val_loss = float('inf')
best_model_wts = copy.deepcopy(model.state_dict())
patience = 10

start_time = time.time()  # Step 2: Record the start time

for epoch in range(epochs):
    model.train()
    running_train_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.cuda(), labels.cuda() # Move the data into CUDA
        optimizer.zero_grad()
        labels -= 1
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_train_loss += loss.item()

    epoch_train_loss = running_train_loss / len(train_loader.dataset)

    model.eval()
    val_running_loss = 0.0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.cuda(), labels.cuda() # Move the data into CUDA
            outputs = model(inputs)
            labels -= 1
            loss = criterion(outputs, labels)
            val_running_loss += loss.item()

        epoch_val_loss = val_running_loss / len(val_loader.dataset)

    print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}')

    if epoch_val_loss < best_val_loss:
        print(f'Validation Loss Decreased({best_val_loss:.6f}--->{epoch_val_loss:.6f}) \t Saving The Model')
        best_val_loss = epoch_val_loss
        best_model_wts = copy.deepcopy(model.state_dict())
        no_improve_epochs = 0
    else:
        no_improve_epochs += 1

    if no_improve_epochs > patience:
        print('Early stopping!')
        model.load_state_dict(best_model_wts)
        break

end_time = time.time()  # Step 3: Record the end time
total_time = end_time - start_time  # Step 4: Calculate total training time

print(f'Finished training. Total training time: {total_time:.2f} seconds')  # Print the total training time

Epoch [1/50], Train Loss: 0.0482, Val Loss: 0.0288
Validation Loss Decreased(inf--->0.028779) 	 Saving The Model
Epoch [2/50], Train Loss: 0.0250, Val Loss: 0.0177
Validation Loss Decreased(0.028779--->0.017722) 	 Saving The Model
Epoch [3/50], Train Loss: 0.0175, Val Loss: 0.0139
Validation Loss Decreased(0.017722--->0.013879) 	 Saving The Model
Epoch [4/50], Train Loss: 0.0137, Val Loss: 0.0113
Validation Loss Decreased(0.013879--->0.011335) 	 Saving The Model
Epoch [5/50], Train Loss: 0.0115, Val Loss: 0.0088
Validation Loss Decreased(0.011335--->0.008838) 	 Saving The Model
Epoch [6/50], Train Loss: 0.0097, Val Loss: 0.0074
Validation Loss Decreased(0.008838--->0.007444) 	 Saving The Model
Epoch [7/50], Train Loss: 0.0082, Val Loss: 0.0099
Epoch [8/50], Train Loss: 0.0070, Val Loss: 0.0057
Validation Loss Decreased(0.007444--->0.005673) 	 Saving The Model
Epoch [9/50], Train Loss: 0.0063, Val Loss: 0.0060
Epoch [10/50], Train Loss: 0.0056, Val Loss: 0.0044
Validation Loss Decreased

In [26]:
# Before the training loop, to record the initial memory usage (GPU)
if torch.cuda.is_available():
    torch.cuda.reset_peak_memory_stats()  # Reset peak memory stats at the start
    initial_memory = torch.cuda.memory_allocated()
    print(f"Initial Memory Allocated: {initial_memory / 1e6} MB")

Initial Memory Allocated: 168.244736 MB


### Save the modle

In [27]:
# Assuming 'model' is your instance of HSIClassificationModel or any other model
# and it's been trained
torch.save(model.state_dict(), path+'p7_UH2013_model_state_dict.pth')


### Claculte th test time

In [28]:
import torch
import numpy as np
from sklearn.metrics import confusion_matrix, accuracy_score, cohen_kappa_score
import time  # Import the time module for timing the test phase

# Assuming 'model' is your instance of HSIClassificationModel or any other model
# and it's been trained

# Save the model
model_save_path = path+ 'p7_Uh2013_model_state_dict.pth'
torch.save(model.state_dict(), model_save_path)
print(f'Model saved to {model_save_path}')

# Load the model (make sure to initialize the model architecture first)
model.load_state_dict(torch.load(model_save_path))
model.to(device)

# Ensure the model is in evaluation mode
model.eval()

# Store predictions and actual labels
predictions = []
actual_labels = []

start_time = time.time()  # Start timing

with torch.no_grad():
    for hsi_patches, labels in test_loader:
        # Move data to the appropriate device
        hsi_patches = hsi_patches.to(device)
        labels -= 1  # Adjust labels if necessary

        # Forward pass
        outputs = model(hsi_patches)

        # Get predictions
        _, predicted = torch.max(outputs, 1)
        predictions.extend(predicted.cpu().numpy())
        actual_labels.extend(labels.cpu().numpy())

end_time = time.time()  # End timing
test_time = end_time - start_time  # Calculate the test time

# Optionally, calculate accuracy or other metrics using predictions and actual_labels

# Convert lists to NumPy arrays for easier manipulation
predictions_array = np.array(predictions)
actual_labels_array = np.array(actual_labels)

# Overall Accuracy
oa = accuracy_score(actual_labels_array, predictions_array)

# Confusion Matrix
cm = confusion_matrix(actual_labels_array, predictions_array)
# Calculate per-class accuracy from the confusion matrix
class_accuracy = cm.diagonal() / cm.sum(axis=1)
# Average Accuracy
aa = np.mean(class_accuracy)

# Kappa Coefficient
kappa = cohen_kappa_score(actual_labels_array, predictions_array)

print(f'Overall Accuracy (OA): {oa:.4f}')
print(f'Average Accuracy (AA): {aa:.4f}')
print(f'Kappa Coefficient: {kappa:.4f}')
print(f'Test time: {test_time:.2f} seconds')  # Print the test time

Model saved to /content/drive/MyDrive/A02_RemoteSensingData/UHS_2013_DFTC/p7_Uh2013_model_state_dict.pth
Overall Accuracy (OA): 0.9656
Average Accuracy (AA): 0.9724
Kappa Coefficient: 0.9626
Test time: 1.43 seconds


In [29]:
for i, acc in enumerate(class_accuracy): print(f'Class {i+1} Accuracy: {acc:.4f}')


Class 1 Accuracy: 0.9962
Class 2 Accuracy: 0.9925
Class 3 Accuracy: 1.0000
Class 4 Accuracy: 0.9792
Class 5 Accuracy: 0.9991
Class 6 Accuracy: 1.0000
Class 7 Accuracy: 0.9757
Class 8 Accuracy: 0.9839
Class 9 Accuracy: 0.9235
Class 10 Accuracy: 0.9701
Class 11 Accuracy: 0.8397
Class 12 Accuracy: 0.9472
Class 13 Accuracy: 0.9789
Class 14 Accuracy: 1.0000
Class 15 Accuracy: 1.0000
