In [1]:
import os
import gc
import cv2
import time
import json
import torch
import torchvision
import numpy as np
import pandas as pd
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import TensorDataset, DataLoader, Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
class MetricsCalculator(object):  
    def __init__(self):  
        self.TP = 0  
        self.FP = 0  
        self.FN = 0  
        self.TN = 0
        self.y_trues = []
        self.y_pred_onehot = []
        self.y_preds_proba = []
  
    def update(self, y_true, y_pred, y_pred_onehot, y_pred_proba):  
        y_true = np.array(y_true).reshape(-1, 1)  
        y_pred = np.array(y_pred).reshape(-1, 1)
        y_pred_onehot = np.array(y_pred_onehot).reshape(-1, 2)
        y_pred_proba = np.array(y_pred_proba).reshape(-1, 2)

        self.y_trues.extend(y_true)
        self.y_pred_onehot.extend(y_pred_onehot)
        self.y_preds_proba.extend(y_pred_proba)  # Update the correct variable here
  
        # 假设 y_true 中 1 表示正类，0 表示负类  
        self.TP += np.sum((y_true == 1) & (y_pred == 1)) 
        self.FN += np.sum((y_true == 1) & (y_pred == 0))
        self.FP += np.sum((y_true == 0) & (y_pred == 1))
        self.TN += np.sum((y_true == 0) & (y_pred == 0))

    def calculate_brier_score(self):
        y_true = np.array(self.y_pred_onehot)
        y_pred_proba = np.array(self.y_preds_proba)  # Use the correct variable here

        # 计算 BS
        BS = np.mean((y_true - y_pred_proba) ** 2, axis=0)[0]

        # 计算 BSS
        y_mean = np.mean(y_true, axis=0)[0]
        reference_bs = np.mean((y_true[:, 0] - y_mean) ** 2, axis=0)
        BSS = 1 - BS / reference_bs if reference_bs != 0 else 0

        return BS, BSS

    def calculate_metrics(self):  
        print(np.array(self.y_preds_proba).shape)
        total = self.TP + self.FN + self.FP + self.TN  
  
        Accuracy = (self.TP + self.TN) / total if total > 0 else 0  
        Precision = self.TP / (self.TP + self.FP) if (self.TP + self.FP) > 0 else 0  
        Recall = self.TP / (self.TP + self.FN) if (self.TP + self.FN) > 0 else 0  
        FAR = self.FP / (self.FP + self.TP) if (self.FP + self.TP) > 0 else 0  
        TSS = Recall - (self.FP / (self.FP + self.TN)) if (self.FP + self.TN) > 0 else 0
        HSS = (2 * (self.TP * self.TN - self.FP * self.FN)) / ((self.TP + self.FN) * (self.FN + self.TN) + (self.TP + self.FP) * (self.FP + self.TN) + 1e-5)  
        
        BS, BSS = self.calculate_brier_score()
  
        metrics = {  
            'TP': self.TP,  
            'FP': self.FP,  
            'FN': self.FN,  
            'TN': self.TN,  
            'Accuracy': Accuracy,  
            'Precision': Precision,  
            'Recall': Recall,  
            'FAR': FAR,  
            'TSS': TSS,  
            'HSS': HSS,
            'Brier Score (BS)': BS,
            'Brier Skill Score (BSS)': BSS
        }  
        return metrics 

In [3]:
def data_preprocess(data_path, chunk_size=1000):
    labels_list = []
    images_list = []

    # Process data in chunks to save memory
    for chunk in pd.read_csv(data_path, header=None, usecols=[3] + list(range(4, 16388)), chunksize=chunk_size):
        labels_chunk = chunk.iloc[:, 0].values
        images_chunk = chunk.iloc[:, 1:].values.astype('float32').reshape(-1, 128, 128)
        
        # Convert labels to binary
        label_mapping = {'N': 0, 'C': 0, 'M': 1, 'X': 1}
        labels_chunk = np.vectorize(label_mapping.get)(labels_chunk).astype('float32').reshape(-1, 1)
        
        # Resize and preprocess images
        images_resized = np.array([cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR) for img in images_chunk])
        images_resized = np.stack([images_resized]*3, axis=-1)  # Convert to 3 channels
        images_resized = images_resized.transpose([0, 3, 1, 2])  # Change to (num_samples, channels, height, width)
        images_resized /= 4000.0  # Normalize

        labels_list.append(labels_chunk)
        images_list.append(images_resized)

        # Clear memory
        del chunk, labels_chunk, images_chunk, images_resized
        gc.collect()

    labels = np.vstack(labels_list)[39::40, ...].reshape((-1, ))
    images = np.vstack(images_list).reshape((-1, 40, 3, 224, 224)).transpose([0, 2, 1, 3, 4])

    # Clear memory
    del labels_list, images_list
    gc.collect()

    return torch.tensor(images[:, :, -32::2, ...]), torch.tensor(labels, dtype=torch.long)

