In [2]:
# Set the CPU affinity to specific cores: 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32
import psutil
import os

# Get the current process
current_process = psutil.Process(os.getpid())

# Set CPU affinity to the specified cores
current_process.cpu_affinity([22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32])

# Verify if the CPU affinity is set correctly
print(f"CPU Affinity: {current_process.cpu_affinity()}")

# Configure logging
import logging
import time

# Generate a timestamp for the log file
time_stamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
log_file = f"train-{time_stamp}.log"

# Set up the logger
logging.basicConfig(
    filename=log_file,
    level=logging.INFO,
    format='%(asctime)s - %(message)s'
)

CPU Affinity: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31]


In [None]:
from captum import attr
from tqdm import tqdm
import numpy as np
from WFlib import models
from WFlib.tools import data_processor, analyzer
import torch
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvBlock2d(nn.Module):
    """
    A 2D convolutional block consisting of two convolutional layers followed by batch normalization
    and ReLU activation, with a residual connection.
    """
    def __init__(self, in_channels, out_channels, kernel_size):
        super(ConvBlock2d, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=1,
                padding="same"
            ),
            nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=True),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=1,
                padding="same"
            ),
            nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=True),
            nn.ReLU()
        )
        self.downsample = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else None
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)
        self.last_relu = nn.ReLU()

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.last_relu(out + res)


class Encoder2d(nn.Module):
    """
    A 2D convolutional encoder consisting of multiple ConvBlock2d layers followed by max pooling and dropout.
    """
    def __init__(self, in_channels, out_channels, conv_num_layers):
        super(Encoder2d, self).__init__()
        layers = []
        cur_in_channels = in_channels
        cur_out_channels = 16
        for i in range(conv_num_layers):
            layers.append(ConvBlock2d(cur_in_channels, cur_out_channels, (3, 7)))
            if i < conv_num_layers - 1:
                layers.append(nn.MaxPool2d((1, 2)))
            else:
                layers.append(nn.MaxPool2d((2, 2)))
            layers.append(nn.Dropout(0.1))
            cur_in_channels = cur_out_channels
            cur_out_channels = cur_out_channels * 2
            if i == conv_num_layers - 2:
                cur_out_channels = out_channels
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        x = self.layers(x)
        return x


class TransformerEncoderBlock(nn.Module):
    """
    A Transformer encoder block consisting of multi-head attention, feed-forward layers, and layer normalization.
    """
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.2):
        super(TransformerEncoderBlock, self).__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.layernorm1 = nn.LayerNorm(embed_dim)
        self.layernorm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # Input shape: (seq_len, batch_size, embed_dim)
        attn_output, _ = self.attn(x, x, x)
        x = self.layernorm1(x + self.dropout(attn_output))
        ff_output = self.ff(x)
        x = self.layernorm2(x + self.dropout(ff_output))
        return x


class AttentionPooling(nn.Module):
    """
    Attention pooling module to aggregate sequence information into a single vector.
    """
    def __init__(self, embed_dim):
        super(AttentionPooling, self).__init__()
        self.attn = nn.Linear(embed_dim, 1)
        
    def forward(self, x):
        # Input shape: (batch_size, seq_len, embed_dim)
        weights = F.softmax(self.attn(x), dim=1)  # (batch_size, seq_len, 1)
        return torch.sum(x * weights, dim=1)  # (batch_size, embed_dim)


