In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, r2_score, mean_squared_error
import seaborn as sns
import torch.nn.functional as F
import copy
import random
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed) 
    torch.cuda.manual_seed_all(seed) 

    torch.backends.cudnn.deterministic = True  
    torch.backends.cudnn.benchmark = False 

seed = 1
set_seed(seed)

csv_path = '/data/home/wxl22/Classification_of_Varieties/process_utils/processed_spectra.csv' 
df = pd.read_csv(csv_path)


spectra_data = df['data'].apply(eval).to_numpy()  
labels = df['label'].to_numpy()  


X = np.array(list(spectra_data)) 
y = labels

label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)  


label_mapping = dict(zip(label_encoder.classes_, label_encoder.transform(label_encoder.classes_)))
print("Label Mapping: ", label_mapping)
print("Number of classes:", len(label_mapping))


scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)


X_train, X_test, y_train, y_test = train_test_split(
    X_scaled, y_encoded, test_size=0.2, random_state=seed, stratify=y_encoded
)


X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)


class SpectraDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


train_dataset = SpectraDataset(X_train, y_train)
test_dataset = SpectraDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost, init_embeddings=None):
        super(VectorQuantizer, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost

        self.embeddings = nn.Embedding(self.num_embeddings, self.embedding_dim)
        if init_embeddings is not None:
            self.embeddings.weight.data.copy_(torch.from_numpy(init_embeddings))
        else:
            self.embeddings.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)

    def forward(self, inputs):

        inputs = inputs.permute(0, 2, 1).contiguous()  
        input_shape = inputs.shape
        flat_input = inputs.view(-1, self.embedding_dim)  

        distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
                     + torch.sum(self.embeddings.weight**2, dim=1)
                     - 2 * torch.matmul(flat_input, self.embeddings.weight.t()))  


        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) 
        encodings = torch.zeros(encoding_indices.size(0), self.num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1) 

    
        quantized = torch.matmul(encodings, self.embeddings.weight).view(input_shape) 


        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        quantized = inputs + (quantized - inputs).detach()

        quantized = quantized.permute(0, 2, 1).contiguous()  

        return quantized, loss


