# Import

In [1]:
# Import standard libraries
import os
import gc
import pickle
import json
import math

# Import scientific computing libraries
import numpy as np
import pandas as pd
import scipy.signal
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns

# Import data processing libraries
import polars as pl
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import accuracy_score, confusion_matrix

# Import deep learning libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.nn.utils.rnn import pad_sequence
from torch.cuda.amp import autocast, GradScaler

# Import signal processing functions
from scipy.signal import butter, filtfilt

# Import parallel processing libraries
from concurrent.futures import ThreadPoolExecutor, as_completed

# Import utility libraries
from tqdm import tqdm
import wandb
from types import SimpleNamespace

# Import custom modules
import _MultiResUNet as MultiResUNet

In [2]:
BATCH_SIZE = 4
EPOCHS = 30
PREFIX = '20240612_1_'
class_names = ['ramp ascent', 'ramp descent', 'stair ascent', 'stair descent', 'walk']

In [3]:
# Hyperparameters for model
config = SimpleNamespace(
    SAVE_DIR='model',
    model_depth=8,
    model_width=32,
    kernel_size=5,
    problem_type='Classification',
    ds=True,
    ae=False,
    feature_number=512,
    is_transconv=True,

    learning_rate=0.00001,
    
    WINDOW_SIZES = [0.1, 0.3, 0.6, 1.2],  # 초 단위 윈도우 크기
    SAMPLE_RATE_TARGET = 100,  # Hz
    MAX_TIME = 50        # sec
)

In [4]:
SAMPLE_RATE = 200  # Hz (OpenDataSet)

In [5]:
SAVE_DIR = os.path.join(config.SAVE_DIR, PREFIX)

# Hyperparameters for model
def save_config(config, save_path):
    os.makedirs(SAVE_DIR, exist_ok=True)
    with open(save_path, 'w') as f:
        json.dump(config.__dict__, f, indent=4)
        
save_config(config, os.path.join(SAVE_DIR, PREFIX + 'config.json'))

In [6]:
# MAX_LENGTH_TARGET를 2 ** model_depth의 배수로 설정
factor = 2 ** config.model_depth
MAX_LENGTH_TARGET = math.ceil((config.SAMPLE_RATE_TARGET * config.MAX_TIME) / factor) * factor
print(f'Max recording time: {MAX_LENGTH_TARGET/config.SAMPLE_RATE_TARGET} sec')

Max recording time: 51.2 sec


# Load Data

In [7]:
# 데이터 불러오기
X_data = np.load('X_data.npy', allow_pickle=True)
Y_data = np.load('Y_data.npy', allow_pickle=True)

print('X_data shape:', X_data.shape)
print('Y_data shape:', Y_data.shape)

X_data shape: (2990,)
Y_data shape: (2990,)


## Set Columns

In [8]:
required_columns = [
    'foot_Accel_X', 'foot_Accel_Y', 'foot_Accel_Z', 'foot_Gyro_X', 'foot_Gyro_Y', 'foot_Gyro_Z',
    'shank_Accel_X', 'shank_Accel_Y', 'shank_Accel_Z', 'shank_Gyro_X', 'shank_Gyro_Y', 'shank_Gyro_Z'
    # 'thigh_Accel_X', 'thigh_Accel_Y', 'thigh_Accel_Z', 'thigh_Gyro_X', 'thigh_Gyro_Y', 'thigh_Gyro_Z',
    # 'trunk_Accel_X', 'trunk_Accel_Y', 'trunk_Accel_Z', 'trunk_Gyro_X', 'trunk_Gyro_Y', 'trunk_Gyro_Z'
]
X_data_custom = [df[required_columns] for df in X_data]


# Feature Engineering

