#### Data Preprocessing

In [1]:
from boxsers.preprocessing import savgol_smoothing, cosmic_filter

def preprocessing_method(x):
    # 1) Applies a median filter to remove cosmic rays from the spectrum(s).
    x = cosmic_filter(x, ks=3)
    # 2) Smoothes the spectra
    x = savgol_smoothing(x, 7, p=3, degree=0)
    return x

2024-08-06 12:16:59.039451: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-06 12:16:59.039477: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-06 12:16:59.040157: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


#### Loading Dataset

In [2]:
import numpy as np

##### Train Dataset

In [3]:
X_fn = "./data/Bacteria-ID/X_finetune.npy"
y_fn = "./data/Bacteria-ID/y_finetune.npy"
X_train_raw = np.load(X_fn)
y_train_raw = np.load(y_fn)
X_train_raw = preprocessing_method(X_train_raw)  

##### Test Dataset

In [4]:
X_test = "./data/Bacteria-ID/X_test.npy"
y_test = "./data/Bacteria-ID/y_test.npy"
X_test = np.load(X_test)
y_test = np.load(y_test)
X_test = preprocessing_method(X_test)  

#### Model Setting

In [5]:
from model.Variant_LeNet_without_linear import Variant_LeNet_without_linear
from model.Variant_LeNet import Variant_LeNet
from tqdm import tqdm
import torch
import pandas as pd
import torch.nn as nn
from torch.autograd import Variable
from functools import partial
from deep_SLDA import slda_loss, SLDA
from imblearn.metrics import specificity_score
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import torch.optim as optim
import seaborn as sns
from sklearn.metrics import confusion_matrix
import math
import matplotlib.pyplot as plt
from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    auc,
    roc_curve,
)
from plot import plot_ROC_curve, plot_heatmap, plot_loss_metrics, plot_metrics, plot_antibiotic_groupings

n_classes = 30
batch_size = 3000

antibiotic_accuracy = []
train_avg_accuracy = []
val_avg_accuracy = []
avg_accuracy = []
avg_recall = []
avg_specificity = []
avg_precision = []
avg_f1 = []
avg_roc = []
C = np.zeros((30, 30))