In [4]:
np.random.seed(42)
torch.manual_seed(40)
torch.cuda.manual_seed_all(42)
device = torch.device('cuda:0')
weight = torch.tensor([0.569, 4.120], dtype=torch.float32, device=device)

In [None]:
batch_val = 55

for dataset_id in range(10):
    best_TSS = -1
    best_BSS = -100
    Train_Loss = []
    Val_Loss = []
    train_data, train_label = data_preprocess(f'DATA/feature/group9_Data2_image/{dataset_id}Train.csv')
    val_data, val_label = data_preprocess(f'DATA/feature/group9_Data2_image/Sift/{dataset_id}Val.csv')
    torch.cuda.empty_cache()
    gc.collect()

    train_dataset = TensorDataset(train_data, train_label)
    val_dataset = TensorDataset(val_data, val_label)
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, drop_last=True, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=2, shuffle=True, pin_memory=True)
    
    model_weights = torchvision.models.video.MViT_V2_S_Weights.DEFAULT
    model = torchvision.models.video.mvit_v2_s(weights=model_weights).to(device)
    model.head[1] = torch.nn.Linear(768, 2).to(device)
    
    opt = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.15)
    scheduler = torch.optim.lr_scheduler.LinearLR(opt, start_factor=0.001, end_factor=1.0, total_iters=15)

    for i in range(30):
        batch_id = 0
        total_loss = 0
        model.train()

        for img, label in train_loader:
            batch_id += 1

            pred = model(img.to(device))
            loss = F.cross_entropy(pred, label.to(device), weight=weight)
            total_loss += loss.item()

            opt.zero_grad()
            loss.backward()
            opt.step()

            if batch_id % batch_val == 0:
                print(f'Epoch: {i+1}')
                print(f'Train_Loss: {total_loss / batch_val}')
                Train_Loss.append(total_loss / batch_val)
                total_loss = 0

                TSS = []
                val_batch_id = 0
                calculator = MetricsCalculator()
                model.eval()

                with torch.no_grad():
                    for img, label in val_loader:
                        val_batch_id += 1
                        pred = model(img.to(device))
                        loss = F.cross_entropy(pred, label.to(device), weight=weight)
                        total_loss += loss.item()
                        pred_label = torch.argmax(pred, dim=-1)
                        y_pred_proba = F.softmax(pred, dim=-1)
                        calculator.update(y_true=label.detach().cpu().numpy(), y_pred=pred_label.detach().cpu().numpy(), y_pred_proba=y_pred_proba.detach().cpu().numpy())

                val_loss_avg = total_loss / val_batch_id
                print(f'Val_Loss: {val_loss_avg}')
                Val_Loss.append(val_loss_avg)
                total_loss = 0

                metric = calculator.calculate_metrics()
                model_folder = 'interval_32/model_BSS/MViT_Sift/'
                if not os.path.exists(model_folder):
                    os.makedirs(model_folder)
                # if best_TSS < metric['TSS']:
                #     best_TSS = metric['TSS']
                #     torch.save(model.state_dict(), model_folder + f'MViT{dataset_id}.pt')
                # print(metric)
                # print('Mean TSS: {}   Best_TSS:{}\n'.format(metric['TSS'], best_TSS))
                if best_BSS < metric['Brier Skill Score (BSS)']:
                    best_BSS = metric['Brier Skill Score (BSS)']
                    torch.save(model.state_dict(), model_folder + f'MViT{dataset_id}.pt')
                print(metric)
                print('Mean BSS: {}   Best_BSS:{}\n'.format(metric['TSS'], best_TSS))

                # Clear CUDA cache and collect garbage
                torch.cuda.empty_cache()
                gc.collect()

        scheduler.step()

    loss_folder = 'interval_32/Loss_BSS/MViT_Sift/'
    if not os.path.exists(loss_folder):
        os.makedirs(loss_folder)
    np.save(loss_folder + f'Train_Loss_{dataset_id}.npy', np.array(Train_Loss))
    np.save(loss_folder + f'Val_Loss_{dataset_id}.npy', np.array(Val_Loss))

    # Clear CUDA cache and collect garbage after each dataset
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
for dataset_id in range(0, 10):
    test_data, test_label = data_preprocess(f'DATA/feature/group9_Data2_image/Sift/{dataset_id}Test.csv')
    # test_data = test_data[39::40, ...]
    # test_label = test_label[39::40, ...]
    test_dataset = TensorDataset(test_data, test_label)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)
    model = torchvision.models.video.mvit_v2_s().to(device)
    model.head[1] = torch.nn.Linear(768, 2).to(device)
    calculator = MetricsCalculator()
    param = torch.load(f'interval_32/model/MViT_Sift/MViT{dataset_id}.pt')
    model.load_state_dict(param)
    model.eval()

    TSS = []
    with torch.no_grad():
        for img, label in test_loader:
            pred = model(img.to(device))
            pred_label = torch.argmax(pred, axis=-1)
            pred_onehot = F.one_hot(label, 2)
            pred_proba = F.softmax(pred, dim=-1)
            calculator.update(label, pred_label.detach().cpu().numpy(), pred_onehot.detach().cpu().numpy(), pred_proba.detach().cpu().numpy())
    metric = calculator.calculate_metrics()
    print(metric)

    metric_folder = 'interval_32/Metrics_BSS/MViT_Sift/'
    if not os.path.exists(metric_folder):
        os.makedirs(metric_folder)
    data_serializable = {k: int(v) if isinstance(v, np.integer) else v for k, v in metric.items()} 
    with open(metric_folder + f'dataset_{dataset_id}.json', 'w', encoding='utf-8') as f:  
        json.dump(data_serializable, f, ensure_ascii=True, indent=4)    