In [9]:
class TimeSeriesFeatureEngineer:
    def __init__(self, window_sizes, sampling_rate):
        self.sampling_rate = sampling_rate
        self.window_sizes = np.dot(window_sizes, sampling_rate).astype(int)
        self.encoder = None
        self.label_mapping = {
            'idle': 'walk',
            'rampascent': 'rampascent',
            'rampascent-walk': 'rampascent',
            'rampdescent': 'rampdescent',
            'rampdescent-walk': 'rampdescent',
            'stairascent': 'stairascent',
            'stairascent-walk': 'stairascent',
            'stairdescent': 'stairdescent',
            'stairdescent-walk': 'stairdescent',
            'stand': 'walk',
            'stand-walk': 'walk',
            'turn1': 'walk',
            'turn2': 'walk',
            'walk': 'walk',
            'walk-rampascent': 'rampascent',
            'walk-rampdescent': 'rampdescent',
            'walk-stairascent': 'stairascent',
            'walk-stairdescent': 'stairdescent',
            'walk-stand': 'walk'
        }

    def map_labels(self, Y_data):
        Y_data_mapped = []
        for y_seq in Y_data:
            Y_data_mapped.append(np.array([self.label_mapping[label] for label in y_seq]))
        return Y_data_mapped

    def create_encoder(self, Y_data):
        # 라벨 매핑
        Y_data_mapped = self.map_labels(Y_data)
        
        # 전체 라벨 수집
        all_labels = np.concatenate(Y_data_mapped)
        all_labels_unique = np.unique(all_labels).reshape(-1, 1)
        
        # OneHotEncoder를 사용하여 라벨 인코딩
        self.encoder = OneHotEncoder(sparse_output=False)
        self.encoder.fit(all_labels_unique)

        # 인코더의 라벨 출력
        print("Encoder classes:", self.encoder.categories_)
        return self.encoder

    def fit_transform_labels(self, Y_data):
        if self.encoder is None:
            raise ValueError("Encoder has not been created. Call create_encoder first.")
        
        # 라벨 매핑
        Y_data_mapped = self.map_labels(Y_data)
        
        # 각 Y_data를 원핫 인코딩
        Y_data_encoded_list = [self.encoder.transform(np.array(y).reshape(-1, 1)) for y in Y_data_mapped]
        return Y_data_encoded_list

    def feature_engineering(self, df: pl.DataFrame):
        # LazyFrame으로 변환하여 작업
        lf = df.lazy()
        
        for col in df.columns:
            # df[col] = self.lowpass_filter(df[col], cutoff_freq=int(self.sampling_rate*0.1), sampling_rate=self.sampling_rate, filter_order=6)

            for window in self.window_sizes:
                window_str = str(window)
                # 통계 값
                lf = lf.with_columns([
                    df[col].rolling_mean(window).alias(col + '_mean_' + window_str),
                    df[col].rolling_std(window).alias(col + '_std_' + window_str),
                    df[col].rolling_min(window).alias(col + '_min_' + window_str),
                    df[col].rolling_max(window).alias(col + '_max_' + window_str),
                    df[col].diff(window).alias(col + '_diff_' + window_str)
                ])
                for lag in [1, 2, 3, 4, 5]:
                    lf = lf.with_columns([
                        df[col].shift(lag * window).alias(col + f'_lag_{lag}_' + window_str)
                    ])
        
        features_df = lf.collect().fill_nan(0).fill_null(0)
        return features_df
    
    def lowpass_filter(self, data, cutoff_freq=100, sampling_rate=200, filter_order=6):
        nyquist = 0.5 * sampling_rate
        normal_cutoff = cutoff_freq / nyquist
        b, a = butter(filter_order, normal_cutoff, btype='low', analog=False)
        filtered_data = filtfilt(b, a, data)
        return filtered_data

    def fit_transform_features(self, X_data):
        X_features = []
        for seq in X_data:
            seq_df = pl.DataFrame(seq)
            features_df = self.feature_engineering(seq_df)
            X_features.append(features_df.to_numpy())
        return X_features

    def resample_x_data(self, X_data, original_sampling_rate, target_sampling_rate):
        resampled_X_data = []
        for seq in X_data:
            seq_copy = seq.copy()  # 데이터프레임 복사본 생성
            for col in seq_copy.columns:
                seq_copy.loc[:, col] = self.lowpass_filter(seq_copy[col], cutoff_freq=int(self.sampling_rate * 0.1), sampling_rate=SAMPLE_RATE, filter_order=6)

            num_samples = int(len(seq_copy) * target_sampling_rate / original_sampling_rate)
            resampled_seq = scipy.signal.resample(seq_copy, num_samples)
            resampled_X_data.append(resampled_seq)
        return resampled_X_data

    def resample_y_data(self, Y_data, original_sampling_rate, target_sampling_rate):
        resampled_Y_data = []
        for seq in Y_data:
            num_samples = int(len(seq) * target_sampling_rate / original_sampling_rate)
            resampled_seq = np.zeros((num_samples, seq.shape[1]))
            for i in range(seq.shape[1]):
                resampled_seq[:, i] = np.round(scipy.signal.resample(seq[:, i], num_samples))
            resampled_Y_data.append(resampled_seq)
        return resampled_Y_data

    def fit(self, X_data, Y_data, original_sampling_rate, target_sampling_rate, train_dir="train_batches", val_dir="val_batches", test_size=0.2, max_workers=4):
        os.makedirs(train_dir, exist_ok=True)
        os.makedirs(val_dir, exist_ok=True)

        # 라벨 인코딩
        self.create_encoder(Y_data)
        Y_data_encoded = self.fit_transform_labels(Y_data)

        # Resample the data
        X_data_resampled = self.resample_x_data(X_data, original_sampling_rate, target_sampling_rate)
        Y_data_resampled = self.resample_y_data(Y_data_encoded, original_sampling_rate, target_sampling_rate)

        # Statistics
        sequence_length = [len(seq) for seq in X_data_resampled]
        print(f'Max sequence length: {max(sequence_length)}')
        print(f'Min sequence length: {min(sequence_length)}')
        print(f'Mean sequence length: {np.mean(sequence_length)}')

        # Train/Val split
        X_train, X_val, Y_train, Y_val = train_test_split(X_data_resampled, Y_data_resampled, test_size=test_size, random_state=42)

        # Train 데이터 저장
        self._process_and_save_individual(X_train, Y_train, train_dir, max_workers)
        # Val 데이터 저장
        self._process_and_save_individual(X_val, Y_val, val_dir, max_workers)

    def _process_and_save_individual(self, X_data, Y_data, save_dir, max_workers):
        def process_and_save(idx):
            X_features = self.fit_transform_features([X_data[idx]])[0]
            Y_encoded = Y_data[idx]
            
            with open(os.path.join(save_dir, f"X_data_{idx}.pkl"), 'wb') as f:
                pickle.dump(X_features, f)
            with open(os.path.join(save_dir, f"Y_data_{idx}.pkl"), 'wb') as f:
                pickle.dump(Y_encoded, f)
            
            del X_features, Y_encoded
            gc.collect()

        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = [executor.submit(process_and_save, idx) for idx in range(len(X_data))]
            for _ in tqdm(as_completed(futures), total=len(futures), desc=f"Processing data in {save_dir}", unit="sample"):
                pass

