In [6]:
import torch
from torch.utils.data import DataLoader
from torch import nn
import torch.optim as optim
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score, matthews_corrcoef, recall_score, precision_score
from sklearn.metrics import confusion_matrix, f1_score, classification_report
from sklearn.metrics import roc_curve
import torch.nn.functional as F
from torchvision import datasets, models

from collections import OrderedDict
from functools import partial
from typing import Callable, Optional
from torch import Tensor

import re
from typing import Any, List, Tuple
from collections import OrderedDict

import torch.nn.functional as F
import torch.utils.checkpoint as cp

In [3]:
loaded_datasets_info = torch.load('/root/autodl-tmp/data/saved_datasets.pth', weights_only=False)
train_dataset = loaded_datasets_info['train_dataset']
test_dataset = loaded_datasets_info['test_dataset']

In [4]:
batch_size = 100
loaded_train_dataset = DataLoader(train_dataset, batch_size = batch_size, shuffle = False)
loaded_test_dataset = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)

In [7]:
def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output

class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

class ConvBNAct(nn.Module):
    def __init__(self,
                 in_planes: int,
                 out_planes: int,
                 kernel_size: int = 3,
                 stride: int = 1,
                 groups: int = 1,
                 norm_layer: Optional[Callable[..., nn.Module]] = None,
                 activation_layer: Optional[Callable[..., nn.Module]] = None):
        super(ConvBNAct, self).__init__()

        padding = (kernel_size - 1) // 2
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if activation_layer is None:
            activation_layer = nn.SiLU  # alias Swish  (torch>=1.7)

        self.conv = nn.Conv2d(in_channels=in_planes,
                              out_channels=out_planes,
                              kernel_size=kernel_size,
                              stride=stride,
                              padding=padding,
                              groups=groups,
                              bias=False)

        self.bn = norm_layer(out_planes)
        self.act = activation_layer()

    def forward(self, x):
        result = self.conv(x)
        result = self.bn(result)
        result = self.act(result)

        return result

class SqueezeExcite(nn.Module):
    def __init__(self,
                 input_c: int,   # block input channel
                 expand_c: int,  # block expand channel
                 se_ratio: float = 0.25):
        super(SqueezeExcite, self).__init__()
        squeeze_c = int(input_c * se_ratio)
        self.conv_reduce = nn.Conv2d(expand_c, squeeze_c, 1)
        self.act1 = nn.SiLU()  # alias Swish
        self.conv_expand = nn.Conv2d(squeeze_c, expand_c, 1)
        self.act2 = nn.Sigmoid()

    def forward(self, x: Tensor) -> Tensor:
        scale = x.mean((2, 3), keepdim=True)
        scale = self.conv_reduce(scale)
        scale = self.act1(scale)
        scale = self.conv_expand(scale)
        scale = self.act2(scale)
        return scale * x

class MBConv(nn.Module):
    def __init__(self,
                 kernel_size: int,
                 input_c: int,
                 out_c: int,
                 expand_ratio: int,
                 stride: int,
                 se_ratio: float,
                 drop_rate: float,
                 norm_layer: Callable[..., nn.Module]):
        super(MBConv, self).__init__()

        if stride not in [1, 2]:
            raise ValueError("illegal stride value.")

        self.has_shortcut = (stride == 1 and input_c == out_c)

        activation_layer = nn.SiLU  # alias Swish
        expanded_c = input_c * expand_ratio

        # in EfficientNetV2, there is no condition like expansion=1 in MBConv, so conv_pw must exists
        assert expand_ratio != 1
        # Point-wise expansion
        self.expand_conv = ConvBNAct(input_c,
                                     expanded_c,
                                     kernel_size=1,
                                     norm_layer=norm_layer,
                                     activation_layer=activation_layer)

        # Depth-wise convolution
        self.dwconv = ConvBNAct(expanded_c,
                                expanded_c,
                                kernel_size=kernel_size,
                                stride=stride,
                                groups=expanded_c,
                                norm_layer=norm_layer,
                                activation_layer=activation_layer)

        self.se = SqueezeExcite(input_c, expanded_c, se_ratio) if se_ratio > 0 else nn.Identity()

        # Point-wise linear projection
        self.project_conv = ConvBNAct(expanded_c,
                                      out_planes=out_c,
                                      kernel_size=1,
                                      norm_layer=norm_layer,
                                      activation_layer=nn.Identity)  # no activation function existed，all pass to Identity

        self.out_channels = out_c

        # only use dropout layer only when use shortcut connection
        self.drop_rate = drop_rate
        if self.has_shortcut and drop_rate > 0:
            self.dropout = DropPath(drop_rate)

    def forward(self, x: Tensor) -> Tensor:
        result = self.expand_conv(x)
        result = self.dwconv(result)
        result = self.se(result)
        result = self.project_conv(result)

        if self.has_shortcut:
            if self.drop_rate > 0:
                result = self.dropout(result)
            result += x

        return result