class ImportanceSplitNet(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()

        self.conv_multi = nn.Conv1d(embedding_dim * 6, embedding_dim * 2, kernel_size=3, padding=1)
        

        self.gate_fc = nn.Linear(embedding_dim * 2, embedding_dim * 2)
        
    def forward(self, x):
        """
        x: (B, 6*embedding_dim, L)
           —— 假设在外部已经将三路输出拼接得到 embedding_dim*6 个通道
        """

        out = self.conv_multi(x)
        
        gating_input = F.adaptive_avg_pool1d(out, output_size=1).squeeze(-1)

        gating_score = torch.sigmoid(self.gate_fc(gating_input))

        gating_score = gating_score.unsqueeze(-1) 
        gating_score = gating_score.expand(-1, -1, out.size(-1))
        
        weighted_out = out * gating_score 
        
        ch_score = gating_score.mean(dim=-1)
        
        threshold = 0.5
        ch_mask = (ch_score > threshold).float().unsqueeze(-1)  
        ch_mask = ch_mask.expand(-1, -1, weighted_out.size(-1)) 
        

        important_features   = weighted_out * ch_mask
        unimportant_features = weighted_out * (1 - ch_mask)
        
        return important_features, unimportant_features

class Orthogonal_Model(nn.Module):
    def __init__(self, embedding_dim, mlp_in):
        super(Orthogonal_Model, self).__init__()
        self.invariant_axis = nn.Parameter(torch.empty(mlp_in, mlp_in), requires_grad=True)
        self.relevant_axis = nn.Parameter(torch.empty(mlp_in, mlp_in), requires_grad=True)
        nn.init.xavier_uniform_(self.invariant_axis, gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_uniform_(self.relevant_axis, gain=nn.init.calculate_gain('relu'))

        self.conv_multi = nn.Conv1d(embedding_dim * 6, embedding_dim * 4, kernel_size=3, padding=1)


        self.invariant_axis = nn.Parameter(torch.randn(4 * embedding_dim))
        self.relevant_axis = nn.Parameter(torch.randn(4 * embedding_dim))

        with torch.no_grad():
            self.invariant_axis /= torch.norm(self.invariant_axis) + 1e-10
            self.relevant_axis -= torch.dot(self.relevant_axis, self.invariant_axis) * self.invariant_axis
            self.relevant_axis /= torch.norm(self.relevant_axis) + 1e-10

    def similarity(self, feature1, feature2, type='cos'):
        sim = 0
        if type == 'cos':
            norm1 = torch.norm(feature1, dim=1, keepdim=True)
            norm2 = torch.norm(feature2, dim=1, keepdim=True)
            sim = feature1 * feature2 / (norm1 * norm2)
            sim = torch.sum(sim, dim=1)
            sim = torch.max(torch.abs(sim))
        if type == 'pearson':
            sim = []
            for i in range(feature1.shape[0]):
                sim_temp = np.corrcoef(feature1[i].cpu().detach().numpy(), feature2[i].cpu().detach().numpy())
                # print(sim_temp)
                sim.append(sim_temp[0][1])
            sim = max(sim, key=abs)
        return sim

    def orthogonal_loss(self):
            
            invariant_axis = self.invariant_axis / (torch.norm(self.invariant_axis) + 1e-10)  # (C,)
            relevant_axis = self.relevant_axis / (torch.norm(self.relevant_axis) + 1e-10)      # (C,)


            o_loss = torch.dot(invariant_axis, relevant_axis)  # 标量

            o_loss = o_loss ** 2
            
            return o_loss

    def forward(self, x):
            
            out = self.conv_multi(x)  
            B, C, L = out.shape 
            

            invariant_axis = self.invariant_axis / (torch.norm(self.invariant_axis) ** 2 + 1e-10) 
            relevant_axis = self.relevant_axis / (torch.norm(self.relevant_axis) ** 2 + 1e-10)   
            


            invariant_dot = torch.matmul(out.transpose(1, 2), invariant_axis)  
            relevant_dot = torch.matmul(out.transpose(1, 2), relevant_axis)    
            
   
            invariant_dot = invariant_dot.unsqueeze(1) 
            relevant_dot = relevant_dot.unsqueeze(1)   
            
        
            invariant_axis = invariant_axis.unsqueeze(0).unsqueeze(2)  
            relevant_axis = relevant_axis.unsqueeze(0).unsqueeze(2)    
            

            invariant_features = invariant_dot * invariant_axis  
            relevant_features = relevant_dot * relevant_axis      
            

            invariant_features = invariant_features[:, :C//2, :] 
            relevant_features = relevant_features[:, C//2:, :]   
            
            return invariant_features, relevant_features

class MLP(nn.Module):
    def __init__(self, n_layers, in_dim, hidden, out_dim, dropout):
        super(MLP, self).__init__()
        self.lins = nn.ModuleList()
        self.lins.append(nn.Linear(in_dim, hidden))
        for _ in range(n_layers - 2):
            self.lins.append(nn.Linear(hidden, hidden))
        self.lins.append(nn.Linear(hidden, out_dim))
        self.dropout = dropout

        for i in range(len(self.lins)):
            nn.init.xavier_uniform_(self.lins[i].weight, gain=nn.init.calculate_gain('relu'))
    def forward(self, features):
        for i, lin in enumerate(self.lins[:-1]):
            features = lin(features)
            features = F.relu(features)
            features = F.dropout(features, p=self.dropout, training=self.training)
        features = self.lins[-1](features)
        return features


class SpectraCNNWithAttentionVQ(nn.Module):
    def __init__(self, num_classes, num_embeddings=512, embedding_dim=128, commitment_cost=0.25, num_heads=4, init_embeddings=None):
        super(SpectraCNNWithAttentionVQ, self).__init__()
        dropout = 0.3
        self.conv1 = nn.Conv1d(1, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(64)
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = nn.Conv1d(64, embedding_dim, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(embedding_dim)
        self.dropout2 = nn.Dropout(dropout)

        self.conv3 = nn.Conv1d(embedding_dim, embedding_dim * 2, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(embedding_dim * 2)
        self.dropout3 = nn.Dropout(dropout)

        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.conv4 = nn.Conv1d(embedding_dim * 2, embedding_dim * 2, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm1d(embedding_dim * 2)
        self.pool4 = nn.MaxPool1d(kernel_size=2, stride=2) 

        self.conv5 = nn.Conv1d(embedding_dim * 2, embedding_dim * 2, kernel_size=5, padding=2)
        self.bn5 = nn.BatchNorm1d(embedding_dim * 2)
        self.pool5 = nn.AvgPool1d(kernel_size=2, stride=2) 

        self.conv6 = nn.Conv1d(embedding_dim * 2, embedding_dim * 2, kernel_size=7, padding=3)
        self.bn6 = nn.BatchNorm1d(embedding_dim * 2)
        self.pool6 = nn.MaxPool1d(kernel_size=2, stride=2) 

        self.conv_multi = nn.Conv1d(embedding_dim * 6, embedding_dim * 2, kernel_size=3, padding=1)  
        self.orthogonal = Orthogonal_Model(embedding_dim, embedding_dim * 2)

        self.residual_conv = nn.Conv1d(1, embedding_dim * 2, kernel_size=1)

        self.vq = VectorQuantizer(num_embeddings, embedding_dim * 2, commitment_cost, init_embeddings)


        self.attention = nn.MultiheadAttention(embed_dim=embedding_dim * 2, num_heads=num_heads, batch_first=True)

        self.gate = nn.Conv1d(embedding_dim * 2, embedding_dim * 2, kernel_size=1)

        self.fc1 = nn.Linear(embedding_dim * 256, 128)
        self.fc2 = nn.Linear(128, num_classes)
        self.mlp = MLP(3, embedding_dim * 256, 128, num_classes, dropout)

        self.dropout4 = nn.Dropout(0.58)

    def forward(self, x):
        
        x = x.unsqueeze(1)  
        
        x = F.relu(self.bn1(self.conv1(x)))  
        x = self.dropout1(x)
       
        x = F.relu(self.bn2(self.conv2(x))) 
        x = self.dropout2(x)

        x = F.relu(self.bn3(self.conv3(x))) 
        x = self.dropout3(x)

        x1 = F.relu(self.bn4(self.conv4(x)))
        x1 = self.pool4(x1)
        x2 = F.relu(self.bn5(self.conv5(x)))
        x2 = self.pool5(x2)
        x3 = F.relu(self.bn6(self.conv6(x)))
        x3 = self.pool6(x3)

        xn = torch.cat([x1, x2, x3], dim=1)  

        invariant, spurious = self.orthogonal(xn)

        o_loss = self.orthogonal.orthogonal_loss()



        # 向量量化
        x = self.pool3(x)
        x_vq, vq_loss = self.vq(x)  
   
        gate = torch.tanh(self.gate(x))  

        x = gate * invariant + (1 - gate) * x_vq  

    
        x = x.permute(0, 2, 1).contiguous()  
        attn_output, _ = self.attention(x, x, x)  
        attn_output = attn_output.permute(0, 2, 1).contiguous()  
        x = attn_output.view(x.size(0), -1) 
        x = F.relu(self.fc1(x)) 
        x = self.dropout4(x)
       
        logits = self.fc2(x) 
       
        probabilities = F.softmax(logits, dim=1)  

        invariant_logits = F.softmax(self.mlp(invariant.view(x.size(0), -1)), dim=1)



        # print("o_loss:", o_loss)
        return logits, probabilities, vq_loss, o_loss, invariant_logits






num_classes = len(label_encoder.classes_)
print("Number of classes:", num_classes)


num_classes = 80 
num_embeddings = 128  
embedding_dim = 64
commitment_cost = 0.4 
num_heads = 64 


model = SpectraCNNWithAttentionVQ(num_classes=num_classes,
                                    num_embeddings=num_embeddings,
                                    embedding_dim=embedding_dim,
                                    commitment_cost=commitment_cost,
                                    num_heads=num_heads)
print(model)

optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)  


criterion_cls = nn.CrossEntropyLoss()


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


max_train_accuracy = 0.0
max_train_epoch = 0
test_accuracy_at_max_train = 0.0
test_loss_at_max_train = 0.0
precision_at_max_train = 0.0
recall_at_max_train = 0.0
f1_at_max_train = 0.0


best_test_labels = []
best_test_predictions = []


train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []
precision_scores = []
recall_scores = []
f1_scores = []
r2_scores = []
rmse_scores = []


num_epochs = 500
for epoch in range(num_epochs):

    model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    for data, labels in train_loader:
        data, labels = data.to(device), labels.to(device)


        logits, probabilities, vqloss, o_loss, invariant_logits = model(data)
        loss = criterion_cls(logits, labels) + criterion_cls(invariant_logits, labels) + vqloss + o_loss

  
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(logits.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()

    avg_train_loss = running_loss / len(train_loader)
    train_accuracy = correct_train / total_train


    model.eval()
    running_test_loss = 0.0
    correct_test = 0
    total_test = 0
    all_labels = []
    all_probs = []
    predicted_classes = []
    with torch.no_grad():
        for data, labels in test_loader:
            data, labels = data.to(device), labels.to(device)
            logits, probabilities, vqloss, o_loss, invariant_logits = model(data)
            loss = criterion_cls(logits, labels) + criterion_cls(invariant_logits, labels) + vqloss + o_loss
            running_test_loss += loss.item()

            _, predicted = torch.max(logits.data, 1)
            total_test += labels.size(0)
            correct_test += (predicted == labels).sum().item()

          
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probabilities.cpu().numpy())
            predicted_classes.extend(predicted.cpu().numpy())

    avg_test_loss = running_test_loss / len(test_loader)
    test_accuracy = correct_test / total_test


    precision = precision_score(all_labels, predicted_classes, average='weighted', zero_division=0)
    recall = recall_score(all_labels, predicted_classes, average='weighted', zero_division=0)
    f1 = f1_score(all_labels, predicted_classes, average='weighted', zero_division=0)


    if train_accuracy > max_train_accuracy:
        max_train_accuracy = train_accuracy
        max_train_epoch = epoch
        test_accuracy_at_max_train = test_accuracy
        test_loss_at_max_train = avg_test_loss
        precision_at_max_train = precision
        recall_at_max_train = recall
        f1_at_max_train = f1
        best_test_labels = copy.deepcopy(all_labels)
        best_test_predictions = copy.deepcopy(predicted_classes)



    train_losses.append(avg_train_loss)
    test_losses.append(avg_test_loss)
    train_accuracies.append(train_accuracy)
    test_accuracies.append(test_accuracy)
    precision_scores.append(precision)
    recall_scores.append(recall)
    f1_scores.append(f1)


    print(f"Epoch [{epoch+1}/{num_epochs}], "
          f"Train Loss: {avg_train_loss:.4f} (Cls: {avg_train_loss:.4f}), "
          f"Train Accuracy: {train_accuracy:.4f}, "
          f"Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}, "
          f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}")


cm_test = confusion_matrix(best_test_labels, best_test_predictions)
print(len(best_test_predictions))

print("\n==========================")
print(f"训练准确率最高的 Epoch: {max_train_epoch + 1}")
print(f"训练准确率: {max_train_accuracy:.4f}")
print(f"对应的测试准确率: {test_accuracy_at_max_train:.4f}")
print(f"对应的测试损失: {test_loss_at_max_train:.4f}")
print(f"对应的 Precision: {precision_at_max_train:.4f}")
print(f"对应的 Recall: {recall_at_max_train:.4f}")
print(f"对应的 F1-Score: {f1_at_max_train:.4f}")
print("==========================\n")