In [10]:
read_folder ='interval_32/Metrics_BSS/MViT_Sift/'
read_file = sorted(os.listdir(read_folder))[:]
TSS = []
Accuracy = []
Recall = []
FAR = []
Percision = []
HSS = []
BSS = []
for r_file in read_file:
    r_file = os.path.join(read_folder, r_file)
    with open(r_file, 'r') as f:
        data = json.load(f)
        print(data)
        TSS.append(data['TSS'])
        Accuracy.append(data['Accuracy'])
        Percision.append(data['Precision'])
        Recall.append(data['Recall'])
        FAR.append(data['FAR'])
        HSS.append(data['HSS'])
        BSS.append(data['Brier Skill Score (BSS)'])
print(len(Recall))
print('Accuracy:', np.mean(Accuracy), np.std(Accuracy))
print('Percision:', np.mean(Percision), np.std(Percision))
print('Recall:', np.mean(Recall), np.std(Recall))
print('FAR:', np.mean(FAR), np.std(FAR))
print('TSS:', np.mean(TSS), np.std(TSS))
print('HSS:', np.mean(HSS), np.std(HSS))
print('BSS:', np.mean(BSS), np.std(BSS))

{'TP': 19, 'FP': 13, 'FN': 1, 'TN': 78, 'Accuracy': 0.8738738738738738, 'Precision': 0.59375, 'Recall': 0.95, 'FAR': 0.40625, 'TSS': 0.8071428571428572, 'HSS': 0.6540516459170712, 'Brier Score (BS)': 0.09620883224360378, 'Brier Skill Score (BSS)': 0.3486873505090975}
{'TP': 17, 'FP': 10, 'FN': 2, 'TN': 90, 'Accuracy': 0.8991596638655462, 'Precision': 0.6296296296296297, 'Recall': 0.8947368421052632, 'FAR': 0.37037037037037035, 'TSS': 0.7947368421052632, 'HSS': 0.6789568330059425, 'Brier Score (BS)': 0.07629941137148981, 'Brier Skill Score (BSS)': 0.43132843977280655}
{'TP': 15, 'FP': 5, 'FN': 3, 'TN': 84, 'Accuracy': 0.9252336448598131, 'Precision': 0.75, 'Recall': 0.8333333333333334, 'FAR': 0.25, 'TSS': 0.7771535580524345, 'HSS': 0.744172143621721, 'Brier Score (BS)': 0.14276238084689005, 'Brier Skill Score (BSS)': -0.020278713056207653}
{'TP': 18, 'FP': 5, 'FN': 4, 'TN': 77, 'Accuracy': 0.9134615384615384, 'Precision': 0.782608695652174, 'Recall': 0.8181818181818182, 'FAR': 0.2173913

In [11]:
read_folder ='interval_32/Metrics/MViT_Sift1/'
read_file = sorted(os.listdir(read_folder))[:]
TSS = []
Accuracy = []
Recall = []
FAR = []
Percision = []
HSS = []
BSS = []
for r_file in read_file:
    r_file = os.path.join(read_folder, r_file)
    with open(r_file, 'r') as f:
        data = json.load(f)
        print(data)
        TSS.append(data['TSS'])
        Accuracy.append(data['Accuracy'])
        Percision.append(data['Precision'])
        Recall.append(data['Recall'])
        FAR.append(data['FAR'])
        HSS.append(data['HSS'])
print(len(Recall))
print('Accuracy:', np.mean(Accuracy), np.std(Accuracy))
print('Percision:', np.mean(Percision), np.std(Percision))
print('Recall:', np.mean(Recall), np.std(Recall))
print('FAR:', np.mean(FAR), np.std(FAR))
print('TSS:', np.mean(TSS), np.std(TSS))
print('HSS:', np.mean(HSS), np.std(HSS))


{'TP': 19, 'FP': 13, 'FN': 1, 'TN': 78, 'Accuracy': 0.8738738738738738, 'Precision': 0.59375, 'Recall': 0.95, 'FAR': 0.40625, 'TSS': 0.8071428571428572, 'HSS': 0.6540516459170712}
{'TP': 17, 'FP': 10, 'FN': 2, 'TN': 90, 'Accuracy': 0.8991596638655462, 'Precision': 0.6296296296296297, 'Recall': 0.8947368421052632, 'FAR': 0.37037037037037035, 'TSS': 0.7947368421052632, 'HSS': 0.6789568330059425}
{'TP': 15, 'FP': 5, 'FN': 3, 'TN': 84, 'Accuracy': 0.9252336448598131, 'Precision': 0.75, 'Recall': 0.8333333333333334, 'FAR': 0.25, 'TSS': 0.7771535580524345, 'HSS': 0.744172143621721}
{'TP': 18, 'FP': 5, 'FN': 4, 'TN': 77, 'Accuracy': 0.9134615384615384, 'Precision': 0.782608695652174, 'Recall': 0.8181818181818182, 'FAR': 0.21739130434782608, 'TSS': 0.7572062084257207, 'HSS': 0.7448200634001634}
{'TP': 14, 'FP': 5, 'FN': 3, 'TN': 81, 'Accuracy': 0.9223300970873787, 'Precision': 0.7368421052631579, 'Recall': 0.8235294117647058, 'FAR': 0.2631578947368421, 'TSS': 0.7653898768809849, 'HSS': 0.73089

In [7]:
read_folder = 'interval_32/Metrics/MViT_Sift/'
read_file = sorted(os.listdir(read_folder))[:]
TSS = []
for r_file in read_file:
    r_file = os.path.join(read_folder, r_file)
    with open(r_file, 'r') as f:
        data = json.load(f)
        print(data)
        TSS.append(data['TSS'])
len(TSS), np.mean(TSS), np.std(TSS)

{'TP': 19, 'FP': 13, 'FN': 1, 'TN': 78, 'Accuracy': 0.8738738738738738, 'Precision': 0.59375, 'Recall': 0.95, 'FAR': 0.14285714285714285, 'TSS': 0.8071428571428572, 'HSS': 0.6540516459170712}
{'TP': 17, 'FP': 10, 'FN': 2, 'TN': 90, 'Accuracy': 0.8991596638655462, 'Precision': 0.6296296296296297, 'Recall': 0.8947368421052632, 'FAR': 0.1, 'TSS': 0.7947368421052632, 'HSS': 0.6789568330059425}
{'TP': 15, 'FP': 5, 'FN': 3, 'TN': 84, 'Accuracy': 0.9252336448598131, 'Precision': 0.75, 'Recall': 0.8333333333333334, 'FAR': 0.056179775280898875, 'TSS': 0.7771535580524345, 'HSS': 0.744172143621721}
{'TP': 18, 'FP': 5, 'FN': 4, 'TN': 77, 'Accuracy': 0.9134615384615384, 'Precision': 0.782608695652174, 'Recall': 0.8181818181818182, 'FAR': 0.06097560975609756, 'TSS': 0.7572062084257207, 'HSS': 0.7448200634001634}
{'TP': 14, 'FP': 5, 'FN': 3, 'TN': 81, 'Accuracy': 0.9223300970873787, 'Precision': 0.7368421052631579, 'Recall': 0.8235294117647058, 'FAR': 0.05813953488372093, 'TSS': 0.7653898768809849, '

(10, 0.7748656458969915, 0.055037550694810596)