This implementation utilizes PyTorch to construct the Deep Belief Network(DBN) model

# Import Libraries

In [5]:
# in case torch is not installed
%pip install torch scikit-learn

Collecting torch
  Downloading torch-2.4.0-cp310-cp310-win_amd64.whl (197.9 MB)
Collecting scikit-learn
  Using cached scikit_learn-1.5.1-cp310-cp310-win_amd64.whl (11.0 MB)
Collecting jinja2
  Using cached jinja2-3.1.4-py3-none-any.whl (133 kB)
Collecting filelock
  Using cached filelock-3.15.4-py3-none-any.whl (16 kB)
Collecting sympy
  Using cached sympy-1.13.2-py3-none-any.whl (6.2 MB)
Collecting fsspec
  Using cached fsspec-2024.6.1-py3-none-any.whl (177 kB)
Collecting networkx
  Using cached networkx-3.3-py3-none-any.whl (1.7 MB)
Collecting numpy>=1.19.5
  Using cached numpy-2.1.0-cp310-cp310-win_amd64.whl (12.9 MB)
Collecting joblib>=1.2.0
  Using cached joblib-1.4.2-py3-none-any.whl (301 kB)
Collecting threadpoolctl>=3.1.0
  Using cached threadpoolctl-3.5.0-py3-none-any.whl (18 kB)
Collecting scipy>=1.6.0
  Using cached scipy-1.14.1-cp310-cp310-win_amd64.whl (44.8 MB)
Collecting MarkupSafe>=2.0
  Using cached MarkupSafe-2.1.5-cp310-cp310-win_amd64.whl (17 kB)
Collecting mpmath<

You should consider upgrading via the 'c:\Users\User\AppData\Local\Programs\Python\Python310\python.exe -m pip install --upgrade pip' command.


In [2]:
# in case pandas was not installed
%pip install pandas

Collecting pandas
  Downloading pandas-2.2.2-cp310-cp310-win_amd64.whl (11.6 MB)
Collecting pytz>=2020.1
  Downloading pytz-2024.1-py2.py3-none-any.whl (505 kB)
Collecting tzdata>=2022.7
  Downloading tzdata-2024.1-py2.py3-none-any.whl (345 kB)
Installing collected packages: tzdata, pytz, pandas
Successfully installed pandas-2.2.2 pytz-2024.1 tzdata-2024.1
Note: you may need to restart the kernel to use updated packages.


You should consider upgrading via the 'c:\Users\User\AppData\Local\Programs\Python\Python310\python.exe -m pip install --upgrade pip' command.


In [3]:
# in case matplotlib was not installed
%pip install matplotlib

Collecting matplotlib
  Downloading matplotlib-3.9.2-cp310-cp310-win_amd64.whl (7.8 MB)
Collecting cycler>=0.10
  Downloading cycler-0.12.1-py3-none-any.whl (8.3 kB)
Collecting pillow>=8
  Downloading pillow-10.4.0-cp310-cp310-win_amd64.whl (2.6 MB)
Collecting fonttools>=4.22.0
  Downloading fonttools-4.53.1-cp310-cp310-win_amd64.whl (2.2 MB)
Collecting pyparsing>=2.3.1
  Downloading pyparsing-3.1.4-py3-none-any.whl (104 kB)
Collecting contourpy>=1.0.1
  Downloading contourpy-1.3.0-cp310-cp310-win_amd64.whl (216 kB)
Collecting kiwisolver>=1.3.1
  Downloading kiwisolver-1.4.5-cp310-cp310-win_amd64.whl (56 kB)
Installing collected packages: pyparsing, pillow, kiwisolver, fonttools, cycler, contourpy, matplotlib
Successfully installed contourpy-1.3.0 cycler-0.12.1 fonttools-4.53.1 kiwisolver-1.4.5 matplotlib-3.9.2 pillow-10.4.0 pyparsing-3.1.4
Note: you may need to restart the kernel to use updated packages.


You should consider upgrading via the 'c:\Users\User\AppData\Local\Programs\Python\Python310\python.exe -m pip install --upgrade pip' command.