In [10]:
feature_engineer = TimeSeriesFeatureEngineer(config.WINDOW_SIZES, config.SAMPLE_RATE_TARGET)

In [11]:
feature_engineer.fit(X_data_custom, Y_data, SAMPLE_RATE, config.SAMPLE_RATE_TARGET, max_workers=16)

Encoder classes: [array(['rampascent', 'rampdescent', 'stairascent', 'stairdescent', 'walk'],
      dtype='<U12')]
Max sequence length: 4918
Min sequence length: 1000
Mean sequence length: 1658.2311036789297


Processing data in train_batches: 100%|██████████| 2392/2392 [03:04<00:00, 12.96sample/s]
Processing data in val_batches: 100%|██████████| 598/598 [00:44<00:00, 13.33sample/s]


In [12]:
# # 예시 데이터 설정
# sample_data = sample_data = X_data[31]['foot_Accel_X']
# cutoff_rate = 0.49

# original_sampling_rate = 200
# target_sampling_rate = 100

# # 필터링된 원본 데이터
# filtered_sample_data = feature_engineer.lowpass_filter(sample_data, cutoff_freq=int(target_sampling_rate*cutoff_rate), sampling_rate=original_sampling_rate)

# # 리샘플된 데이터
# resampled_data = feature_engineer.resample_x_data([sample_data], original_sampling_rate, target_sampling_rate)[0]

