In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score, matthews_corrcoef, recall_score, precision_score
from sklearn.metrics import confusion_matrix, f1_score
from torch.utils.data import TensorDataset, random_split, DataLoader
import matplotlib.pyplot as plt
import math
import torchvision.models as models
import torch.optim.lr_scheduler as lr_scheduler
from collections import OrderedDict
from functools import partial
from typing import Callable, Optional
from torch import Tensor

from typing import Optional, Tuple, Union, Dict
from torch.nn import functional as F

In [2]:
from transformer import TransformerEncoder
from model_config import get_config
from model import mobile_vit_xx_small

In [3]:
loaded_datasets_info = torch.load('/root/autodl-tmp/imgs/RE-US/saved_datasets_RE-US.pth', weights_only=False)
train_dataset = loaded_datasets_info['train_dataset']
val_dataset = loaded_datasets_info['val_dataset']
test_dataset = loaded_datasets_info['test_dataset']

In [4]:
batch_size = 10
loaded_train_dataset = DataLoader(train_dataset, batch_size = batch_size, shuffle = False)
loaded_val_dataset = DataLoader(val_dataset, batch_size = batch_size, shuffle = False)
loaded_test_dataset = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)

In [5]:
device = "cuda"
model = mobile_vit_xx_small(num_classes = 1).to(device)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
num_epochs = 10

In [6]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_indx, (inputs, labels) in enumerate(loaded_train_dataset):
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        # scheduler.step()
        
        running_loss += loss.item()
        
    # Print average loss for the epoch
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / (len(loaded_train_dataset) / batch_size)}")  

Epoch 1/10, Loss: 6.920109689235687
Epoch 2/10, Loss: 6.810783399985387
Epoch 3/10, Loss: 6.680951301868145
Epoch 4/10, Loss: 6.511607009630937
Epoch 5/10, Loss: 6.248471163786374
Epoch 6/10, Loss: 5.8936107617158155
Epoch 7/10, Loss: 5.487457800369996
Epoch 8/10, Loss: 5.113928226324228
Epoch 9/10, Loss: 4.911083544676121
Epoch 10/10, Loss: 4.359779644470948


In [7]:
# save model
torch.save(model.state_dict(), "/root/autodl-tmp/scripts/RE-US/h.2_MobileViT/MobileViT_xx_s.pth") 

In [8]:
predicted_probabilities = []
true_labels = []
with torch.set_grad_enabled(False):
    for batch_indx, (inputs, labels) in enumerate(loaded_val_dataset):
        inputs = inputs.to(device)
        labels = labels.to(device)      
        outputs = model(inputs)
        predicted_probabilities.extend(outputs.tolist())
        true_labels.extend(labels.tolist())

In [9]:
def metrics_output(preds,labels):
    true_labels = np.array(labels)
    predicted_probs = np.array(preds)
    binary_predictions = (predicted_probs >= 0.5).astype(int)
    auc = roc_auc_score(true_labels, predicted_probs)
    conf_matrix = confusion_matrix(true_labels, binary_predictions)
    tn, fp, fn, tp = conf_matrix.ravel()
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)
    accuracy = accuracy_score(true_labels, binary_predictions)
    f1 = f1_score(true_labels, binary_predictions)
    mcc = matthews_corrcoef(true_labels, binary_predictions)  
    return (auc, sensitivity, specificity, accuracy, f1, mcc)

In [10]:
roc_auc, metrics_sn, metrics_sp, metrics_ACC, metrics_F1, metrics_MCC = metrics_output(predicted_probabilities, true_labels)
print(roc_auc, metrics_sn, metrics_sp, metrics_ACC, metrics_F1, metrics_MCC)

0.7576354679802955 0.6 0.7586206896551724 0.671875 0.6666666666666665 0.3598637460328732


In [11]:
np.save('/root/autodl-tmp/ROC/RE-US/MobileViT_xx_s/y_val_pred.npy', predicted_probabilities)
np.save('/root/autodl-tmp/ROC/RE-US/MobileViT_xx_s/y_val.npy', true_labels)

In [12]:
predicted_probabilities = []  
true_labels = []  
with torch.set_grad_enabled(False): 
    for batch_indx, (inputs, labels) in enumerate(loaded_test_dataset):
        inputs = inputs.to(device)
        labels = labels.to(device)    
        outputs = model(inputs)
        predicted_probabilities.extend(outputs.tolist())
        true_labels.extend(labels.tolist())

In [13]:
roc_auc, metrics_sn, metrics_sp, metrics_ACC, metrics_F1, metrics_MCC = metrics_output(predicted_probabilities, true_labels)
print(roc_auc, metrics_sn, metrics_sp, metrics_ACC, metrics_F1, metrics_MCC)

0.7156250000000001 0.6 0.8 0.7 0.6666666666666665 0.408248290463863


In [14]:
np.save('/root/autodl-tmp/ROC/RE-US/MobileViT_xx_s/y_test_pred.npy', predicted_probabilities)
np.save('/root/autodl-tmp/ROC/RE-US/MobileViT_xx_s/y_test.npy', true_labels)