class FusedMBConv(nn.Module):
    def __init__(self,
                 kernel_size: int,
                 input_c: int,
                 out_c: int,
                 expand_ratio: int,
                 stride: int,
                 se_ratio: float,
                 drop_rate: float,
                 norm_layer: Callable[..., nn.Module]):
        super(FusedMBConv, self).__init__()

        assert stride in [1, 2]
        assert se_ratio == 0

        self.has_shortcut = stride == 1 and input_c == out_c
        self.drop_rate = drop_rate

        self.has_expansion = expand_ratio != 1

        activation_layer = nn.SiLU  # alias Swish
        expanded_c = input_c * expand_ratio

        # apply expand conv only when expand ratio isn't equal to 1
        if self.has_expansion:
            # Expansion convolution
            self.expand_conv = ConvBNAct(input_c,
                                         expanded_c,
                                         kernel_size=kernel_size,
                                         stride=stride,
                                         norm_layer=norm_layer,
                                         activation_layer=activation_layer)

            self.project_conv = ConvBNAct(expanded_c,
                                          out_c,
                                          kernel_size=1,
                                          norm_layer=norm_layer,
                                          activation_layer=nn.Identity)  # activation function not exists
        else:
            # when only have project_conv
            self.project_conv = ConvBNAct(input_c,
                                          out_c,
                                          kernel_size=kernel_size,
                                          stride=stride,
                                          norm_layer=norm_layer,
                                          activation_layer=activation_layer)  # activation function exists

        self.out_channels = out_c

        # use dropout layer only when use shortcut connection
        self.drop_rate = drop_rate
        if self.has_shortcut and drop_rate > 0:
            self.dropout = DropPath(drop_rate)

    def forward(self, x: Tensor) -> Tensor:
        if self.has_expansion:
            result = self.expand_conv(x)
            result = self.project_conv(result)
        else:
            result = self.project_conv(x)

        if self.has_shortcut:
            if self.drop_rate > 0:
                result = self.dropout(result)

            result += x

        return result

class EfficientNetV2(nn.Module):
    def __init__(self,
                 model_cnf: list,
                 num_classes: int = 1000,
                 num_features: int = 1280,
                 dropout_rate: float = 0.2,
                 drop_connect_rate: float = 0.2):
        super(EfficientNetV2, self).__init__()

        for cnf in model_cnf:
            assert len(cnf) == 8

        norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.1)

        stem_filter_num = model_cnf[0][4]

        self.stem = ConvBNAct(1,
                              stem_filter_num,
                              kernel_size=3,
                              stride=2,
                              norm_layer=norm_layer)  # default activation function is SiLU

        total_blocks = sum([i[0] for i in model_cnf])
        block_id = 0
        blocks = []
        for cnf in model_cnf:
            repeats = cnf[0]
            op = FusedMBConv if cnf[-2] == 0 else MBConv
            for i in range(repeats):
                blocks.append(op(kernel_size=cnf[1],
                                 input_c=cnf[4] if i == 0 else cnf[5],
                                 out_c=cnf[5],
                                 expand_ratio=cnf[3],
                                 stride=cnf[2] if i == 0 else 1,
                                 se_ratio=cnf[-1],
                                 drop_rate=drop_connect_rate * block_id / total_blocks,
                                 norm_layer=norm_layer))
                block_id += 1
        self.blocks = nn.Sequential(*blocks)

        head_input_c = model_cnf[-1][-3]
        head = OrderedDict()

        head.update({"project_conv": ConvBNAct(head_input_c,
                                               num_features,
                                               kernel_size=1,
                                               norm_layer=norm_layer)})  # default activation function is SiLU

        head.update({"avgpool": nn.AdaptiveAvgPool2d(1)})
        head.update({"flatten": nn.Flatten()})

        if dropout_rate > 0:
            head.update({"dropout": nn.Dropout(p=dropout_rate, inplace=True)})
        head.update({"classifier": nn.Linear(num_features, num_classes)})

        self.head = nn.Sequential(head)

        # initial weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)
        
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: Tensor) -> Tensor:
        x = self.stem(x)
        x = self.blocks(x)
        x = self.head(x)
        x = self.sigmoid(x)

        return x

def efficientnetv2_s(num_classes: int = 1000):
    """
    EfficientNetV2
    https://arxiv.org/abs/2104.00298
    """
    # train_size: 300, eval_size: 384

    # repeat, kernel, stride, expansion, in_c, out_c, operator, se_ratio
    model_config = [[2, 3, 1, 1, 24, 24, 0, 0],
                    [4, 3, 2, 4, 24, 48, 0, 0],
                    [4, 3, 2, 4, 48, 64, 0, 0],
                    [6, 3, 2, 4, 64, 128, 1, 0.25],
                    [9, 3, 1, 6, 128, 160, 1, 0.25],
                    [15, 3, 2, 6, 160, 256, 1, 0.25]]

    model = EfficientNetV2(model_cnf=model_config,
                           num_classes=num_classes,
                           dropout_rate=0.2)
    return model

In [8]:
device = "cuda"
model = efficientnetv2_s(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
num_epochs = 10

In [9]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_indx, (inputs, labels) in enumerate(loaded_train_dataset):
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    # Print average loss for the epoch
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / (len(loaded_train_dataset) / batch_size)}")  