class CTFNet(nn.Module):
    """
    An improved multi-scale time series network with convolutional encoders, Transformer encoders,
    attention pooling, and classification/projection heads.
    """
    def __init__(self, num_classes=10, num_scales=3, num_features=5, sub_dim=2, seq_len=1000, embed_dim=64, num_heads=4, ff_dim=128):
        super(CTFNet, self).__init__()
        
        self.num_scales = num_scales
        self.num_features = num_features
        self.sub_dim = sub_dim
        
        # Learnable positional encodings
        seq_len_after = 125
        self.pos_encodings = nn.ParameterList([
            nn.Parameter(torch.zeros(1, seq_len_after, embed_dim))
            for _ in range(num_scales)
        ])
        
        # Initialize positional encodings
        for pos_enc in self.pos_encodings:
            nn.init.trunc_normal_(pos_enc, std=0.02)

        # Use Encoder2d to process data for each scale
        self.encoder2d_blocks = nn.ModuleList([
            Encoder2d(in_channels=num_features, out_channels=embed_dim, conv_num_layers=3)
            for _ in range(num_scales)
        ])
        
        # Transformer encoders
        self.transformer_encoders = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, ff_dim)
            for _ in range(num_scales)
        ])
        
        # Multi-head self-attention fusion
        self.fusion_attn = nn.MultiheadAttention(embed_dim * num_scales, num_heads, dropout=0.3)
        
        # Attention pooling
        self.attn_pool = AttentionPooling(embed_dim * num_scales)
        
        # Projection head for contrastive learning
        projection_dim = 64
        self.projection_head = nn.Sequential(
            nn.Linear(embed_dim * num_scales, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, projection_dim)
        )

        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim * num_scales, 256),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )
        
    def forward(self, x):
        # Input shape: (batch_size, num_scales, num_features, sub_dim, seq_len)
        batch_size = x.size(0)
        
        # Multi-scale 2D convolution processing
        conv_outputs = []
        for i in range(self.num_scales):
            # Get data for the current scale: (batch_size, num_features, sub_dim, seq_len)
            scale_x = x[:, i]  
            # Pass through Encoder2d: output will be (batch_size, embed_dim, 1, 125)
            conv_out = self.encoder2d_blocks[i](scale_x)
            
            # Reshape to fit subsequent processing: (batch_size, embed_dim, 125)
            conv_out = conv_out.squeeze(2)
            
            conv_outputs.append(conv_out)
        
        # Adjust dimensions for Transformer and add positional encodings
        transformer_inputs = []
        for i, conv_out in enumerate(conv_outputs):
            # Reshape convolution output to (batch_size, seq_len, embed_dim)
            t_input = conv_out.permute(0, 2, 1)
            
            # Add positional encodings
            t_input = t_input + self.pos_encodings[i]
            
            # Convert to Transformer expected input shape: (seq_len, batch_size, embed_dim)
            t_input = t_input.permute(1, 0, 2)
            
            transformer_inputs.append(t_input)
        
        # Transformer encoding
        transformer_outputs = []
        for i in range(self.num_scales):
            trans_out = self.transformer_encoders[i](transformer_inputs[i])  # (seq_len, batch_size, embed_dim)
            transformer_outputs.append(trans_out)
        
        # Multi-scale fusion
        fused = torch.cat(transformer_outputs, dim=2)  # (seq_len, batch_size, embed_dim * num_scales)
        fused, _ = self.fusion_attn(fused, fused, fused)  # (seq_len, batch_size, embed_dim * num_scales)
        
        # Adjust dimensions for attention pooling
        fused = fused.permute(1, 0, 2)  # (batch_size, seq_len, embed_dim * num_scales)
        
        # Attention pooling
        pooled = self.attn_pool(fused)  # (batch_size, embed_dim * num_scales)

        # Projection head
        projection = self.projection_head(pooled)  # (batch_size, projection_dim)
        projection = F.normalize(projection, p=2, dim=1)
        
        # Fully connected layer output
        output = self.classifier(pooled)  # (batch_size, num_classes)
        
        return output, projection

In [None]:
from pytorch_metric_learning.losses import SupConLoss
import time