# # 리샘플된 데이터에 필터링 적용
# filtered_resampled_data = feature_engineer.lowpass_filter(resampled_data, cutoff_freq=int(target_sampling_rate*cutoff_rate), sampling_rate=target_sampling_rate)

# # 필터링된 원본 데이터를 리샘플링
# resampled_filtered_sample_data = feature_engineer.resample_x_data([filtered_sample_data], original_sampling_rate, target_sampling_rate)[0]

# # 원본 데이터의 시간 축
# time_original = np.linspace(0, len(sample_data) / original_sampling_rate, len(sample_data))

# # 리샘플된 데이터의 시간 축
# time_resampled = np.linspace(0, len(sample_data) / original_sampling_rate, len(resampled_data))

# # 데이터 시각화
# plt.figure(figsize=(100, 10))

# plt.plot(time_original, sample_data, label='Original Data')
# # plt.plot(time_original, filtered_sample_data, label='Filtered Original Data', linestyle='--')
# # plt.plot(time_resampled, resampled_data, label='Resampled Data', color='orange')
# # plt.plot(time_resampled, filtered_resampled_data, label='Filtered Resampled Data', color='red', linestyle='--')
# plt.plot(time_resampled, resampled_filtered_sample_data, label='Resampled Filtered Original Data', color='green', linestyle=':')

# plt.title('Original, Resampled, and Filtered Data')
# plt.xlabel('Time (s)')
# plt.ylabel('Amplitude')
# plt.legend()
# plt.show()

# Dataloader

In [13]:
class TimeSeriesDataset(Dataset):
    def __init__(self, X_dir, Y_dir, num_samples, max_length):
        self.X_dir = X_dir
        self.Y_dir = Y_dir
        self.num_samples = num_samples
        self.max_length = max_length

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        with open(os.path.join(self.X_dir, f"X_data_{idx}.pkl"), 'rb') as f:
            X_data = pickle.load(f)
        with open(os.path.join(self.Y_dir, f"Y_data_{idx}.pkl"), 'rb') as f:
            Y_data = pickle.load(f)

        X_padded, X_mask = self.pad_or_trim_sequence(X_data)
        Y_padded, _ = self.pad_or_trim_sequence(Y_data)
        
        return X_padded, Y_padded, X_mask

    def pad_or_trim_sequence(self, sequence):
        seq_len = len(sequence)
        feature_dim = sequence.shape[1] if len(sequence.shape) > 1 else 1

        if seq_len > self.max_length:
            return torch.tensor(sequence[:self.max_length], dtype=torch.float32), torch.ones(self.max_length, dtype=torch.float32)
        else:
            padding_length = self.max_length - seq_len
            if feature_dim > 1:
                padded_seq = np.pad(sequence, ((0, padding_length), (0, 0)), 'constant', constant_values=0)
            else:
                padded_seq = np.pad(sequence, (0, padding_length), 'constant', constant_values=0)
            mask = np.concatenate([np.ones(seq_len), np.zeros(padding_length)])
            return torch.tensor(padded_seq, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)

In [14]:
# Create datasets
num_batches_train = len(os.listdir("train_batches")) // 2  # assuming one X and one Y file per batch
num_batches_val = len(os.listdir("val_batches")) // 2

train_dataset = TimeSeriesDataset(X_dir="train_batches", Y_dir="train_batches", num_samples=num_batches_train, max_length=MAX_LENGTH_TARGET)
val_dataset = TimeSeriesDataset(X_dir="val_batches", Y_dir="val_batches", num_samples=num_batches_val, max_length=MAX_LENGTH_TARGET)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

In [15]:
# import matplotlib.pyplot as plt   
# # Get padded sequences
# idx = 597
# padded_sequence_1 = val_dataset[idx][1]
# padded_sequence_2 = val_dataset[idx][0]