Epoch 1/10, Loss: 204.80816485484442
Epoch 2/10, Loss: 165.5896511276563
Epoch 3/10, Loss: 154.8189313610395
Epoch 4/10, Loss: 151.4711290796598
Epoch 5/10, Loss: 149.99125971396765
Epoch 6/10, Loss: 149.15418702363968
Epoch 7/10, Loss: 148.58495779832205
Epoch 8/10, Loss: 148.22475453217825
Epoch 9/10, Loss: 147.95411173502603
Epoch 10/10, Loss: 147.73798388242722


In [10]:
torch.save(model.state_dict(), './model_params/EfficientNet.pth')

In [11]:
model.load_state_dict(torch.load('./model_params/EfficientNet.pth'))

  model.load_state_dict(torch.load('./model_params/EfficientNet.pth'))


<All keys matched successfully>

### Metrics function definition

In [50]:
# get roc_auc, metrics_sn, metrics_sp, metrics_ACC, metrics_F1, metrics_MCC
def calculate_multiclass_metrics(true_labels, predicted_labels, predicted_probabilities, num_classes):
    accuracy = accuracy_score(true_labels, predicted_labels)
    mcc = matthews_corrcoef(true_labels, predicted_labels)
    
    sensitivity_per_class = []
    specificity_per_class = []
    auc_per_class = []
    f1_per_class = []

    for i in range(num_classes):
        true_binary = (np.array(true_labels) == i).astype(int)
        pred_binary = (np.array(predicted_labels) == i).astype(int)

        cm = confusion_matrix(true_binary, pred_binary, labels=[0, 1])
        tn, fp, fn, tp = cm.ravel()

        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        sensitivity_per_class.append(sensitivity)
        specificity_per_class.append(specificity)

        
        auc = roc_auc_score(true_binary, predicted_probabilities[:, i]) if len(np.unique(true_binary)) > 1 else 0
        auc_per_class.append(auc)

        f1 = f1_score(true_binary, pred_binary) if len(np.unique(true_binary)) > 1 else 0
        f1_per_class.append(f1)

    avg_sensitivity = np.mean(sensitivity_per_class)
    avg_specificity = np.mean(specificity_per_class)
    avg_auc = np.mean(auc_per_class) if auc_per_class else 0
    avg_f1 = np.mean(f1_per_class)

    print(f"Average AUC: {avg_auc:.4f}")
    print(f"Average Sensitivity: {avg_sensitivity:.4f}")
    print(f"Average Specificity: {avg_specificity:.4f}")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Average F1-score: {avg_f1:.4f}")
    print(f"Matthews Correlation Coefficient (MCC): {mcc:.4f}")

### Training data metrics

In [51]:
predicted_probabilities = []
true_labels = []
with torch.set_grad_enabled(False):
    for batch_indx, (inputs, labels) in enumerate(loaded_train_dataset):
        inputs = inputs.to(device)
        labels = labels.to(device)      
        outputs = model(inputs)
        predicted_probabilities.extend(outputs.tolist())
        true_labels.extend(labels.tolist())

In [52]:
true_labels_ = np.argmax(true_labels, axis=-1)
predicted_labels = np.argmax(predicted_probabilities, axis=-1)
preds = torch.tensor(predicted_probabilities)
preds = F.softmax(preds, dim=-1)

train_metrics = calculate_multiclass_metrics(true_labels_, predicted_labels, preds, num_classes=10)

Average AUC: 0.9983
Average Sensitivity: 0.9873
Average Specificity: 0.9986
Accuracy: 0.9874
Average F1-score: 0.9873
Matthews Correlation Coefficient (MCC): 0.9860


In [53]:
np.save('/root/autodl-tmp/ROC/EfficientNet/y_train_pred.npy', preds)
np.save('/root/autodl-tmp/ROC/EfficientNet/y_train.npy', true_labels)

### Testing data metrics

In [54]:
predicted_probabilities = []  
true_labels = []  
with torch.set_grad_enabled(False): 
    for batch_indx, (inputs, labels) in enumerate(loaded_test_dataset):
        inputs = inputs.to(device)
        labels = labels.to(device)    
        outputs = model(inputs)
        predicted_probabilities.extend(outputs.tolist())
        true_labels.extend(labels.tolist())

In [55]:
true_labels_ = np.argmax(true_labels, axis=-1)
predicted_labels = np.argmax(predicted_probabilities, axis=-1)
preds = torch.tensor(predicted_probabilities)
preds = F.softmax(preds, dim=-1)

test_metrics = calculate_multiclass_metrics(true_labels_, predicted_labels, preds, num_classes=10)

Average AUC: 0.9978
Average Sensitivity: 0.9766
Average Specificity: 0.9974
Accuracy: 0.9766
Average F1-score: 0.9765
Matthews Correlation Coefficient (MCC): 0.9740


In [56]:
np.save('/root/autodl-tmp/ROC/EfficientNet/y_test_pred.npy', preds)
np.save('/root/autodl-tmp/ROC/EfficientNet/y_test.npy', true_labels)