In [22]:
%pip install tqdm

Collecting tqdm
  Downloading tqdm-4.66.5-py3-none-any.whl (78 kB)
Installing collected packages: tqdm
Successfully installed tqdm-4.66.5
Note: you may need to restart the kernel to use updated packages.


You should consider upgrading via the 'c:\Users\User\AppData\Local\Programs\Python\Python310\python.exe -m pip install --upgrade pip' command.


In [23]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, jaccard_score
import matplotlib.pyplot as plt
import gc
from PIL import Image
from tqdm import tqdm

# Configurations

In [17]:
# Setup the paths to train and test datasets
TRAIN_DIR = './global-wheat-detection/train/'
TEST_DIR = './global-wheat-detection/test/'
TRAIN_CSV_PATH = './global-wheat-detection/train.csv'
AUG_SAVE_DIR = './global-wheat-detection/augmented_images/'
SAVE_PATH = 'models/DBN/'
CHECKPOINT_DIR = 'models/DBN/checkpoints/'

# Model configuration
EPOCHS = 10
IMG_SIZE = 256
VISIBLE_UNITS = IMG_SIZE * IMG_SIZE
HIDDEN_UNITS_1 = 256  # Number of hidden units in the first RBM
HIDDEN_UNITS_2 = 128  # Number of hidden units in the second RBM
N_CLASSES = 2
BATCH_SIZE = 64  # Batch size for training
LEARNING_RATE = 0.01  # Learning rate for both pre-training and fine-tuning

# Create directories if they don't exist
if not os.path.exists(SAVE_PATH):
    os.makedirs(SAVE_PATH)
if not os.path.exists(CHECKPOINT_DIR):
    os.makedirs(CHECKPOINT_DIR)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data Loading

In [18]:
df = pd.read_csv('./df_augment.csv')
# df = pd.read_csv('./df_full.csv')
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 242466 entries, 0 to 242465
Data columns (total 6 columns):
 #   Column    Non-Null Count   Dtype  
---  ------    --------------   -----  
 0   image_id  242466 non-null  object 
 1   x         242466 non-null  float64
 2   y         242466 non-null  float64
 3   w         242466 non-null  float64
 4   h         242466 non-null  float64
 5   source    242466 non-null  object 
dtypes: float64(4), object(2)
memory usage: 11.1+ MB


Preprocessing

In [26]:
# Convert image paths and labels to tensors
def load_image_and_label(image_id, label, IMG_SIZE=IMG_SIZE):
    if not os.path.exists(image_id):
        raise FileNotFoundError(f"File not found: {image_id}")
    image = Image.open(image_id).convert('L')  # Convert to grayscale if necessary
    image = image.resize((IMG_SIZE, IMG_SIZE))  # Resize to the desired size
    image = np.array(image, dtype=np.float32) / 255.0  # Normalize the image to [0, 1] range
    label = np.array(label, dtype=np.int64)
    return torch.tensor(image, dtype=torch.float32), torch.tensor(label, dtype=torch.long)


def parse_dataframe(df, TRAIN_DIR=TRAIN_DIR, AUG_SAVE_DIR=AUG_SAVE_DIR):
    image_paths = []
    bboxes = []
    for _, row in tqdm(df.iterrows(), total=len(df), desc='Parsing dataframe and image path'):
        image_id = row['image_id']
        bbox = row[['x', 'y', 'w', 'h']].values  # Assuming these are bounding box coordinates
        ori_image_path = os.path.join(TRAIN_DIR, f'{image_id}.jpg')
        augmented_image_path = os.path.join(AUG_SAVE_DIR, f'{image_id}.jpg')
        
        # Check if the original or augmented image exists
        if os.path.exists(ori_image_path):
            image_path = ori_image_path
        elif os.path.exists(augmented_image_path):
            image_path = augmented_image_path
        else:
            raise FileNotFoundError(f"Image not found for ID: {image_id}")
        
        image_paths.append(image_path)
        bboxes.append(bbox)
    
    return image_paths, bboxes

