In [3]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.utils.data import DataLoader
import datetime
import torchvision
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from d2l import torch as d2l
from sklearn.metrics import confusion_matrix,roc_curve,auc,precision_recall_curve
import seaborn as sns
from torch.optim.lr_scheduler import LambdaLR

In [4]:
# Hyperparameters
num_epochs = 50      # Number of training epochs
batch_size = 128     # Batch size
learning_rate = 0.0001 # Learning rate

In [5]:
# Channel Attention Module
class ChannelAttention(nn.Module):
    """Channel attention module with region-based feature extraction.
    
    Args:
        in_planes (int): Number of input channels
        num_regions (int): Number of regions to split feature map (default: 4)
        pool_type (str): Pooling type - 'avg' or 'max' (default: 'avg')
    """
    def __init__(self, in_planes, num_regions=4, pool_type='avg'):
        super(ChannelAttention, self).__init__()
        
        self.in_planes = in_planes
        self.num_regions = num_regions
        self.pool_type = pool_type
        
        # Region configuration (square regions)
        self.num_region_rows = num_regions
        self.num_region_cols = num_regions
        
        # Pooling layer selection
        if pool_type == 'avg':
            self.pool = nn.AdaptiveAvgPool2d((1, 1))
        elif pool_type == 'max':
            self.pool = nn.AdaptiveMaxPool2d((1, 1))
        else:
            raise ValueError("Invalid pool_type. Choose 'avg' or 'max'.")
        
        # 1D convolutions for each region
        self.conv1d_list = nn.ModuleList([
            nn.Conv1d(1, 1, 1, bias=False)
            for _ in range(num_regions * num_regions)
        ])
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, h, w = x.size()
        
        # Calculate region dimensions
        region_height = h // self.num_region_rows
        region_width = w // self.num_region_cols
        
        # Extract and pool features from each region
        pooled_features = []
        for i in range(self.num_region_rows):
            for j in range(self.num_region_cols):
                x_crop = x[:, :, 
                          i*region_height:(i+1)*region_height, 
                          j*region_width:(j+1)*region_width]
                pooled = self.pool(x_crop)
                pooled_features.append(pooled.squeeze(-1).permute(0,2,1))  
        
        # Process each pooled feature
        processed_features = []
        for i, pooled in enumerate(pooled_features):
            conv_out = self.conv1d_list[i](pooled)
            processed_features.append(conv_out)
            
        # Combine features through summation
        processed_features = torch.sum(torch.stack(processed_features, dim=0), dim=0)
        
        # Compute channel attention weights
        attention_weights = self.sigmoid(processed_features)  
        
        # Reshape to original feature map dimensions
        attention_weights = attention_weights.permute(0,2,1).unsqueeze(-1)
        
        return attention_weights


# Spatial Attention Module  
class SpatialAttention(nn.Module):
    """Spatial attention module with grouped feature processing.
    
    Args:
        in_channels (int): Number of input channels
        groups (int): Number of channel groups (default: 4)
        kernel_size (int): Convolution kernel size (3 or 7, default: 7)
    """
    def __init__(self, in_channels, groups=4, kernel_size=7):  
        super(SpatialAttention, self).__init__()  
  
        self.groups = groups
        self.channels_per_group = in_channels // groups  

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        
        # Convolution layers for each group
        self.conv1d_list = nn.ModuleList([
            nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False)
            for _ in range(groups)
        ])
        self.sigmoid = nn.Sigmoid()  
  
    def forward(self, x):  
        # Process each channel group
        pooled_features = [] 
        for i in range(self.groups):  
            # Extract channel group
            x_crop = x[:, i*self.channels_per_group:(i+1)*self.channels_per_group, :, :]  
            # Channel-wise average pooling
            pooled = torch.mean(x_crop, dim=1, keepdim=True)  
            pooled_features.append(pooled)  
            
        processed_features = []
        for i, pooled in enumerate(pooled_features):
            conv_out = self.conv1d_list[i](pooled)
            processed_features.append(conv_out)
            
        # Combine features through summation
        processed_features = torch.sum(torch.stack(processed_features, dim=0), dim=0)
  
        # Compute spatial attention weights
        attention_weights = self.sigmoid(processed_features)  
  
        return attention_weights


