In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
from sklearn.model_selection import train_test_split
import glob
import mne
import scipy.signal as sg
from dataset import TimeSeriesDataset
from models import Model_CNN_LSTM, MultiResolutionModel, MultiResolutionLSTMFFT
import mlflow.pytorch
import torchmetrics
from losses import FocalLoss
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import io
from torchview import draw_graph
import tqdm

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [2]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
files_fpath = "./DatabaseSubjects/"
data_files = glob.glob(files_fpath + "/*.edf")
anno_files = glob.glob(files_fpath + "/HypnogramAASM_*.txt")

In [5]:
##Separete into train and test files
train_data_files, test_data_files, train_anno_files, test_anno_files = train_test_split(data_files, anno_files, test_size=0.2, random_state=42)

In [6]:
train_dataset = TimeSeriesDataset(train_data_files, train_anno_files, n_past=0, segmentation_size_sec=5)
test_dataset = TimeSeriesDataset(test_data_files, test_anno_files, n_past=0, segmentation_size_sec=5)


Extracting EDF parameters from c:\Users\rafar\Documents\2024\Disciplina - Machine Learning Para séries Temporais\Trabalho1\DatabaseSubjects\subject17.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  data = mne.io.read_raw_edf(file, preload=True)
  value = np.nanmax([_prefilter_float(x) for x in values])
  data = mne.io.read_raw_edf(file, preload=True)
  value = np.nanmin([_prefilter_float(x) for x in values])


Reading 0 ... 5981999  =      0.000 ... 29909.995 secs...
Signal shape: (6, 5982000)
Signal shape: (2, 5982000)
Data shape: (8, 5982000), fs: 200, segmentation size: 1000, annotations: 5982
Data shape after padding: (5983000, 8), annotations: 5982
Extracting EDF parameters from c:\Users\rafar\Documents\2024\Disciplina - Machine Learning Para séries Temporais\Trabalho1\DatabaseSubjects\subject14.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 6025999  =      0.000 ... 30129.995 secs...


  data = mne.io.read_raw_edf(file, preload=True)
  data = mne.io.read_raw_edf(file, preload=True)


Signal shape: (6, 6026000)
Signal shape: (2, 6026000)
Data shape: (8, 6026000), fs: 200, segmentation size: 1000, annotations: 6026
Data shape after padding: (6027000, 8), annotations: 6026
Extracting EDF parameters from c:\Users\rafar\Documents\2024\Disciplina - Machine Learning Para séries Temporais\Trabalho1\DatabaseSubjects\subject2.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 5909999  =      0.000 ... 29549.995 secs...


  data = mne.io.read_raw_edf(file, preload=True)
  data = mne.io.read_raw_edf(file, preload=True)


Signal shape: (6, 5910000)
Signal shape: (2, 5910000)
Data shape: (8, 5910000), fs: 200, segmentation size: 1000, annotations: 5910
Data shape after padding: (5911000, 8), annotations: 5910
Extracting EDF parameters from c:\Users\rafar\Documents\2024\Disciplina - Machine Learning Para séries Temporais\Trabalho1\DatabaseSubjects\subject12.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 5767999  =      0.000 ... 28839.995 secs...


  data = mne.io.read_raw_edf(file, preload=True)
  data = mne.io.read_raw_edf(file, preload=True)


Signal shape: (6, 5768000)
Signal shape: (2, 5768000)
Data shape: (8, 5768000), fs: 200, segmentation size: 1000, annotations: 5768
Data shape after padding: (5769000, 8), annotations: 5768
Extracting EDF parameters from c:\Users\rafar\Documents\2024\Disciplina - Machine Learning Para séries Temporais\Trabalho1\DatabaseSubjects\subject8.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 5819999  =      0.000 ... 29099.995 secs...


  data = mne.io.read_raw_edf(file, preload=True)
  data = mne.io.read_raw_edf(file, preload=True)


Signal shape: (6, 5820000)
Signal shape: (2, 5820000)
Data shape: (8, 5820000), fs: 200, segmentation size: 1000, annotations: 5820
Data shape after padding: (5821000, 8), annotations: 5820
Extracting EDF parameters from c:\Users\rafar\Documents\2024\Disciplina - Machine Learning Para séries Temporais\Trabalho1\DatabaseSubjects\subject6.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 5983999  =      0.000 ... 29919.995 secs...


  data = mne.io.read_raw_edf(file, preload=True)
  data = mne.io.read_raw_edf(file, preload=True)