def train_ctf_model(X_train, y_train, X_val, y_val, num_classes, epochs=50, 
                     batch_size=128, val_batch_size=128, save_dir='./checkpoints',
                     resume_from=None):
    """
    Train the ctf model, supporting batch evaluation, best model saving, and training resumption.
    
    Parameters:
        X_train, y_train: Training data and labels
        X_val, y_val: Validation data and labels
        num_classes: Number of classes
        epochs: Number of training epochs
        batch_size: Training batch size
        val_batch_size: Validation batch size
        save_dir: Directory to save the model
        resume_from: Path to resume training from a checkpoint
    """
    import os
    import time
    
    # Create save directory
    os.makedirs(save_dir, exist_ok=True)
    
    # Data preparation
    X_train = torch.FloatTensor(X_train)
    y_train = torch.LongTensor(y_train)
    X_val = torch.FloatTensor(X_val)
    y_val = torch.LongTensor(y_val)
    
    train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    # Create validation data loader
    val_dataset = torch.utils.data.TensorDataset(X_val, y_val)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=val_batch_size, shuffle=False)
    
    # Initialize model
    #model = Groupctf(num_classes=num_classes, feature_dim=10, seq_len=1000)
    #model = ctf(num_classes=num_classes, feature_dim=10, seq_len=1000)
    model = CTFNet(num_classes=num_classes)
 
    
    # Set to cuda:1
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # Optimizer and loss functions
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
    ce_criterion = nn.CrossEntropyLoss()
    sc_criterion = SupConLoss(temperature=0.07)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.6, patience=5, verbose=True
    )
    
    # Initialize training state
    start_epoch = 0
    best_val_acc = 0
    best_epoch = 0
    
    # Resume training (if checkpoint is provided)
    if resume_from and os.path.isfile(resume_from):
        logging.info(f"Resuming from checkpoint: {resume_from}")
        checkpoint = torch.load(resume_from)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_val_acc = checkpoint['best_val_acc']
        best_epoch = checkpoint['best_epoch']
        logging.info(f"Resuming from epoch {start_epoch}, best accuracy: {best_val_acc*100:.2f}% (epoch {best_epoch})")
    
    # Training loop
    for epoch in range(start_epoch, epochs):
        epoch_start_time = time.time()
        
        # Training phase
        model.train()
        train_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, targets in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{epochs}"):
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs, projection = model(inputs)
            ce_loss = ce_criterion(outputs, targets)
            sc_loss = sc_criterion(projection, targets)
            loss = ce_loss + sc_loss
            loss.backward()

            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        
        train_acc = correct / total
        
        # Validation phase - batch evaluation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for inputs, targets in tqdm(val_loader, desc=f"Validation Epoch {epoch+1}/{epochs}"):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs, projection = model(inputs)

                ce_loss = ce_criterion(outputs, targets)
                #sc_loss = sc_criterion(projection, targets)
                
                loss = ce_loss #+ sc_loss
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()
        
        val_acc = val_correct / val_total
        
        # Save the best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = epoch
            torch.save(model.state_dict(), f"{save_dir}/best_ctf_model.pt")
            logging.info(f"Saved best model, validation accuracy: {val_acc*100:.2f}%")
        
        # Save checkpoint every 5 epochs or at the last epoch
        if (epoch + 1) % 5 == 0 or epoch == epochs - 1:
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_val_acc': best_val_acc,
                'best_epoch': best_epoch
            }
            torch.save(checkpoint, f"{save_dir}/checkpoint_epoch_{epoch+1}.pt")
            logging.info(f"Saved checkpoint: epoch_{epoch+1}")
        
        # Adjust learning rate
        scheduler.step(val_loss / len(val_loader))
        
        # Print statistics
        epoch_time = time.time() - epoch_start_time
        logging.info(f"Epoch {epoch+1}/{epochs} - Time: {epoch_time:.1f}s")
        logging.info(f"Training Loss: {train_loss/len(train_loader):.4f} | "
              f"Training Accuracy: {train_acc*100:.2f}%")
        logging.info(f"Validation Loss: {val_loss/len(val_loader):.4f} | "
              f"Validation Accuracy: {val_acc*100:.2f}%")
        logging.info(f"Best Validation Accuracy: {best_val_acc*100:.2f}% (Epoch {best_epoch+1})")
        logging.info("-" * 60)
    
    # Load the best model after training
    best_model_path = f"{save_dir}/best_ctf_model.pt"
    if os.path.exists(best_model_path):
        model.load_state_dict(torch.load(best_model_path))
        logging.info(f"Loaded best model (Epoch {best_epoch+1}, Accuracy: {best_val_acc*100:.2f}%)")
    
    return model, best_val_acc, best_epoch


In [6]:
import torch
import numpy as np
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed


def fast_count_burst(arr):
    diff = np.diff(arr)
    change_indices = np.nonzero(diff)[0]
    segment_starts = np.insert(change_indices + 1, 0, 0)
    segment_ends = np.append(change_indices, len(arr) - 1)
    segment_lengths = segment_ends - segment_starts + 1
    segment_signs = np.sign(arr[segment_starts])
    adjusted_lengths = segment_lengths * segment_signs

    return adjusted_lengths