In [6]:
class Solver:
    def __init__(
        self,
        dataloaders,
        model,
        model_path,
        device,
        n_classes,
        finetune_model
    ):
        self.dataloaders = dataloaders
        self.device = device
        self.net = model
        self.net.load_state_dict(finetune_model["state_dict"], strict=False)
        self.net = self.net.to(self.device)

        self.criterion = partial(
            slda_loss,
            n_classes=n_classes,
        )

        self.optimizer = optim.Adam(self.net.parameters(), lr=1e-4, betas=(0.5, 0.999))
        self.model_path = model_path
        self.n_classes = n_classes
        self.slda_layer = SLDA(self.n_classes)

    def iterate(self, epoch, phase, scheduler=None):
        if phase == "train":
            self.net.train()
        else:
            self.net.eval()

        dataloader = self.dataloaders[phase]
        total_loss = 0
        correct = 0
        total = 0
        loss_total = 0

        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = Variable(inputs).to(self.device), Variable(
                targets.long()
            ).to(self.device)
            
            feas = self.net(inputs)
            
            if phase == "train":
                dirs, range_eigenvalue, null_eigenvalue = self.slda_layer.fit(feas, targets, phase)
                Z = torch.matmul(feas, dirs.T)
                self.clf = LinearDiscriminantAnalysis()
                self.clf.fit(Z.detach().data.cpu().numpy(),targets.cpu().numpy())
                outputs = self.clf.predict(Z.detach().data.cpu().numpy())
                outputs = torch.from_numpy(outputs).to(self.device)
                loss = self.criterion(range_eigenvalue, null_eigenvalue)
                self.dirs = dirs
            else:
                range_eigenvalue, null_eigenvalue = self.slda_layer.fit(feas, targets, phase)
                Z = torch.matmul(feas, self.dirs.T)
                outputs = self.clf.predict(Z.detach().data.cpu().numpy())
                outputs = torch.from_numpy(outputs).to(self.device)
                loss = self.criterion(range_eigenvalue, null_eigenvalue)
            
            if phase == "train":
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            total_loss += loss.item()
            total += targets.size(0)
            loss_total += 1
            correct += outputs.eq(targets).cpu().sum().item()

        avg_loss = total_loss / loss_total
        total_acc = correct / total

        return avg_loss, total_acc

    def train(self, epochs):

        best_acc = 0

        useful_stuff = {
            "training_loss": [],
            "validation_loss": [],
            "train_metrics": [],
            "validation_metrics": [],
        }

        lambda1 = lambda epoch: 0.9 ** (epoch // 10) if epoch >= 10 else 1.0
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda1)

        for epoch in tqdm(range(epochs)):
            
            train_loss, train_acc = self.iterate(epoch, "train")
            useful_stuff["training_loss"].append(train_loss)
            useful_stuff["train_metrics"].append(train_acc)

            # self.optimizer.step()
            self.scheduler.step()
            
            with torch.no_grad():
                val_loss, val_acc = self.iterate(epoch, "val")
                useful_stuff["validation_loss"].append(val_loss)
                useful_stuff["validation_metrics"].append(val_acc)

            if val_acc > best_acc or epoch == 0:
                best_acc = val_acc
                checkpoint = {
                    "epoch": epoch,
                    "val_loss": val_loss,
                    "dirs": self.dirs,
                    "clf": self.clf,
                    "state_dict": self.net.state_dict(),
                }
                torch.save(checkpoint, self.model_path)  

        return train_acc, best_acc, useful_stuff

    def test_iterate(self, epoch, phase):
        self.net.eval()
        dataloader = self.dataloaders[phase]
        y_pred = []
        y_true = []
        y_pred_prob = []
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(dataloader):
                inputs, targets = Variable(inputs.cuda()), Variable(
                    targets.cuda().long()
                )
                
                feas = self.net(inputs)
                Z = torch.matmul(feas, self.dirs.T)
                outputs = self.clf.predict(Z.detach().data.cpu().numpy())
                outputs = torch.from_numpy(outputs).to(self.device)
                outputs_prob = self.clf.predict_proba(Z.detach().data.cpu().numpy())
                outputs_prob = torch.from_numpy(outputs_prob).to(self.device)

                y_pred.append(outputs.cpu().numpy().ravel())
                y_true.append(targets.cpu().numpy())
                y_pred_prob.append(outputs_prob.cpu().numpy())
            pass
        
        y_pred_prob = np.concatenate(y_pred_prob)
        y_pred = np.concatenate(y_pred)
        y_true = np.concatenate(y_true)

        return (
            np.array(y_pred).flatten(),
            np.array(y_true).flatten(),
            np.array(y_pred_prob).reshape(3000, 30),
        )

    def test(self):
        checkpoint = torch.load(self.model_path)
        epoch = checkpoint["epoch"]
        val_loss = checkpoint["val_loss"]
        self.dirs = checkpoint["dirs"]
        self.clf = checkpoint["clf"]
        self.net.load_state_dict(checkpoint["state_dict"])
        print("load model at epoch {}, with val loss: {:.3f}".format(epoch, val_loss))
        y_pred, y_true, y_pred_prob = self.test_iterate(epoch, "test")
        print("total", accuracy_score(y_true, y_pred))
        for i in range(self.n_classes):
            idx = y_true == i
            print("class", i, accuracy_score(y_true[idx], y_pred[idx]))

        return (
            confusion_matrix(y_true, y_pred),
            y_true,
            y_pred,
            accuracy_score(y_true, y_pred),
            y_pred_prob,
        )

In [7]:
from datasets_spectrum import spectral_dataloader
from config import ORDER, STRAINS

np.random.seed(42)

p_val = 0.1
n_val = int(3000 * p_val)
idx_tr = list(range(3000))
np.random.shuffle(idx_tr)
idx_val = idx_tr[:n_val]
idx_tr = idx_tr[n_val:]

fold_index = 1