# Convolutional Block Attention Module (CBAM)
class CBAM(nn.Module):
    """CBAM attention module combining channel and spatial attention.
    
    Args:
        in_planes (int): Number of input channels
    """
    def __init__(self, in_planes):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(in_planes)  # Channel attention
        self.sa = SpatialAttention(in_planes)  # Spatial attention

    def forward(self, x):
        # Apply channel attention
        out = x * self.ca(x)
        # Apply spatial attention
        result = out * self.sa(out)
        return result

In [6]:
class RHAB(nn.Module):  #@save
    def __init__(self, input_channels, num_channels,
                 use_1x1conv=True, strides=2):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels,
                               kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels,
                               kernel_size=3, padding=1,stride=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels,
                                   kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)
        self.CBAM = CBAM(num_channels)
        self.gelu = nn.GELU()
        

    def forward(self, X):
        Y = self.gelu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        Y = self.CBAM(Y)
        Y += X
        
        return self.gelu(Y)

In [7]:
class BHBNet(nn.Module):  
    def __init__(self):  
        super(BHBNet, self).__init__()
        self.conv1 = nn.Conv2d(1,64,kernel_size=3,padding=1)
        self.bn = nn.BatchNorm2d(64)
        self.b1 = RHAB(64,128)
        self.b2 = RHAB(128,256)
        self.b3 = RHAB(256,512)
        self.gru = nn.GRU(64,128,3)
        self.gelu = nn.GELU()  
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(128,64)
        self.fc2 = nn.Linear(64,6)

    
    def forward(self, x):  
        x = x.view(x.size(0),1,64,64)
        x = self.conv1(x)
        x = self.bn(x)
        x = self.gelu(x)
        # RHAM
        x = self.b1(x)
        x = self.b2(x)
        x = self.b3(x)
        x = x.view(x.size(0), 512, -1)
        # GRU module  
        x, _ = self.gru(x)    
        x = x[:, -1, :]  
        # output module
        x = x.view(x.size(0),-1)
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x  

In [8]:
# Read spectral data CSV files and add label column
try:
    # Load spectral data for each class
    BHB_spec_df = pd.read_csv(r"D:\ZYH\data\BHB\LAMOST_BHB_spectra.csv")
    Hsd_spec_df = pd.read_csv(r"D:\ZYH\data\Hsd\LAMOST_Hsd_spectra.csv")
    A_spec_df = pd.read_csv(r"D:\ZYH\data\A\LAMOST_A_spectra.csv")
    B_spec_df = pd.read_csv(r"D:\ZYH\data\B\LAMOST_B_spectra.csv")
    
   
    BHB_spec_df['label'] = 1
    Hsd_spec_df['label'] = 0
    A_spec_df['label'] = 0
    B_spec_df['label'] = 0
    
    # Combine all datasets including label column
    All_spec_df = pd.concat([BHB_spec_df, Hsd_spec_df, A_spec_df, B_spec_df], ignore_index=True)
    
except FileNotFoundError as e:
    print(f"Error: File not found - {e}")
    exit()

In [9]:
def setup_seed(seed):
   torch.manual_seed(seed)
   torch.cuda.manual_seed_all(seed)
   np.random.seed(seed)
   random.seed(seed)
   torch.backends.cudnn.deterministic = True
setup_seed(42)

In [10]:
# Class for reading spectral data, returns data and labels
class SpectraDataset(Dataset):
    def __init__(self, specdata):
        # Ensure all columns except the label column are numeric
        self.specdata = specdata

    def __getitem__(self, index):
        # Convert data row to torch tensor
        specdata = torch.tensor(self.specdata.iloc[index, :-1].values, dtype=torch.float32)
        # Get label
        label = self.specdata.iloc[index, -1]
        label = torch.tensor(label, dtype=torch.float32)
        return specdata, label

    def __len__(self):
        # Return number of samples in dataset
        return len(self.specdata)

