In [13]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score, matthews_corrcoef, recall_score, precision_score
from sklearn.metrics import confusion_matrix, f1_score
from torch.utils.data import TensorDataset, random_split, DataLoader
import matplotlib.pyplot as plt
import math
import torchvision.models as models
import torch.nn.functional as F

In [14]:
loaded_datasets_info = torch.load('/root/autodl-tmp/imgs/RE-SWE/saved_datasets_RE-SWE.pth')
train_dataset = loaded_datasets_info['train_dataset']
val_dataset = loaded_datasets_info['val_dataset']
test_dataset = loaded_datasets_info['test_dataset']

  loaded_datasets_info = torch.load('/root/autodl-tmp/imgs/RE-SWE/saved_datasets_RE-SWE.pth')


In [15]:
batch_size = 10
loaded_train_dataset = DataLoader(train_dataset, batch_size = batch_size, shuffle = False)
loaded_val_dataset = DataLoader(val_dataset, batch_size = batch_size, shuffle = False)
loaded_test_dataset = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)

In [16]:
def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.size()
    channels_per_group = num_channels // groups
    # reshape: b, num_channels, h, w  -->  b, groups, channels_per_group, h, w
    x = x.view(batchsize, groups, channels_per_group, height, width)
    # channelshuffle
    x = torch.transpose(x, 1, 2).contiguous()
    # flatten
    x = x.view(batchsize, -1, height, width)
    return x

class shuffleNet_unit(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride, groups):
        super(shuffleNet_unit, self).__init__()

        mid_channels = out_channels//4
        self.stride = stride
        if in_channels == 24:
            self.groups = 1
        else:
            self.groups = groups
        self.GConv1 = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1, groups=self.groups, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True)
        )

        self.DWConv = nn.Sequential(
            nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=self.stride, padding=1, groups=self.groups, bias=False),
            nn.BatchNorm2d(mid_channels)
        )

        self.GConv2 = nn.Sequential(
            nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1, groups=self.groups, bias=False),
            nn.BatchNorm2d(out_channels)
        )

        if self.stride == 2:
            self.shortcut = nn.Sequential(
                nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
            )
        else:
            self.shortcut = nn.Sequential()

    def forward(self, x):
        out = self.GConv1(x)
        out = channel_shuffle(out, groups=self.groups)
        out = self.DWConv(out)
        out = self.GConv2(out)
        short = self.shortcut(x)
        if self.stride == 2:
            out = F.relu(torch.cat([out, short], dim=1))
        else:
            out = F.relu(out + short)
        return out

class ShuffleNet(nn.Module):
    def __init__(self, groups, num_layers, num_channels, num_classes=1): # 修改了num_classes
        super(ShuffleNet, self).__init__()

        self.groups = groups
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 24, 3, 2, 1, bias=False),
            nn.BatchNorm2d(24),
            nn.ReLU(inplace=True),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.stage2 = self.make_layers(24, num_channels[0], num_layers[0], groups)
        self.stage3 = self.make_layers(num_channels[0], num_channels[1], num_layers[1], groups)
        self.stage4 = self.make_layers(num_channels[1], num_channels[2], num_layers[2], groups)

        self.globalpool = nn.AvgPool2d(kernel_size=7, stride=1)
        self.fc = nn.Linear(num_channels[2], num_classes)
        self.sigmoid = nn.Sigmoid()
        
    def make_layers(self, in_channels, out_channels, num_layers, groups):
        layers = []
        layers.append(shuffleNet_unit(in_channels, out_channels - in_channels, 2, groups))
        in_channels = out_channels
        for i in range(num_layers - 1):
            layers.append(shuffleNet_unit(in_channels, out_channels, 1, groups))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.globalpool(x)
        x = x.view(x.size(0), -1)
        out = self.sigmoid(self.fc(x))
        return out

device = "cuda"
model = ShuffleNet(1, num_layers = [4, 8, 4], num_channels = [144, 288, 576]).to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
num_epochs = 10

In [17]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_indx, (inputs, labels) in enumerate(loaded_train_dataset):
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    # Print average loss for the epoch
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / (len(loaded_train_dataset) / batch_size)}")  

Epoch 1/10, Loss: 4.237455258575769
Epoch 2/10, Loss: 3.4349097239856534
Epoch 3/10, Loss: 2.245242243202833
Epoch 4/10, Loss: 1.2278001400857017
Epoch 5/10, Loss: 0.9291945830605423
Epoch 6/10, Loss: 0.8584614727610292
Epoch 7/10, Loss: 1.0414647099185879
Epoch 8/10, Loss: 0.7072763906147044
Epoch 9/10, Loss: 0.7699138706755967
Epoch 10/10, Loss: 0.4014350141705318


In [18]:
predicted_probabilities = []
true_labels = []
with torch.set_grad_enabled(False):
    for batch_indx, (inputs, labels) in enumerate(loaded_val_dataset):
        inputs = inputs.to(device)
        labels = labels.to(device)      
        outputs = model(inputs)
        predicted_probabilities.extend(outputs.tolist())
        true_labels.extend(labels.tolist())

In [19]:
def metrics_output(preds,labels):
    true_labels = np.array(labels)
    predicted_probs = np.array(preds)
    binary_predictions = (predicted_probs >= 0.5).astype(int)
    auc = roc_auc_score(true_labels, predicted_probs)
    conf_matrix = confusion_matrix(true_labels, binary_predictions)
    tn, fp, fn, tp = conf_matrix.ravel()
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)
    accuracy = accuracy_score(true_labels, binary_predictions)
    f1 = f1_score(true_labels, binary_predictions)
    mcc = matthews_corrcoef(true_labels, binary_predictions)  
    return (auc, sensitivity, specificity, accuracy, f1, mcc)

In [20]:
roc_auc, metrics_sn, metrics_sp, metrics_ACC, metrics_F1, metrics_MCC = metrics_output(predicted_probabilities, true_labels)
print(roc_auc, metrics_sn, metrics_sp, metrics_ACC, metrics_F1, metrics_MCC)

0.9302083333333333 0.875 0.875 0.875 0.8400000000000001 0.7392960871352163


In [21]:
np.save('/root/autodl-tmp/ROC/RE-SWE/ShuffleNet/y_val_pred.npy', predicted_probabilities)
np.save('/root/autodl-tmp/ROC/RE-SWE/ShuffleNet/y_val.npy', true_labels)

In [22]:
predicted_probabilities = []  
true_labels = []  
with torch.set_grad_enabled(False): 
    for batch_indx, (inputs, labels) in enumerate(loaded_test_dataset):
        inputs = inputs.to(device)
        labels = labels.to(device)    
        outputs = model(inputs)
        predicted_probabilities.extend(outputs.tolist())
        true_labels.extend(labels.tolist())

In [23]:
roc_auc, metrics_sn, metrics_sp, metrics_ACC, metrics_F1, metrics_MCC = metrics_output(predicted_probabilities, true_labels)
print(roc_auc, metrics_sn, metrics_sp, metrics_ACC, metrics_F1, metrics_MCC)

0.9367007672634271 0.8043478260869565 0.9411764705882353 0.8625 0.8705882352941177 0.7373198807716764


In [24]:
np.save('/root/autodl-tmp/ROC/RE-SWE/ShuffleNet/y_test_pred.npy', predicted_probabilities)
np.save('/root/autodl-tmp/ROC/RE-SWE/ShuffleNet/y_test.npy', true_labels)