for i in range(5):

    print("fold: ", fold_index)
    print("train size: ", len(idx_tr))
    print("validation size: ", len(idx_val))
    print("test size: ", len(y_test))

    dl_tr = spectral_dataloader(
        X_train_raw, y_train_raw, idxs=idx_tr, batch_size=batch_size, shuffle=True
    )
    dl_val = spectral_dataloader(
        X_train_raw, y_train_raw, idxs=idx_val, batch_size=batch_size, shuffle=False
    )
    dl_test = spectral_dataloader(X_test, y_test, batch_size=batch_size, shuffle=False)

    values, counts = np.unique(np.asarray(y_test), return_counts=True)

    dataloaders = {"train": dl_tr, "val": dl_val, "test": dl_test}
    model = Variant_LeNet_without_linear(in_channels=1)

    resnet_filename = f"best_variant_lenet_model_1.pt"
    finetune_model = torch.load(resnet_filename)
    model_path = f"best_finetune_variant_lenet_model_{fold_index}.pt"
    solver = Solver(
        dataloaders, model, model_path, "cuda", n_classes, finetune_model
    )
    
    train_accuracy, val_accuracy, useful_stuff = solver.train(500)
    C, y_true, y_pred, test_accuracy, y_pred_prob = solver.test()
    train_avg_accuracy.append(train_accuracy)
    val_avg_accuracy.append(val_accuracy)
    avg_accuracy.append(np.round(test_accuracy,4))
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(np.unique(y_true).shape[0]):
        fpr[i], tpr[i], _ = roc_curve(y_test == i, y_pred_prob[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    values = [
        v
        for v in roc_auc.values()
        if isinstance(v, (int, float)) and not math.isnan(v)
    ]
    if values:
        auc_score = sum(values) / len(values)
    avg_roc.append(auc_score)

    cm = confusion_matrix(y_true, y_pred, labels=ORDER)
    sns.set_context("talk", rc={"font": "Helvetica", "font.size": 12})
    label = [STRAINS[i] for i in ORDER]
    cm = 100 * cm / cm.sum(axis=1)[:,np.newaxis]

    avg_recall.append(recall_score(y_true, y_pred, average='macro', zero_division=0).round(4))
    avg_specificity.append(specificity_score(y_true, y_pred, average='weighted').round(4))
    avg_precision.append(precision_score(y_true, y_pred, average='weighted', zero_division=0).round(4))
    avg_f1.append(f1_score(y_true, y_pred, average='weighted', zero_division=0).round(4))
    df = pd.DataFrame(
        {
            "Accuracy": [np.round(accuracy_score(y_true, y_pred), 4)],
            "Recall": [
                recall_score(y_true, y_pred, average=None, zero_division=0).round(4)
            ],
            "Specificity": [specificity_score(y_true, y_pred, average=None).round(4)],
            "Precision": [
                precision_score(y_true, y_pred, average=None, zero_division=0).round(4)
            ],
            "F1 Score": [
                f1_score(y_true, y_pred, average=None, zero_division=0).round(4)
            ],
        }
    )
    print(df.transpose())

    plot_ROC_curve("variant_lenet", y_true, y_test, y_pred_prob, fold_index=fold_index)
    plot_heatmap("variant_lenet", cm, fold_index=fold_index)
    plot_metrics(training_results=useful_stuff, fold_index=fold_index, fold_name="variant_lenet")
    plot_loss_metrics(training_results=useful_stuff, fold_index=fold_index, fold_name="variant_lenet")
    acc = plot_antibiotic_groupings("variant_lenet", y_true, y_pred, fold_index=fold_index)
    antibiotic_accuracy.append(acc)
    fold_index += 1

fold:  1
train size:  2700
validation size:  300
test size:  3000


100%|██████████| 500/500 [04:21<00:00,  1.91it/s]

load model at epoch 16, with val loss: 0.224





total 0.8533333333333334
class 0 0.96
class 1 1.0
class 2 0.67
class 3 1.0
class 4 0.66
class 5 1.0
class 6 0.43
class 7 0.87
class 8 0.54
class 9 0.7
class 10 0.92
class 11 0.37
class 12 0.97
class 13 0.81
class 14 1.0
class 15 0.86
class 16 0.88
class 17 0.7
class 18 0.94
class 19 0.99
class 20 1.0
class 21 0.92
class 22 0.89
class 23 0.87
class 24 0.88
class 25 0.91
class 26 0.95
class 27 0.99
class 28 0.96
class 29 0.96
                                                             0
Accuracy                                                0.8533
Recall       [0.96, 1.0, 0.67, 1.0, 0.66, 1.0, 0.43, 0.87, ...
Specificity  [1.0, 0.9986, 0.9893, 0.9966, 0.9928, 0.999, 0...
Precision    [1.0, 0.9615, 0.6837, 0.9091, 0.7586, 0.9709, ...
F1 Score     [0.9796, 0.9804, 0.6768, 0.9524, 0.7059, 0.985...
Accuracy: 97.3%
fold:  2
train size:  2700
validation size:  300
test size:  3000


100%|██████████| 500/500 [04:25<00:00,  1.88it/s]

load model at epoch 11, with val loss: 0.211





total 0.8573333333333333
class 0 0.95
class 1 1.0
class 2 0.7
class 3 1.0
class 4 0.66
class 5 1.0
class 6 0.5
class 7 0.82
class 8 0.61
class 9 0.67
class 10 0.92
class 11 0.39
class 12 0.97
class 13 0.79
class 14 1.0
class 15 0.79
class 16 0.95
class 17 0.76
class 18 0.94
class 19 0.99
class 20 0.99
class 21 0.94
class 22 0.89
class 23 0.88
class 24 0.86
class 25 0.91
class 26 0.92
class 27 0.99
class 28 0.96
class 29 0.97
                                                             0
Accuracy                                                0.8573
Recall       [0.95, 1.0, 0.7, 1.0, 0.66, 1.0, 0.5, 0.82, 0....
Specificity  [1.0, 0.9983, 0.9879, 0.9972, 0.99, 0.999, 0.9...
Precision    [1.0, 0.9524, 0.6667, 0.9259, 0.6947, 0.9709, ...
F1 Score     [0.9744, 0.9756, 0.6829, 0.9615, 0.6769, 0.985...
Accuracy: 97.3%
fold:  3
train size:  2700
validation size:  300
test size:  3000


100%|██████████| 500/500 [04:31<00:00,  1.84it/s]

load model at epoch 11, with val loss: 0.204





total 0.8536666666666667
class 0 0.97
class 1 1.0
class 2 0.64
class 3 1.0
class 4 0.66
class 5 1.0
class 6 0.45
class 7 0.84
class 8 0.56
class 9 0.7
class 10 0.92
class 11 0.38
class 12 0.97
class 13 0.81
class 14 1.0
class 15 0.87
class 16 0.89
class 17 0.72
class 18 0.95
class 19 0.98
class 20 1.0
class 21 0.94
class 22 0.87
class 23 0.87
class 24 0.83
class 25 0.91
class 26 0.93
class 27 1.0
class 28 0.98
class 29 0.97
                                                             0
Accuracy                                                0.8537
Recall       [0.97, 1.0, 0.64, 1.0, 0.66, 1.0, 0.45, 0.84, ...
Specificity  [1.0, 0.999, 0.989, 0.9969, 0.9928, 0.999, 0.9...
Precision    [1.0, 0.9709, 0.6667, 0.9174, 0.7586, 0.9709, ...
F1 Score     [0.9848, 0.9852, 0.6531, 0.9569, 0.7059, 0.985...
Accuracy: 97.3%
fold:  4
train size:  2700
validation size:  300
test size:  3000


100%|██████████| 500/500 [04:32<00:00,  1.84it/s]

load model at epoch 37, with val loss: 0.183





total 0.8563333333333333
class 0 0.97
class 1 1.0
class 2 0.62
class 3 1.0
class 4 0.65
class 5 1.0
class 6 0.44
class 7 0.84
class 8 0.63
class 9 0.69
class 10 0.92
class 11 0.4
class 12 0.97
class 13 0.81
class 14 1.0
class 15 0.83
class 16 0.91
class 17 0.68
class 18 0.94
class 19 0.99
class 20 1.0
class 21 0.94
class 22 0.88
class 23 0.88
class 24 0.88
class 25 0.92
class 26 0.94
class 27 1.0
class 28 0.98
class 29 0.98
                                                             0
Accuracy                                                0.8563
Recall       [0.97, 1.0, 0.62, 1.0, 0.65, 1.0, 0.44, 0.84, ...
Specificity  [1.0, 0.999, 0.989, 0.9972, 0.9931, 0.999, 0.9...
Precision    [1.0, 0.9709, 0.6596, 0.9259, 0.7647, 0.9709, ...
F1 Score     [0.9848, 0.9852, 0.6392, 0.9615, 0.7027, 0.985...
Accuracy: 97.3%
fold:  5
train size:  2700
validation size:  300
test size:  3000


100%|██████████| 500/500 [04:34<00:00,  1.82it/s]

load model at epoch 43, with val loss: 0.209





total 0.861
class 0 0.97
class 1 1.0
class 2 0.65
class 3 1.0
class 4 0.65
class 5 1.0
class 6 0.44
class 7 0.87
class 8 0.63
class 9 0.71
class 10 0.91
class 11 0.39
class 12 0.99
class 13 0.85
class 14 1.0
class 15 0.81
class 16 0.94
class 17 0.71
class 18 0.94
class 19 0.99
class 20 0.99
class 21 0.94
class 22 0.88
class 23 0.88
class 24 0.89
class 25 0.92
class 26 0.94
class 27 1.0
class 28 0.97
class 29 0.97
                                                             0
Accuracy                                                 0.861
Recall       [0.97, 1.0, 0.65, 1.0, 0.65, 1.0, 0.44, 0.87, ...
Specificity  [1.0, 0.999, 0.99, 0.9966, 0.9934, 0.999, 0.99...
Precision    [1.0, 0.9709, 0.6915, 0.9091, 0.7738, 0.9709, ...
F1 Score     [0.9848, 0.9852, 0.6701, 0.9524, 0.7065, 0.985...
Accuracy: 97.4%


In [8]:
print(avg_accuracy)
print('Recall', avg_recall)
print('Specificity', avg_specificity)
print('Precision', avg_precision)
print('F1', avg_f1)

print("train mean:", round(np.mean(train_avg_accuracy),4))
print("train std:", round(np.std(train_avg_accuracy),4))

print("val mean:", round(np.mean(val_avg_accuracy),4))
print("val std:", round(np.std(val_avg_accuracy),4))

print("test mean:", round(np.mean(avg_accuracy),4))
print("test std:", round(np.std(avg_accuracy),4))

print("recall mean:", round(np.mean(avg_recall),4))
print("recall std:", round(np.std(avg_recall),4))

print("Specificity mean:", round(np.mean(avg_specificity),4))
print("Specificity std:", round(np.std(avg_specificity),4))

print("Precision mean:", round(np.mean(avg_precision),4))
print("Precision std:", round(np.std(avg_precision),4))

print("F1 mean:", round(np.mean(avg_f1),4))
print("F1 std:", round(np.std(avg_f1),4))

print("auc mean:", round(np.mean(avg_roc),4))
print("auc std:", round(np.std(avg_roc),4))

print("antibiotic mean:", np.mean(antibiotic_accuracy))
print("antibiotic std:", np.std(antibiotic_accuracy))

[0.8533, 0.8573, 0.8537, 0.8563, 0.861]
Recall [0.8533, 0.8573, 0.8537, 0.8563, 0.861]
Specificity [0.9949, 0.9951, 0.995, 0.995, 0.9952]
Precision [0.862, 0.8643, 0.8607, 0.8636, 0.8685]
F1 [0.8492, 0.8542, 0.8499, 0.8527, 0.8572]
train mean: 1.0
train std: 0.0
val mean: 0.9413
val std: 0.0016
test mean: 0.8563
test std: 0.0028
recall mean: 0.8563
recall std: 0.0028
Specificity mean: 0.995
Specificity std: 0.0001
Precision mean: 0.8638
Precision std: 0.0027
F1 mean: 0.8526
F1 std: 0.0029
auc mean: 0.9937
auc std: 0.0002
antibiotic mean: 0.9730666666666666
antibiotic std: 0.000489897948556624