In [11]:
# Create dataset instance
Spectra_dataset = SpectraDataset(All_spec_df)

# Split into training, validation and test sets
train_dataset, valid_test_dataset = train_test_split(Spectra_dataset, test_size=0.2, random_state=42)  
valid_dataset, test_dataset = train_test_split(train_valid_dataset, test_size=0.5, random_state=42)    

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
print("PyTorch Version: ", torch.__version__)
print("Torchvision Version: ", torchvision.__version__)

# Device configuration - Check if CUDA acceleration is available
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

In [13]:
# Instantiate the model
model = BHBNet()
model.load_state_dict(torch.load(r"D:\ZYH\代码\六分类\单模态_光谱分类模型\save_model\best_six_class.pth"))
model.fc2 = nn.Linear(64, 2)
for name, param in model.named_parameters():  
    if 'fc2' not in name:  
        param.requires_grad = False
# print(model)
model.to(device)

# Set up loss function
loss = nn.CrossEntropyLoss()

# Set up optimizer
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate,weight_decay=5e-2)

# Define polynomial decay function
def poly_decay(epoch, total_epochs=100, power=1):
    return (1 - epoch / total_epochs) ** power
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: poly_decay(epoch, num_epochs, 2))

In [16]:
# Evaluation
def evaluate_accuracy_gpu(net, data_iter, loss, device=None): #@save
    """Compute the model's accuracy on a dataset using GPU."""
    net.eval()  # Set to evaluation mode
    for module in net.modules():
        if module.__class__.__name__.startswith("Dropout"):
            module.train()
    if not device:
        device = next(iter(net.parameters())).device
    fwd_passes = 50
    # Number of correct predictions, total predictions
    metric = d2l.Accumulator(3)
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
                # Required for BERT fine-tuning (to be discussed later)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            predictions = []
            for fwd_pass in range(fwd_passes):
                output = net(X)
                predictions.append(output)
            predictions = torch.stack(predictions, dim=0)
            l = loss(torch.mean(predictions, dim=0), y.long())
            y = y.cpu()
            predictions = F.softmax(predictions, dim=2)
            y_pred = torch.mean(predictions, dim=0)
            y_hat = y_pred.argmax(axis=1).long()    
            y_pred = y_pred.cpu().numpy()
            y_hat = y_hat.cpu()
            cmp = y_hat.type(y.dtype) == y
            metric.add(l * X.shape[0], float(cmp.type(y.dtype).sum()), X.shape[0])
    return metric[0] / metric[2], metric[1] / metric[2]