Signal shape: (6, 5984000)
Signal shape: (2, 5984000)
Data shape: (8, 5984000), fs: 200, segmentation size: 1000, annotations: 5984
Data shape after padding: (5985000, 8), annotations: 5984
Extracting EDF parameters from c:\Users\rafar\Documents\2024\Disciplina - Machine Learning Para séries Temporais\Trabalho1\DatabaseSubjects\subject3.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 6051999  =      0.000 ... 30259.995 secs...


  data = mne.io.read_raw_edf(file, preload=True)
  data = mne.io.read_raw_edf(file, preload=True)


Signal shape: (6, 6052000)
Signal shape: (2, 6052000)
Data shape: (8, 6052000), fs: 200, segmentation size: 1000, annotations: 6052
Data shape after padding: (6053000, 8), annotations: 6052
Extracting EDF parameters from c:\Users\rafar\Documents\2024\Disciplina - Machine Learning Para séries Temporais\Trabalho1\DatabaseSubjects\subject11.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 6049999  =      0.000 ... 30249.995 secs...


  data = mne.io.read_raw_edf(file, preload=True)
  data = mne.io.read_raw_edf(file, preload=True)


Signal shape: (6, 6050000)
Signal shape: (2, 6050000)
Data shape: (8, 6050000), fs: 200, segmentation size: 1000, annotations: 6050
Data shape after padding: (6051000, 8), annotations: 6050
Extracting EDF parameters from c:\Users\rafar\Documents\2024\Disciplina - Machine Learning Para séries Temporais\Trabalho1\DatabaseSubjects\subject18.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 6127999  =      0.000 ... 30639.995 secs...


  data = mne.io.read_raw_edf(file, preload=True)
  data = mne.io.read_raw_edf(file, preload=True)


Signal shape: (6, 6128000)
Signal shape: (2, 6128000)
Data shape: (8, 6128000), fs: 200, segmentation size: 1000, annotations: 6128
Data shape after padding: (6129000, 8), annotations: 6128
Extracting EDF parameters from c:\Users\rafar\Documents\2024\Disciplina - Machine Learning Para séries Temporais\Trabalho1\DatabaseSubjects\subject9.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 6703999  =      0.000 ... 33519.995 secs...


  data = mne.io.read_raw_edf(file, preload=True)
  data = mne.io.read_raw_edf(file, preload=True)


Signal shape: (6, 6704000)
Signal shape: (2, 6704000)
Data shape: (8, 6704000), fs: 200, segmentation size: 1000, annotations: 6704
Data shape after padding: (6705000, 8), annotations: 6704
Extracting EDF parameters from c:\Users\rafar\Documents\2024\Disciplina - Machine Learning Para séries Temporais\Trabalho1\DatabaseSubjects\subject13.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 6667999  =      0.000 ... 33339.995 secs...


  data = mne.io.read_raw_edf(file, preload=True)
  data = mne.io.read_raw_edf(file, preload=True)


Signal shape: (6, 6668000)
Signal shape: (2, 6668000)
Data shape: (8, 6668000), fs: 200, segmentation size: 1000, annotations: 6668
Data shape after padding: (6669000, 8), annotations: 6668
Extracting EDF parameters from c:\Users\rafar\Documents\2024\Disciplina - Machine Learning Para séries Temporais\Trabalho1\DatabaseSubjects\subject20.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 6869999  =      0.000 ... 34349.995 secs...


  data = mne.io.read_raw_edf(file, preload=True)
  data = mne.io.read_raw_edf(file, preload=True)


Signal shape: (6, 6870000)
Signal shape: (2, 6870000)
Data shape: (8, 6870000), fs: 200, segmentation size: 1000, annotations: 6870
Data shape after padding: (6871000, 8), annotations: 6870
Extracting EDF parameters from c:\Users\rafar\Documents\2024\Disciplina - Machine Learning Para séries Temporais\Trabalho1\DatabaseSubjects\subject16.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 5805999  =      0.000 ... 29029.995 secs...


  data = mne.io.read_raw_edf(file, preload=True)
  data = mne.io.read_raw_edf(file, preload=True)