# Create PyTorch datasets
def create_dataset(image_paths, labels):
    images = []
    targets = []
    for img_path, label in zip(image_paths, labels):
        image, target = load_image_and_label(img_path, label)
        images.append(image)
        targets.append(target)
    images_tensor = torch.stack(images)
    labels_tensor = torch.tensor(targets)
    dataset = TensorDataset(images_tensor, labels_tensor)
    return dataset

# Load and preprocess the data
image_paths, labels = parse_dataframe(df)

# Split the dataset
train_paths, test_paths, train_labels, test_labels = train_test_split(image_paths, labels, test_size=0.4, random_state=42)
test_paths, val_paths, test_labels, val_labels = train_test_split(test_paths, test_labels, test_size=0.5, random_state=42)

# Create PyTorch Datasets
train_dataset = create_dataset(train_paths, train_labels)
val_dataset = create_dataset(val_paths, val_labels)
test_dataset = create_dataset(test_paths, test_labels)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Inspect the DataLoader
for images, labels in train_loader:
    print(images.shape, labels.shape, len(train_loader))
    break

for images, labels in val_loader:
    print(images.shape, labels.shape, len(val_loader))
    break

for images, labels in test_loader:
    print(images.shape, labels.shape, len(test_loader))
    break

Parsing dataframe and image path:   0%|          | 0/242466 [00:00<?, ?it/s]

Parsing dataframe and image path: 100%|██████████| 242466/242466 [02:18<00:00, 1745.65it/s]


RuntimeError: [enforce fail at alloc_cpu.cpp:114] data. DefaultCPUAllocator: not enough memory: you tried to allocate 262144 bytes.

# Model Building

In [None]:
import matplotlib.pyplot as plt

def plot_metrics(history):
    epochs = range(1, len(history['loss']) + 1)

    # Plot loss
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, history['loss'], 'r-', label='Training Loss')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Plot accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, history['accuracy'], 'b-', label='Training Accuracy')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.show()

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score, jaccard_score

# Precision Metric
def precision_metric(y_true, y_pred):
    return precision_score(y_true, y_pred, average='weighted')

# Recall Metric
def recall_metric(y_true, y_pred):
    return recall_score(y_true, y_pred, average='weighted')

# F1 Score Metric
def f1_metric(y_true, y_pred):
    return f1_score(y_true, y_pred, average='weighted')

# Intersection over Union (IoU) Metric
def iou_metric(y_true, y_pred):
    return jaccard_score(y_true, y_pred, average='weighted')

# DBN_1

In [None]:
class DBNModel:
    def __init__(self, visible_units, hidden_units_1, hidden_units_2, n_classes):
        self.visible_units = visible_units
        self.hidden_units_1 = hidden_units_1
        self.hidden_units_2 = hidden_units_2
        self.n_classes = n_classes
        self.model = self.build_model()

    def build_model(self):
        # Initialize RBMs
        rbm1 = RBM(visible_units=self.visible_units, hidden_units=self.hidden_units_1)
        rbm2 = RBM(visible_units=self.hidden_units_1, hidden_units=self.hidden_units_2)

        # Stack RBMs to form a DBN
        dbn = DBN(rbm_layers=[rbm1, rbm2], n_classes=self.n_classes)
        return dbn

    def compile_model(self, learning_rate=0.01):
        # Compile the model by setting up the optimizer and loss function
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        self.criterion = nn.CrossEntropyLoss()

    def train(self, train_loader, epochs=10, checkpoint_dir=None):
        history = {'loss': [], 'accuracy': []}
        best_loss = float('inf')

        for epoch in range(epochs):
            self.model.train()
            total_loss = 0
            correct = 0

            for batch_idx, (data, target) in enumerate(train_loader):
                self.optimizer.zero_grad()
                output = self.model(data)
                loss = self.criterion(output, target)
                loss.backward()
                self.optimizer.step()

                total_loss += loss.item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

            avg_loss = total_loss / len(train_loader)
            accuracy = correct / len(train_loader.dataset)
            history['loss'].append(avg_loss)
            history['accuracy'].append(accuracy)

            print(f'Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}')

            # Save the model if it has the best loss
            if checkpoint_dir and avg_loss < best_loss:
                best_loss = avg_loss
                self.save(os.path.join(checkpoint_dir, 'dbn_model.pth'))

        return history

    def evaluate(self, test_loader):
        self.model.eval()
        total_loss = 0
        correct = 0
        y_true = []
        y_pred = []

        with torch.no_grad():
            for data, target in test_loader:
                output = self.model(data)
                loss = self.criterion(output, target)
                total_loss += loss.item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

                y_true.extend(target.cpu().numpy())
                y_pred.extend(pred.cpu().numpy())

        avg_loss = total_loss / len(test_loader)
        accuracy = correct / len(test_loader.dataset)
        precision = precision_metric(y_true, y_pred)
        recall = recall_metric(y_true, y_pred)
        f1 = f1_metric(y_true, y_pred)
        iou = iou_metric(y_true, y_pred)

        print(f'Test Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}, IoU: {iou:.4f}')

        return {
            'loss': avg_loss,
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'iou': iou
        }

    def save(self, path):
        torch.save(self.model.state_dict(), path)

    def load(self, path):
        self.model.load_state_dict(torch.load(path))
        self.model.eval()