def process_HMTF(index, sequence, intervals, max_lens):
    packets = np.trim_zeros(sequence, "fb")
    if len(packets) == 0:
        return index, np.zeros((len(intervals), 10, max(max_lens)))
        
    abs_packets = np.abs(packets)
    st_time = abs_packets[0]
    
    hmtf = np.zeros((len(intervals), 10, max(max_lens)))
    
    for scale_idx, (interval, max_len) in enumerate(zip(intervals, max_lens)):
        st_pos = 0
        
        for interval_idx in range(max_len):
            ed_time = (interval_idx + 1) * interval
            if interval_idx == max_len - 1:
                ed_pos = abs_packets.shape[0]
            else:
                ed_pos = np.searchsorted(abs_packets, st_time + ed_time)
            
            #print(st_pos, ed_pos)
            if st_pos < ed_pos:
                cur_packets = packets[st_pos:ed_pos]
                
                # 1-2: Directional statistics
                out_packets = cur_packets[cur_packets > 0]
                in_packets = cur_packets[cur_packets < 0]
                hmtf[scale_idx, 0, interval_idx] = len(out_packets)  # Number of outgoing packets
                hmtf[scale_idx, 1, interval_idx] = len(in_packets)   # Number of incoming packets
                

                
                # 3-6: Burst analysis
                if len(cur_packets) > 0:
                    dirs = np.sign(cur_packets)
                    bursts = fast_count_burst(dirs)
                    out_bursts = bursts[bursts > 0]
                    in_bursts = np.abs(bursts[bursts < 0])
                    
                    hmtf[scale_idx, 2, interval_idx] = len(out_bursts)  # Number of outgoing bursts
                    hmtf[scale_idx, 3, interval_idx] = len(in_bursts)   # Number of incoming bursts
                    hmtf[scale_idx, 4, interval_idx] = np.mean(out_bursts) if len(out_bursts) > 0 else 0  # Average outgoing burst length
                    hmtf[scale_idx, 5, interval_idx] = np.mean(in_bursts) if len(in_bursts) > 0 else 0    # Average incoming burst length
                
                # 7-8: Time interval statistics
                if len(out_packets) > 1:
                    out_times = np.diff(abs_packets[st_pos:ed_pos][cur_packets > 0])
                    hmtf[scale_idx, 6, interval_idx] = np.mean(out_times) if len(out_times) > 0 else 0  # Average time between outgoing packets
                
                if len(in_packets) > 1:
                    in_times = np.diff(abs_packets[st_pos:ed_pos][cur_packets < 0])
                    hmtf[scale_idx, 7, interval_idx] = np.mean(in_times) if len(in_times) > 0 else 0    # Average time between incoming packets
                
                # 9: Outgoing to incoming ratio
                total_packets = len(out_packets) + len(in_packets)
                hmtf[scale_idx, 8, interval_idx] = len(out_packets)/total_packets if total_packets > 0 else 0

                # 10: Time density - packets per millisecond
                interval_duration = abs_packets[ed_pos-1] - abs_packets[st_pos] if ed_pos > st_pos else interval
                hmtf[scale_idx, 9, interval_idx] = len(cur_packets) / interval_duration if interval_duration > 0 else 0
                
            st_pos = ed_pos
    
    return index, hmtf

def extract_HMTF(sequences, num_workers=30):
    """Extract hybrid multi-scale traffic features from sequences"""
    # Define multiple time scales: first 50ms, first 80ms, first 120ms
    intervals = [50, 100, 200]
    max_lens = [1000, 800, 600]
    
    sequences *= 1000  # Convert to milliseconds
    num_sequences = sequences.shape[0]
    
    # Final feature shape: [num_sequences, num_scales, num_features, max_length]
    hmtf = np.zeros((num_sequences, len(intervals), 10, max(max_lens)))
    
    with ProcessPoolExecutor(max_workers=min(num_workers, num_sequences)) as executor:
        futures = [
            executor.submit(
                process_HMTF, index, sequences[index], intervals, max_lens
            ) for index in range(num_sequences)
        ]
        with tqdm(total=num_sequences) as pbar:
            for future in as_completed(futures):
                index, result = future.result()
                hmtf[index] = result
                pbar.update(1)

    return hmtf

In [7]:
# extract_HMTF
# Load the dataset
# train_data = np.load("datasets/OW/train.npz")
# valid_data = np.load("datasets/OW/valid.npz")
# test_data = np.load("datasets/OW/test.npz")

# X_train, y_train = train_data["X"], train_data["y"]
# X_val, y_val = valid_data["X"], valid_data["y"]
# X_test, y_test = test_data["X"], test_data["y"]

In [8]:
# X_train =  data_processor.length_align(X_train, 5000)
# X_val =  data_processor.length_align(X_val, 5000)
# X_test =  data_processor.length_align(X_test, 5000)
# X_train.shape, X_val.shape, X_test.shape

In [9]:
# X_train_HMTF = extract_HMTF(X_train)
# X_val_HMTF = extract_HMTF(X_val)
# X_test_HMTF = extract_HMTF(X_test)

In [10]:
# save it to npz
# np.savez("datasets/OW/train_HMTF.npz", X=X_train_HMTF, y=y_train)
# np.savez("datasets/OW/valid_HMTF.npz", X=X_val_HMTF, y=y_val)
# np.savez("datasets/OW/test_HMTF.npz", X=X_test_HMTF, y=y_test)

In [11]:
# load data
train_data = np.load("datasets/CW/train_HMTF.npz")
valid_data = np.load("datasets/CW/valid_HMTF.npz")
test_data = np.load("datasets/CW/test_HMTF.npz")
X_train_HMTF, y_train = train_data["X"], train_data["y"]
X_val_HMTF, y_val = valid_data["X"], valid_data["y"]
X_test_HMTF, y_test = test_data["X"], test_data["y"]