def evaluate_accuracy_gpu_test(net, data_iter, loss, device=None): #@save
    """Compute the model's accuracy on a dataset using GPU."""
    net.eval()  # Set to evaluation mode
    for module in net.modules():
        if module.__class__.__name__.startswith("Dropout"):
            module.train()
    if not device:
        device = next(iter(net.parameters())).device
    # Number of correct predictions, total predictions
    metric = d2l.Accumulator(3)
    confusion_mat = torch.zeros(6, 6, dtype=torch.int64, device="cpu")
    y_trues = []
    y_preds = []
    y_0_preds = []
    fwd_passes = 50
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
                # Required for BERT fine-tuning (to be discussed later)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            predictions = []
            for fwd_pass in range(fwd_passes):
                output = net(X)
                predictions.append(output)
            predictions = torch.stack(predictions, dim=0)
            y = y.to(device)
            l = loss(torch.mean(predictions, dim=0), y.long())
            y = y.cpu()
            predictions = F.softmax(predictions, dim=2)
            y_pred = torch.mean(predictions, dim=0)
            y_hat = y_pred.argmax(axis=1).long()
            y_pred = y_pred.cpu().numpy()
            y_hat = y_hat.cpu()
            cmp = y_hat.type(y.dtype) == y
            metric.add(l * X.shape[0], float(cmp.type(y.dtype).sum()), X.shape[0])
            y_trues.extend(y)
            y_preds.extend(y_pred)
            y_0_preds.extend(y_pred[:, 1])
            # Update confusion matrix
            for i in range(len(y)):  
                confusion_mat[y[i].long(), y_hat[i].long()] += 1
    test_l = metric[0] / metric[1]            
    accuracy = metric[1] / metric[2]
    precision, recall, f1_score = calculate_metrics(confusion_mat, 2)
    
    class_names = ["Hsd/A/B","BHB"]

    # Create a figure with two subplots  
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))  # 1 row, 2 columns  
  
    # Precision-Recall curve
    precision_list = []    
    recall_list = []    
    for i, name in enumerate(class_names):  
        y_pred = [p[i] for p in y_preds]  
        y_true = np.array([int(t == i) for t in y_trues])  # Convert labels to binary form  
        p, r, _ = precision_recall_curve(y_true, y_pred)  
        precision_list.append(p)    
        recall_list.append(r)   
        axs[0].plot(r, p, label=f'{name} (area = {auc(r, p):0.2f})')  
    axs[0].set_xlabel('Recall', fontsize=14)  
    axs[0].set_ylabel('Precision', fontsize=14)  
    axs[0].legend(loc="lower left", fontsize=14)
    axs[0].tick_params(labelsize=14)
  
    # Confusion matrix  
    sns.heatmap(confusion_mat, annot=True, fmt='d', cmap='Purples', ax=axs[1])  
    axs[1].set_xlabel('Predicted labels', fontsize=14)  
    axs[1].set_ylabel('True labels', fontsize=14)  
    axs[1].set_xticks([0.5, 1.5])
    axs[1].set_yticks([0.5, 1.5])
    axs[1].set_xticklabels(class_names, fontsize=14)
    axs[1].set_yticklabels(class_names, fontsize=14, rotation=0)
  
    plt.tight_layout()  

    plt.savefig(r"D:\ZYH\code\bianry\figure\combined_plot.png", dpi=600, bbox_inches='tight')
    
    return accuracy, precision, recall, f1_score

def calculate_metrics(confusion_mat, num_classes):
    precision = torch.zeros(num_classes)
    recall = torch.zeros(num_classes)
    f1_score = torch.zeros(num_classes)

    for i in range(num_classes):
        tp = confusion_mat[i, i]
        fp = torch.sum(confusion_mat[i, :]) - tp
        fn = torch.sum(confusion_mat[:, i]) - tp
        precision[i] = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall[i] = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1_score[i] = 2 * (precision[i] * recall[i]) / (precision[i] + recall[i]) if (precision[i] + recall[i]) > 0 else 0

    return precision, recall, f1_score

In [None]:
# Training on the training set and validation on the validation set
strattime = datetime.datetime.now()
for epoch in range(num_epochs):
    metric = d2l.Accumulator(3)
    model.train()
    for i, (X,y) in enumerate(train_loader):
        optimizer.zero_grad()
        X, y = X.to(device), y.to(device)
        y_hat = model(X)
        l= loss(y_hat, y.long())
        l.backward()
        optimizer.step()
        with torch.no_grad():
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
        train_l = metric[0] / metric[2]
        train_acc = metric[1] / metric[2]
    valid_l, valid_acc = evaluate_accuracy_gpu(model, valid_loader, loss)
    scheduler.step()
    if epoch == 0:
        best_valid_l = valid_l
        torch.save(model.state_dict(), r"D:\ZYH\code\binary\model\binary_fc.pth")
    if epoch > 0 and best_valid_l > valid_l: 
        best_valid_l = valid_l
        torch.save(model.state_dict(), r"D:\ZYH\code\binary\model\binary_fc.pth")
    print(train_l,valid_l)

In [None]:
# Evalution on the test set
model.load_state_dict(torch.load(r"D:\ZYH\codebinary\save_model\binary_fc_best.pth"))
test_acc, test_pre, test_rec, test_f1 = evaluate_accuracy_gpu_test(model,test_loader, loss)
for i in range(2):
    print(f'Precision {test_pre[i]:.4f}, Recall {test_rec[i]:.4f}, F_1 score {test_f1[i]:.4f}')
print(f'Total Accuracy {test_acc:.4f}')