# Initialize the DBNModel class
dbn_model = DBNModel(visible_units=VISIBLE_UNITS, hidden_units_1=HIDDEN_UNITS_1, hidden_units_2=HIDDEN_UNITS_2, n_classes=N_CLASSES)

# Compile the model
dbn_model.compile_model(learning_rate=LEARNING_RATE)

# Print model summary (not as straightforward in PyTorch, but showing the structure)
print(dbn_model.model)

In [None]:
# Train the DBN model
history = dbn_model.train(train_loader, epochs=EPOCHS, checkpoint_dir=CHECKPOINT_DIR)

# Evaluate the DBN model on the test set
test_metrics = dbn_model.evaluate(test_loader)

In [None]:
# Plot the training metrics
plot_metrics(history)

In [None]:
# Save the entire DBN model
dbn_model.save(SAVE_PATH + 'dbn_model.pth')

In [None]:
# Load the DBN model from a file
# dbn_model.load(SAVE_PATH + 'dbn_model.pth')

# Perform garbage collection
gc.collect()

# DBN_2

The hidden units in the RBM layers are increased to 512 and 256, respectively, which makes this model potentially more capable of capturing complex features but at the cost of increased computational requirements.

In [None]:
# Define the DBN model with more hidden units
dbn_model_2 = DBNModel(visible_units=VISIBLE_UNITS, 
                       hidden_units_1=512,  # Increased from 256 to 512
                       hidden_units_2=256,  # Increased from 128 to 256
                       n_classes=N_CLASSES)

# Compile the model
dbn_model_2.compile_model(learning_rate=LEARNING_RATE)

# Print model summary
print(dbn_model_2.model)

In [None]:
# Train the model
history_2 = dbn_model_2.train(train_loader, epochs=EPOCHS, checkpoint_dir=CHECKPOINT_DIR)

# Evaluate the model on the test set
test_metrics_2 = dbn_model_2.evaluate(test_loader)

In [None]:
# Plot the training metrics
plot_metrics(history_2)

In [None]:
# Save the model
dbn_model_2.save(SAVE_PATH + 'dbn_model_2.pth')

In [None]:
# Load the model (if needed)
# dbn_model_2.load(SAVE_PATH + 'dbn_model_2.pth')
gc.collect()

# DBN_3

A third RBM layer with 64 hidden units is added, making the network deeper. This modification allows the DBN to potentially learn even more abstract features but also requires careful tuning and more data to avoid overfitting.

In [None]:
# Define the DBN model with an additional RBM layer
class DBNModelExtended(DBNModel):
    def build_model(self):
        # Initialize RBMs with an additional layer
        rbm1 = RBM(visible_units=self.visible_units, hidden_units=self.hidden_units_1)
        rbm2 = RBM(visible_units=self.hidden_units_1, hidden_units=self.hidden_units_2)
        rbm3 = RBM(visible_units=self.hidden_units_2, hidden_units=64)  # New additional RBM layer
        
        # Stack RBMs to form a DBN
        dbn = DBN(rbm_layers=[rbm1, rbm2, rbm3], n_classes=self.n_classes)
        return dbn