Signal shape: (6, 5806000)
Signal shape: (2, 5806000)
Data shape: (8, 5806000), fs: 200, segmentation size: 1000, annotations: 5806
Data shape after padding: (5807000, 8), annotations: 5806
Extracting EDF parameters from c:\Users\rafar\Documents\2024\Disciplina - Machine Learning Para séries Temporais\Trabalho1\DatabaseSubjects\subject19.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 6201999  =      0.000 ... 31009.995 secs...


  data = mne.io.read_raw_edf(file, preload=True)
  data = mne.io.read_raw_edf(file, preload=True)


Signal shape: (6, 6202000)
Signal shape: (2, 6202000)
Data shape: (8, 6202000), fs: 200, segmentation size: 1000, annotations: 6202
Data shape after padding: (6203000, 8), annotations: 6202
Extracting EDF parameters from c:\Users\rafar\Documents\2024\Disciplina - Machine Learning Para séries Temporais\Trabalho1\DatabaseSubjects\subject4.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 6317999  =      0.000 ... 31589.995 secs...


  data = mne.io.read_raw_edf(file, preload=True)
  data = mne.io.read_raw_edf(file, preload=True)


Signal shape: (6, 6318000)
Signal shape: (2, 6318000)
Data shape: (8, 6318000), fs: 200, segmentation size: 1000, annotations: 6318
Data shape after padding: (6319000, 8), annotations: 6318
Extracting EDF parameters from c:\Users\rafar\Documents\2024\Disciplina - Machine Learning Para séries Temporais\Trabalho1\DatabaseSubjects\subject15.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 5039999  =      0.000 ... 25199.995 secs...


  data = mne.io.read_raw_edf(file, preload=True)
  data = mne.io.read_raw_edf(file, preload=True)


Signal shape: (6, 5040000)
Signal shape: (2, 5040000)
Data shape: (8, 5040000), fs: 200, segmentation size: 1000, annotations: 5040
Data shape after padding: (5041000, 8), annotations: 5040
Extracting EDF parameters from c:\Users\rafar\Documents\2024\Disciplina - Machine Learning Para séries Temporais\Trabalho1\DatabaseSubjects\subject1.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 5767999  =      0.000 ... 28839.995 secs...


  data = mne.io.read_raw_edf(file, preload=True)
  data = mne.io.read_raw_edf(file, preload=True)


Signal shape: (6, 5768000)
Signal shape: (2, 5768000)
Data shape: (8, 5768000), fs: 200, segmentation size: 1000, annotations: 5768
Data shape after padding: (5769000, 8), annotations: 5768
Extracting EDF parameters from c:\Users\rafar\Documents\2024\Disciplina - Machine Learning Para séries Temporais\Trabalho1\DatabaseSubjects\subject7.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 6071999  =      0.000 ... 30359.995 secs...


  data = mne.io.read_raw_edf(file, preload=True)
  data = mne.io.read_raw_edf(file, preload=True)


Signal shape: (6, 6072000)
Signal shape: (2, 6072000)
Data shape: (8, 6072000), fs: 200, segmentation size: 1000, annotations: 6072
Data shape after padding: (6073000, 8), annotations: 6072
Extracting EDF parameters from c:\Users\rafar\Documents\2024\Disciplina - Machine Learning Para séries Temporais\Trabalho1\DatabaseSubjects\subject5.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 6257999  =      0.000 ... 31289.995 secs...


  data = mne.io.read_raw_edf(file, preload=True)
  data = mne.io.read_raw_edf(file, preload=True)


Signal shape: (6, 6258000)
Signal shape: (2, 6258000)
Data shape: (8, 6258000), fs: 200, segmentation size: 1000, annotations: 6258
Data shape after padding: (6259000, 8), annotations: 6258
Extracting EDF parameters from c:\Users\rafar\Documents\2024\Disciplina - Machine Learning Para séries Temporais\Trabalho1\DatabaseSubjects\subject10.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 6195999  =      0.000 ... 30979.995 secs...


  data = mne.io.read_raw_edf(file, preload=True)
  data = mne.io.read_raw_edf(file, preload=True)


Signal shape: (6, 6196000)
Signal shape: (2, 6196000)
Data shape: (8, 6196000), fs: 200, segmentation size: 1000, annotations: 6196
Data shape after padding: (6197000, 8), annotations: 6196


In [7]:
class_name_alias = {
    0: 'dont_care',
    1: 'N3',
    2: 'N2',
    3: 'N1',
    4: 'REM',
    5: 'A',
}