# # Print and visualize the padded sequences
# print(f"Padded Sequence 1 Shape: {padded_sequence_1.shape}")
# print(f"Padded Sequence 2 Shape: {padded_sequence_2.shape}")

# # Visualize the first feature of the sequences
# plt.figure(figsize=(12, 6))

# plt.subplot(2, 1, 1)
# plt.plot(padded_sequence_1[:, 0], label=class_names[0])
# plt.plot(padded_sequence_1[:, 1], label=class_names[1])
# plt.plot(padded_sequence_1[:, 2], label=class_names[2])
# plt.plot(padded_sequence_1[:, 3], label=class_names[3])
# plt.plot(padded_sequence_1[:, 4], label=class_names[4])
# plt.legend()
# plt.title('Padded Sequence 1')

# plt.subplot(2, 1, 2)
# plt.plot(padded_sequence_2[:, 4], label='Padded Sequence 2 - Feature 1')
# plt.legend()
# plt.title('Padded Sequence 2')

# plt.tight_layout()
# plt.show()

In [16]:
# for X_batch, Y_batch in train_loader:
#     print(X_batch.shape, Y_batch.shape)
#     pass

In [17]:
# 데이터 로더를 사용하여 모델의 길이, 채널 수 및 출력 채널 수 설정
first_batch = next(iter(train_loader))
length = first_batch[0].shape[1]
num_channel = first_batch[0].shape[2]
output_channels = first_batch[1].shape[-1]

# Training

In [18]:
def save_model(model, path):
    torch.save(model.state_dict(), path)

def load_model(model, path):
    model.load_state_dict(torch.load(path))
    return model

In [19]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        if self.alpha is not None:
            self.alpha = self.alpha.to(inputs.device)  # Ensure alpha is on the same device as inputs

        BCE_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = (1 - pt) ** self.gamma * BCE_loss

        if self.alpha is not None:
            alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
            F_loss = alpha_t * F_loss

        if self.reduction == 'mean':
            return F_loss.mean()
        elif self.reduction == 'sum':
            return F_loss.sum()
        else:
            return F_loss


In [20]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=100, save_dir='model_checkpoints'):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    scaler = GradScaler()

    best_val_loss = float('inf')
    os.makedirs(save_dir, exist_ok=True)

    pbar = tqdm(total=num_epochs, desc="Training model", unit="epoch")

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for X_batch, Y_batch, mask in train_loader:
            X_batch = X_batch.to(device)
            Y_batch = Y_batch.to(device)
            mask = mask.to(device)

            optimizer.zero_grad()

            with autocast():
                outputs = model(X_batch)
                if isinstance(outputs, list):  # Deep Supervision
                    loss = sum([criterion(output[mask == 1], Y_batch[mask == 1]) for output in outputs]) / mask.sum()
                else:
                    loss = (criterion(outputs, Y_batch) * mask).sum() / mask.sum()

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * X_batch.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for X_batch, Y_batch, mask in val_loader:
                X_batch = X_batch.to(device)
                Y_batch = Y_batch.to(device)
                mask = mask.to(device)

                with autocast():
                    outputs = model(X_batch)
                    if isinstance(outputs, list):  # Deep Supervision
                        loss = sum([criterion(output[mask == 1], Y_batch[mask == 1]) for output in outputs]) / mask.sum()
                    else:
                        loss = (criterion(outputs, Y_batch) * mask).sum() / mask.sum()

                val_loss += loss.item() * X_batch.size(0)

        val_loss /= len(val_loader.dataset)

        # Log metrics to wandb
        wandb.log({'train_loss': epoch_loss, 'val_loss': val_loss}, step=epoch)

        # Save the best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_path = os.path.join(save_dir, PREFIX+'best_model_checkpoint.pth')
            torch.save(model.state_dict(), best_model_path)

        pbar.set_postfix({'Loss': f'{epoch_loss:.8f}', 'Val Loss': f'{val_loss:.8f}'})
        pbar.update(1)

    # Save the last model
    last_model_path = os.path.join(save_dir, PREFIX+'last_model.pth')
    torch.save(model.state_dict(), last_model_path)

    pbar.close()
    print(f'Finished Training. Best validation loss: {best_val_loss:.8f}')
    
    return model