# Initialize the DBN model
dbn_model_3 = DBNModelExtended(visible_units=VISIBLE_UNITS, 
                               hidden_units_1=HIDDEN_UNITS_1, 
                               hidden_units_2=HIDDEN_UNITS_2, 
                               n_classes=N_CLASSES)

# Compile the model
dbn_model_3.compile_model(learning_rate=LEARNING_RATE)

# Print model summary
print(dbn_model_3.model)

In [None]:
# Train the model
history_3 = dbn_model_3.train(train_loader, epochs=EPOCHS, checkpoint_dir=CHECKPOINT_DIR)

# Evaluate the model on the test set
test_metrics_3 = dbn_model_3.evaluate(test_loader)

In [None]:
# Plot the training metrics
plot_metrics(history_3)

In [None]:
# Save the model
dbn_model_3.save(SAVE_PATH + 'dbn_model_3.pth')

In [None]:
# Load the model (if needed)
# dbn_model_3.load(SAVE_PATH + 'dbn_model_3.pth')
gc.collect()

# Saving History

[I just did for first DBN model]

In [None]:
import pickle

# Save the training history for the first DBN model
with open('dbn_training_history_1.pkl', 'wb') as file:
    pickle.dump(history_1, file)

# Save the training history for the second DBN model
with open('dbn_training_history_2.pkl', 'wb') as file:
    pickle.dump(history_2, file)

# Save the training history for the third DBN model
with open('dbn_training_history_3.pkl', 'wb') as file:
    pickle.dump(history_3, file)

In [None]:
# Load the training history for the first DBN model
with open('dbn_training_history_1.pkl', 'rb') as file:
    loaded_history_1 = pickle.load(file)

# Initialize the model again if needed
dbn_model = DBNModel(visible_units=VISIBLE_UNITS, 
                     hidden_units_1=HIDDEN_UNITS_1, 
                     hidden_units_2=HIDDEN_UNITS_2, 
                     n_classes=N_CLASSES)

# Compile the model
dbn_model.compile_model(learning_rate=LEARNING_RATE)

# Train the model for zero epochs to create an empty history object
empty_history_1 = dbn_model.train(train_loader, epochs=0, checkpoint_dir=None)

# Set the loaded history to the empty history object
empty_history_1 = loaded_history_1

# Now empty_history_1 contains the loaded history
plot_metrics(empty_history_1)

# Hyperparameter Tuning (For DBN_1)

In [None]:
import optuna
import torch
import torch.nn as nn
import torch.optim as optim
from optuna import Trial

class DBNHyperModel:
    def __init__(self, visible_units, n_classes):
        self.visible_units = visible_units
        self.n_classes = n_classes

    def build(self, trial: Trial):
        # Sample hyperparameters
        hidden_units_1 = trial.suggest_int('hidden_units_1', 128, 512, step=64)
        hidden_units_2 = trial.suggest_int('hidden_units_2', 64, 256, step=64)
        learning_rate = trial.suggest_loguniform('learning_rate', 1e-4, 1e-2)

        # Build the DBN
        rbm1 = RBM(visible_units=self.visible_units, hidden_units=hidden_units_1)
        rbm2 = RBM(visible_units=hidden_units_1, hidden_units=hidden_units_2)

        dbn = DBN(rbm_layers=[rbm1, rbm2], n_classes=self.n_classes)
        optimizer = optim.Adam(dbn.parameters(), lr=learning_rate)
        
        return dbn, optimizer

    def objective(self, trial: Trial):
        # Build the model
        model, optimizer = self.build(trial)
        criterion = nn.CrossEntropyLoss()

        # Train the model
        model.train()
        for epoch in range(EPOCHS):
            total_loss = 0
            for data, target in train_loader:
                optimizer.zero_grad()
                output = model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            avg_loss = total_loss / len(train_loader)

            # Validation
            model.eval()
            val_loss = 0
            with torch.no_grad():
                for data, target in val_loader:
                    output = model(data)
                    loss = criterion(output, target)
                    val_loss += loss.item()
            val_loss /= len(val_loader)

            # Report intermediate objective value
            trial.report(val_loss, epoch)

            # Handle pruning (optional)
            if trial.should_prune():
                raise optuna.exceptions.TrialPruned()

        return val_loss