In [8]:
# Assuming `labels` is a list or NumPy array containing the class labels for all samples
class_counts = np.bincount(train_dataset.annotations)  # Count occurrences of each class
class_weights = 1.0 / class_counts  # Inverse of frequencies
class_weights[0] = 0

# Assign weights to each sample based on its class
sample_weights = np.array([class_weights[label] for label in train_dataset.annotations])

In [9]:
from torch.utils.data import WeightedRandomSampler, DataLoader

# Convert weights to a tensor
sample_weights = torch.DoubleTensor(sample_weights)

# Create the sampler
sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),  # Total samples to draw per epoch
    replacement=True  # Allow replacement for oversampling,
)

In [None]:
# Model initialization
#model = Model_CNN_LSTM(n_sensors=8, num_classes=6, fs=200, time_frame=5,)
#model = MultiResolutionModel(n_sensors=8, num_classes=6, fs=200, time_frame=5,)
model = MultiResolutionLSTMFFT(num_channels=8, num_classes=6, fs=200, time_frame=5,)
num_classes=6
criterion = FocalLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=0.01)
model.to(device)




LSTMFormerFFT(
  (fft_lstm_block): FFTLSTMBlock(
    (fft): SlidingFFT()
    (conv1): Conv1d(1608, 512, kernel_size=(1,), stride=(1,), padding=same)
    (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv1d(512, 1024, kernel_size=(3,), stride=(1,), padding=same)
    (bn2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (lstm): LSTM(1024, 512, batch_first=True, dropout=0.2, bidirectional=True)
    (conv3): Conv1d(1024, 512, kernel_size=(1,), stride=(1,), padding=same)
    (bn3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv4): Conv1d(512, 256, kernel_size=(1,), stride=(1,), padding=same)
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (dropout3): Dropout(p=0.2, inplace=False)
  )
  (conv0): Conv1d(8, 32, kernel_size=(10,), stride=(1,), padding=same)
  (bn0): BatchNorm1d(32, eps=1e-05, momentum=0.1, affin

In [11]:
accuracy = torchmetrics.Accuracy(num_classes=num_classes, average='none', task='multiclass').to(device)
precision = torchmetrics.Precision(num_classes=num_classes, average='none', task='multiclass').to(device)
recall = torchmetrics.Recall(num_classes=num_classes, average='none', task='multiclass').to(device)
f1 = torchmetrics.F1Score(num_classes=num_classes, average='none', task='multiclass').to(device)


In [12]:
def plot_confusion_matrix_and_log(predictions, labels, num_classes, exclude_label=0):
    """
    Plots a confusion matrix for the predictions, excluding the specified label, 
    and logs it to MLflow using mlflow.log_image.
    
    Args:
    - predictions (torch.Tensor or np.ndarray): The model's predicted class labels.
    - labels (torch.Tensor or np.ndarray): The true class labels.
    - num_classes (int): The total number of classes (excluding the label to be ignored).
    - exclude_label (int): The class label to exclude (default: 0).
    """
    # Filter out the samples with the 'exclude_label'
    mask = labels != exclude_label
    filtered_preds = predictions[mask]
    filtered_labels = labels[mask]

    # Debugging: Check the types and shapes
    print(f"Type of filtered_preds: {type(filtered_preds)}")
    print(f"Shape of filtered_preds: {filtered_preds.shape if isinstance(filtered_preds, torch.Tensor) else 'Not a tensor'}")
    print(f"Type of filtered_labels: {type(filtered_labels)}")
    print(f"Shape of filtered_labels: {filtered_labels.shape if isinstance(filtered_labels, torch.Tensor) else 'Not a tensor'}")

    # Convert tensors to NumPy arrays if they are still tensors
    if isinstance(filtered_preds, torch.Tensor):
        filtered_preds = filtered_preds.cpu().numpy()
    if isinstance(filtered_labels, torch.Tensor):
        filtered_labels = filtered_labels.cpu().numpy()

    # Debugging: Check if they are now NumPy arrays
    print(f"Type of filtered_preds after conversion: {type(filtered_preds)}")
    print(f"Type of filtered_labels after conversion: {type(filtered_labels)}")

    # Ensure both are arrays and not scalars
    filtered_preds = np.asarray(filtered_preds)
    filtered_labels = np.asarray(filtered_labels)

    # Ensure filtered_labels and filtered_preds are not empty and are array-like
    if filtered_preds.ndim == 0 or filtered_labels.ndim == 0:
        raise ValueError("The predictions and labels must be array-like and cannot be scalar.")
    
    # Compute the confusion matrix
    cm = confusion_matrix(filtered_labels, filtered_preds, labels=range(1, num_classes))
    
    # Create a ConfusionMatrixDisplay object
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=range(1, num_classes))
    
    # Plot the confusion matrix
    fig, ax = plt.subplots(figsize=(8, 6))
    disp.plot(cmap='Blues', values_format='d', ax=ax)
    plt.title(f'Confusion Matrix (Excluding Label {exclude_label})')
    
    # Save the plot to a BytesIO buffer
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    
    # Log the image to MLflow
    mlflow.log_image(buf, "confusion_matrix.png")
    
    # Close the plot to free memory
    plt.close(fig)

def log_model_graph_as_svg(model, train_loader):
    """
    Generates and logs the model graph as an SVG artifact in MLflow.
    
    Args:
    - model (nn.Module): The trained model.
    - train_loader (DataLoader): The DataLoader to get a batch of input data.
    """
    # Generate the model graph visualization
    input_size = next(iter(train_loader))[0].shape  # Get the shape of the input
    
    # Use the draw_graph function with basic parameters
    try:
        model_graph = draw_graph(
            model, 
            input_size=input_size, 
            depth=2,  # Use a smaller depth for simplicity
            device='cpu',  # Adjust device if needed
            show_shapes=True,  # Show tensor shapes
            save_graph=True,  # Don't save yet, just visualize
            filename='model_graph.png',  # We will handle the output file path later
            directory='./',  # Directory for saving graph
        )

        # Check if the model graph is valid
        if model_graph is None or not model_graph.visual_graph:
            raise ValueError("Failed to generate model graph visualization. Please check the input model.")
        
        print("Graph visualization generated successfully.")

    except Exception as e:
        raise RuntimeError(f"Failed to generate model graph: {e}")
    
    # Log the generated SVG artifact to MLflow
    try:
        mlflow.log_artifact('./model_graph.png.png', "model_graph.png")
        print("Model graph has been logged as an artifact in MLflow.")
    except Exception as e:
        raise RuntimeError(f"Failed to log SVG artifact to MLflow: {e}")
    

def compute_loss(outputs, labels, criterion):
    # Mask to exclude label 0
    non_zero_mask = labels != 0
    
    # Apply the mask to outputs and labels (this effectively ignores label 0)
    masked_labels = labels[non_zero_mask]
    masked_outputs = outputs[non_zero_mask]
    
    # If there are no valid labels (i.e., all labels are 0), return a loss of 0
    if masked_labels.size(0) == 0:
        return torch.tensor(0.0, device=outputs.device)

    # Calculate the loss for the masked labels
    loss = criterion(masked_outputs, masked_labels)
    return loss

# Training loop for one epoch
# Training loop for one epoch
def train_epoch(model, dataloader, criterion, optimizer, epoch):
    model.train()
    total_loss = 0
    pbar = tqdm.tqdm(dataloader)
    for iteration, (X, y) in enumerate(pbar):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(X)
        # Apply softmax
        output = torch.nn.Softmax(-1)(output)
        loss = compute_loss(output, y, criterion)
        loss.backward()
        optimizer.step()
        
        # Log the loss for this iteration
        mlflow.log_metric("train_loss", loss.item(), step=epoch * len(dataloader) + iteration)

        total_loss += loss.item()
        pbar.set_description(f"Loss: {loss.item()}")
    
    return total_loss / len(dataloader)

# Validation loop (including metrics calculation)
def validate(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            output = model(X)
            loss = compute_loss(output, y, criterion)
            total_loss += loss.item()
            all_preds.extend(output.argmax(1).cpu().numpy())
            all_labels.extend(y.cpu().numpy())

            # Update TorchMetrics with predictions and true labels
            accuracy.update(output, y)
            precision.update(output, y)
            recall.update(output, y)
            f1.update(output, y)

    # Calculate final metrics from TorchMetrics
    acc = accuracy.compute()
    prec = precision.compute()
    rec = recall.compute()
    f1_score = f1.compute()

    # Calculate confusion matrix
    #cm = confusion_matrix(all_labels, all_preds, num_classes)

    # Get class names using the alias dictionary
    class_names = [class_name_alias[i] for i in range(num_classes)]
    
    # Plot confusion matrix and save it to a file
    #plot_confusion_matrix_and_log(all_preds, all_labels, num_classes=num_classes, exclude_label=0)

    # Log confusion matrix image to MLflow
    #mlflow.log_artifact('confusion_matrix.png')

    # Reset metrics after each validation step
    accuracy.reset()
    precision.reset()
    recall.reset()
    f1.reset()

    # Return loss, accuracy, and per-class metrics
    accuracy_value = acc.mean().item()  # mean accuracy across all classes
    return total_loss / len(dataloader), accuracy_value, prec.cpu().numpy(), rec.cpu().numpy(), f1_score.cpu().numpy()

In [13]:
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [14]:
mlflow.end_run()

In [15]:
experiment_name = "complete_experiment2"
num_epochs = 8


# Set or create the experiment
mlflow.set_tracking_uri(r"sqlite:///C:\Users\rafar\Documents\mlflow\mlflow.db")
mlflow.set_experiment(experiment_name)
#
## Start the MLflow run
mlflow.start_run()
log_model_graph_as_svg(model, train_loader)
model.to(device)
for epoch in range(num_epochs):
    # Train the model for one epoch
    train_loss = train_epoch(model, train_loader, criterion, optimizer, epoch)
    
    # Validate the model and get metrics
    val_loss, val_acc, precision_per_class, recall_per_class, f1_per_class = validate(model, test_loader, criterion)
    
    # Log metrics to MLflow
    mlflow.log_metric("train_loss", train_loss, step=epoch)
    mlflow.log_metric("val_loss", val_loss, step=epoch)
    mlflow.log_metric("val_accuracy", val_acc, step=epoch)
    
    # Log per-class metrics for precision, recall, and F1 score
    for class_id in range(num_classes):
        if class_id == 0:
            continue
        class_name = class_name_alias[class_id]
        mlflow.log_metric(f"precision_class_{class_name}", precision_per_class[class_id], step=epoch)
        mlflow.log_metric(f"recall_class_{class_name}", recall_per_class[class_id], step=epoch)
        mlflow.log_metric(f"f1_class_{class_name}", f1_per_class[class_id], step=epoch)
    
    # Print the progress for the epoch
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

# End the MLflow run
mlflow.end_run()

print("Training complete and metrics logged!")

  ret = func(*args, **kwargs)



Graph visualization generated successfully.
Model graph has been logged as an artifact in MLflow.


Loss: 1.0354379415512085: 100%|██████████| 6083/6083 [27:49<00:00,  3.64it/s] 


Epoch 1/8, Train Loss: 0.9074, Val Loss: 0.9605, Val Acc: 0.4996


Loss: 0.7642558813095093: 100%|██████████| 6083/6083 [25:42<00:00,  3.94it/s] 


Epoch 2/8, Train Loss: 0.7923, Val Loss: 1.1332, Val Acc: 0.5130


Loss: 0.7611589431762695: 100%|██████████| 6083/6083 [28:06<00:00,  3.61it/s]  


Epoch 3/8, Train Loss: 0.7677, Val Loss: 1.3026, Val Acc: 0.5112


Loss: 0.8497283458709717: 100%|██████████| 6083/6083 [27:51<00:00,  3.64it/s]  


Epoch 4/8, Train Loss: 0.7529, Val Loss: 1.3030, Val Acc: 0.5259


Loss: 0.8201395869255066: 100%|██████████| 6083/6083 [28:09<00:00,  3.60it/s]  


Epoch 5/8, Train Loss: 0.7409, Val Loss: 1.5525, Val Acc: 0.5034


Loss: 0.6776416897773743: 100%|██████████| 6083/6083 [28:46<00:00,  3.52it/s]  


Epoch 6/8, Train Loss: 0.7372, Val Loss: 1.2664, Val Acc: 0.5272


Loss: 0.6147096157073975: 100%|██████████| 6083/6083 [31:36<00:00,  3.21it/s]  


Epoch 7/8, Train Loss: 0.7298, Val Loss: 1.8490, Val Acc: 0.4375


Loss: 0.70076584815979: 100%|██████████| 6083/6083 [28:14<00:00,  3.59it/s]    


Epoch 8/8, Train Loss: 0.7265, Val Loss: 1.2573, Val Acc: 0.5176
Training complete and metrics logged!