In [21]:
# 모델, 손실 함수 및 옵티마이저 정의
model = MultiResUNet.UNet(length=length, model_depth=config.model_depth, num_channel=num_channel, model_width=config.model_width, kernel_size=config.kernel_size, problem_type=config.problem_type, output_channels=output_channels, ds=config.ds, ae=config.ae, feature_number=config.feature_number, is_transconv=config.is_transconv)

# criterion = torch.nn.BCEWithLogitsLoss()  # 손실 함수 정의
# ['ramp ascent', 'ramp descent', 'stair ascent', 'stair descent', 'walk']
alpha = np.array([0.1, 0.3, 0.1, 0.3, 0.001])
criterion = FocalLoss(alpha=torch.tensor(alpha, dtype=torch.float32), gamma=2)

optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)  # 옵티마이저 정의

## Wandb Single

In [22]:
# Initialize wandb
wandb.init(project='RT5307', config=config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjojaebeom[0m ([33mjaebeom[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [23]:
# 모델 학습
trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=EPOCHS, save_dir=SAVE_DIR)

  return F.conv1d(input, weight, bias, self.stride,


OutOfMemoryError: CUDA out of memory. Tried to allocate 66.00 MiB. GPU 0 has a total capacity of 23.64 GiB of which 65.56 MiB is free. Process 30200 has 443.44 MiB memory in use. Process 173530 has 7.73 GiB memory in use. Process 178451 has 4.31 GiB memory in use. Including non-PyTorch memory, this process has 10.38 GiB memory in use. Of the allocated memory 9.80 GiB is allocated by PyTorch, and 102.22 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

# Eval

In [None]:
def predict(model, data_loader, criterion):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    
    all_preds = []
    all_labels = []
    all_probabilities = []
    running_loss = 0.0
    
    with torch.no_grad():
        for X_batch, Y_batch, mask in data_loader:
            X_batch = X_batch.to(device)
            Y_batch = Y_batch.to(device)
            mask = mask.to(device)

            with autocast():
                outputs = model(X_batch)
                if isinstance(outputs, list):  # Deep Supervision
                    outputs = outputs[-1]  # Use the last output
                loss = criterion(outputs[mask == 1], Y_batch[mask == 1])
                running_loss += loss.item() * X_batch.size(0)

                probs = torch.softmax(outputs, dim=2)  # Calculate probabilities for each class
                probs = probs.cpu().numpy()

                # Apply mask to probabilities
                masked_probs = [probs[j, mask[j].cpu().numpy() == 1] for j in range(probs.shape[0])]

                preds = [np.argmax(p, axis=1) for p in masked_probs]  # Get predicted class indices
                labels = [torch.argmax(Y_batch[j, mask[j] == 1], dim=1).cpu().numpy() for j in range(Y_batch.shape[0])]  # Get true class indices

                all_preds.extend(preds)
                all_labels.extend(labels)
                all_probabilities.extend(masked_probs)
    
    avg_loss = running_loss / len(data_loader.dataset)
    
    return all_preds, all_labels, all_probabilities, avg_loss

In [None]:
def plot_confusion_matrix(true_labels, pred_labels, class_names, save_dir):
    cm = confusion_matrix(true_labels, pred_labels)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig(os.path.join(save_dir, 'confusion_matrix.png'))
    plt.close()

In [None]:
def calculate_accuracy(true_labels, pred_labels):
    accuracy = accuracy_score(true_labels, pred_labels)
    return accuracy

In [None]:
def plot_probabilities(true_labels, pred_labels, probabilities, class_names, save_dir, idx):
    num_classes = len(class_names)
    time_steps = probabilities.shape[0]

    fig, axes = plt.subplots(num_classes, 1, figsize=(10, num_classes * 2), sharex=True)

    if num_classes == 1:
        axes = [axes]

    # Create one-hot encoded true labels
    true_labels_one_hot = np.zeros((time_steps, num_classes))
    for t in range(len(true_labels)):
        true_labels_one_hot[t, true_labels[t]] = 1

    # Create one-hot encoded predicted labels
    pred_labels_one_hot = np.zeros((time_steps, num_classes))
    for t in range(len(pred_labels)):
        pred_labels_one_hot[t, pred_labels[t]] = 1

    color_prob = '#4A4A4A'  # Dark Gray
    color_true = '#00BFFF'  # Deep Sky Blue
    color_pred = '#F08080'  # Light Coral

    for i, class_name in enumerate(class_names):
        axes[i].plot(range(len(probabilities)), probabilities[:, i], label='Probability', alpha=0.6, color=color_prob)
        axes[i].fill_between(range(len(probabilities)), 0, probabilities[:, i], alpha=0.2, color=color_prob)
        axes[i].plot(range(len(true_labels_one_hot)), true_labels_one_hot[:, i], linestyle='dashed', label='True', alpha=0.6, color=color_true)
        axes[i].fill_between(range(len(true_labels_one_hot)), 0, true_labels_one_hot[:, i], alpha=0.2, color=color_true)
        axes[i].plot(range(len(pred_labels_one_hot)), pred_labels_one_hot[:, i], linestyle='dotted', label='Predicted', alpha=0.6, color=color_pred)
        axes[i].fill_between(range(len(pred_labels_one_hot)), 0, pred_labels_one_hot[:, i], alpha=0.2, color=color_pred)
                
        axes[i].set_ylabel('Probability', fontsize=14)
        axes[i].set_ylim(0, 1)
        axes[i].set_title(class_name, fontsize=18)
        axes[i].legend(fontsize=14)

    axes[-1].set_xlabel('Time Steps', fontsize=14)

    fig.suptitle(f'{idx}th Result', fontsize=24, y=0.99, x=0.85)
    plt.tight_layout(rect=[0, 0, 1, 1.02])

    # Ensure save directory exists
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    save_path = os.path.join(save_dir, f'test_{idx}_probabilities.png')
    plt.savefig(save_path, dpi=300)
    plt.close()

In [None]:
def plot_probabilities_for_all_trials(true_labels, pred_labels, probabilities, class_names, save_dir):
    total_plots = len(probabilities)
    max_workers = 8 
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        with tqdm(total=total_plots, desc="Plotting probabilities", unit="plot") as progress_bar:
            futures = []
            for idx in range(total_plots):
                futures.append(executor.submit(plot_probabilities, true_labels[idx], pred_labels[idx], probabilities[idx], class_names, save_dir, idx))
            
            for future in as_completed(futures):
                try:
                    future.result()
                except Exception as e:
                    print(f"Error occurred: {str(e)}")
                progress_bar.update(1)

In [None]:
model = MultiResUNet.UNet(length=length, model_depth=config.model_depth, num_channel=num_channel, model_width=config.model_width, kernel_size=config.kernel_size, problem_type=config.problem_type, output_channels=output_channels, ds=config.ds, ae=config.ae, feature_number=config.feature_number, is_transconv=config.is_transconv)

criterion = torch.nn.BCEWithLogitsLoss() 

loaded_model = load_model(model, os.path.join(SAVE_DIR, PREFIX+'best_model_checkpoint.pth'))

In [None]:
data_loader = val_loader

pred_labels, true_labels, probabilities, avg_loss = predict(model, data_loader, criterion)

In [None]:
accuracy = calculate_accuracy(np.concatenate(true_labels).flatten(), np.concatenate(pred_labels).flatten())
print(f"Avg Loss: {avg_loss:.8f}, Accuracy: {accuracy:.8f}")

In [None]:
plot_confusion_matrix(np.concatenate(true_labels).flatten(), np.concatenate(pred_labels).flatten(), class_names, SAVE_DIR)    

In [None]:
plot_probabilities_for_all_trials(true_labels, pred_labels, probabilities, class_names, SAVE_DIR)