# Initialize the hypermodel
hypermodel = DBNHyperModel(visible_units=VISIBLE_UNITS, n_classes=N_CLASSES)

# Create an Optuna study
study = optuna.create_study(direction='minimize')

In [None]:
# Retrieve the best hyperparameters
best_hyperparameters = study.best_trial

# Extract the best learning rate if it was tuned
best_learning_rate = best_hyperparameters.params.get('learning_rate', 0.001)  # Default to 0.001 if not tuned

# Print the best learning rate to verify
print(f"Best learning rate: {best_learning_rate}")
print(best_hyperparameters.params)

# Build the model using the best hyperparameters
best_model, best_optimizer = hypermodel.build(best_hyperparameters)

# Ensure the optimizer uses the best learning rate
for param_group in best_optimizer.param_groups:
    param_group['lr'] = best_learning_rate

# Example of a learning rate scheduler and early stopping in PyTorch
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Learning rate scheduler
scheduler = ReduceLROnPlateau(best_optimizer, mode='min', factor=0.2, patience=3, min_lr=1e-5)

# Early stopping parameters
early_stopping_patience = 5
best_val_loss = float('inf')
patience_counter = 0

history_best = {'loss': [], 'val_loss': [], 'accuracy': [], 'val_accuracy': []}

for epoch in range(EPOCHS):
    best_model.train()
    total_loss = 0
    correct = 0
    
    # Training loop
    for data, target in train_loader:
        best_optimizer.zero_grad()
        output = best_model(data)
        loss = nn.CrossEntropyLoss()(output, target)
        loss.backward()
        best_optimizer.step()
        
        total_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    
    # Validation loop
    best_model.eval()
    val_loss = 0
    val_correct = 0
    with torch.no_grad():
        for data, target in val_loader:
            output = best_model(data)
            loss = nn.CrossEntropyLoss()(output, target)
            val_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            val_correct += pred.eq(target.view_as(pred)).sum().item()
    
    # Compute average losses
    avg_train_loss = total_loss / len(train_loader)
    avg_val_loss = val_loss / len(val_loader)
    train_accuracy = correct / len(train_loader.dataset)
    val_accuracy = val_correct / len(val_loader.dataset)
    
    # Append to history
    history_best['loss'].append(avg_train_loss)
    history_best['val_loss'].append(avg_val_loss)
    history_best['accuracy'].append(train_accuracy)
    history_best['val_accuracy'].append(val_accuracy)
    
    # Print epoch summary
    print(f'Epoch {epoch+1}/{EPOCHS}, Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Accuracy: {train_accuracy:.4f}, Val Accuracy: {val_accuracy:.4f}')
    
    # Early stopping and learning rate scheduling
    scheduler.step(avg_val_loss)
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save(best_model.state_dict(), os.path.join(CHECKPOINT_DIR, 'best_dbn_model.pth'))  # Save the best model
    else:
        patience_counter += 1
    
    if patience_counter >= early_stopping_patience:
        print("Early stopping triggered.")
        break

# Load the best model
best_model.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, 'best_dbn_model.pth')))
best_model.eval()

# Evaluate the best model on the test set
test_loss = 0
correct = 0
with torch.no_grad():
    for data, target in test_loader:
        output = best_model(data)
        loss = nn.CrossEntropyLoss()(output, target)
        test_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)
test_accuracy = correct / len(test_loader.dataset)

print(f'Best model testing Loss: {test_loss:.4f}, Testing Accuracy: {test_accuracy:.4f}')

In [None]:
plot_metrics(history_best)

In [None]:
# Save the best model
torch.save(best_model.state_dict(), SAVE_PATH + 'best_dbn_model.pth')