<!--  -->

In [12]:
print(X_train_HMTF.shape, y_train.shape, X_val_HMTF.shape, y_val.shape, X_test_HMTF.shape, y_test.shape)


(85641, 3, 10, 1000) (85641,) (9516, 3, 10, 1000) (9516,) (10573, 3, 10, 1000) (10573,)


In [13]:
# reshape (85641, 3, 10, 1000) to (85641, 3, 5, 2,1000)
X_train_HMTF_reshaped = X_train_HMTF.reshape(85641, 3, 5, 2, 1000)
X_val_HMTF_reshaped = X_val_HMTF.reshape(9516, 3, 5, 2, 1000)

print(X_train_HMTF_reshaped.shape, y_train.shape, X_val_HMTF_reshaped.shape, y_val.shape)

(85641, 3, 5, 2, 1000) (85641,) (9516, 3, 5, 2, 1000) (9516,)


In [None]:
# clear cuda
torch.cuda.empty_cache()

num_classes = len(np.unique(y_train))

model, best_acc, best_epoch = train_ctf_model(
    X_train_HMTF_reshaped, y_train, X_val_HMTF_reshaped, y_val, 
    num_classes=num_classes, 
    epochs=60,
    save_dir='./ctf_train',
    batch_size=128,
    #resume_from='./ModelV1.7-2m-OW/checkpoint_epoch_40.pt'
)

In [None]:

model = CTFNet(num_classes=len(np.unique(y_test)))
#checkpoint = torch.load("./ModelV1.7/checkpoint_epoch_60.pt")
model.load_state_dict(torch.load("./ctf_train/best_ctf_model.pt"))
model.eval()

In [20]:
print(X_test_HMTF.shape, y_test.shape)
X_test_HMTF_reshaped = X_test_HMTF.reshape(10573, 3, 5, 2, 1000)
print(X_test_HMTF_reshaped.shape, y_test.shape)

(10573, 3, 10, 1000) (10573,)
(10573, 3, 5, 2, 1000) (10573,)


In [1]:
from sklearn.metrics import accuracy_score, classification_report
def evaluate_ctf_model(model, X_test, y_test, batch_size=64, device=None):
    """
    Batch evaluation of the ctf model to avoid memory overflow.
    
    Parameters:
    model: Trained ctf model
    X_test: Test data [num_samples, 3, 10, 1000]
    y_test: Test labels
    batch_size: Batch size
    device: Computation device
    
    Returns:
    dict: Dictionary containing various evaluation metrics
    """
    import numpy as np
    from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    # Set device
    if device is None:
        device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    
    # Convert to Tensor
    X_test = torch.FloatTensor(X_test)
    y_test = torch.LongTensor(y_test)
    
    # Create data loader
    test_dataset = torch.utils.data.TensorDataset(X_test, y_test)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    # Set model to evaluation mode
    model = model.to(device)
    model.eval()
    
    # Store predictions and true labels
    all_preds = []
    all_labels = []
    #all_attention_weights = []
    
    with torch.no_grad():
        for inputs, targets in tqdm(test_loader, desc="Evaluating"):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Model prediction
            outputs, attention_weights = model(inputs)
            _, predictions = outputs.max(1)
            
            # Collect results
            all_preds.extend(predictions.cpu().numpy())
            all_labels.extend(targets.cpu().numpy())
            #all_attention_weights.append(attention_weights.cpu().numpy())
    
    # Convert to NumPy arrays
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    #all_attention_weights = np.concatenate(all_attention_weights, axis=0)
    
    # Compute evaluation metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='macro')
    
    print(classification_report(all_labels, all_preds))

    
    # Compute average attention weights for each scale
    #avg_scale_attention = all_attention_weights.mean(axis=0)
    
    # Compute confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    
    # Create evaluation results dictionary
    results = {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "confusion_matrix": cm,
        #"avg_scale_attention": avg_scale_attention
    }
    
    # Print main evaluation results
    print(f"Test accuracy: {accuracy*100:.2f}%")
    print(f"Macro average precision: {precision:.4f}")
    print(f"Macro average recall: {recall:.4f}")
    print(f"Macro average F1 score: {f1:.4f}")
    #print(f"Scale attention weights: {avg_scale_attention}")

    
    return results

# Example usage
# model = train_ctf_model(X_train, y_train, X_val, y_val, num_classes, epochs=50)
# eval_results = evaluate_ctf_model(model, X_test, y_test, batch_size=64)

In [None]:
results = evaluate_ctf_model(model, X_test_HMTF_reshaped, y_test, batch_size=128)

In [None]:
results