In [None]:
import numpy as np
from sklearn.model_selection import KFold
import random
import os
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
import torchvision
from torchvision import datasets,transforms,models
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import Subset

from sklearn.metrics import confusion_matrix
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc, roc_auc_score
from sklearn import metrics
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve

from cnn_finetune import make_model
from sklearn.preprocessing import label_binarize
from itertools import cycle
import seaborn as sns

from sklearn.metrics import auc

from efficientnet_pytorch import EfficientNet

import shutil  


os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [None]:
#GPU指定
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# EfficientNet-b0 model のロード
model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=2)
#ファインチューニング（全てのパラメータを訓練可能に）
for param in model.parameters():
    param.requires_grad = True  
model = model.to(device)

In [None]:
#データセットの設定
train_dataset = torchvision.datasets.ImageFolder(root='/home/yamaguchi/最終/Ki67値クラス分類データセット/train')  
test_dataset  = torchvision.datasets.ImageFolder(root='/home/yamaguchi/最終/Ki67値クラス分類データセット/test') 


#学習用データの前処理
transform = torchvision.transforms.Compose([
    #ランダムな領域を切り取って画像をリサイズ
    transforms.RandomCrop((224,224)),
    # ランダムに画像を水平方向に反転
    transforms.RandomHorizontalFlip(),
    # ランダムに画像の色調（明るさ，コントラスト）を変更
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    # グレースケールに変換（3チャンネル出力）
    transforms.Grayscale(num_output_channels=3),
    torchvision.transforms.ToTensor(),
    #ピクセル値の正規化
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


#テスト用データの前処理
transform_2 = torchvision.transforms.Compose([
    transforms.RandomCrop((224,224)),
    transforms.Grayscale(num_output_channels=3),
    torchvision.transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

#バッチサイズの設定(2の乗数が一般的)
batch_size = 32

train_dataset.transform=transform
test_dataset.transform=transform_2
test_dataloader = DataLoader(test_dataset, batch_size, shuffle=False)

In [None]:
#訓練関数の定義

def train(device, model, optimizer, criterion, cv_train_dataloader, cv_valid_dataloader):
    # Early stoppingの設定
    the_last_loss = 100 
    #Early stoppingにおけるエポック数の設定
    patience = 10
    trigger_times = 0

    #エポック数の設定
    for epoch in range(100):
        model.train()
        running_loss = 0.0
        correct_num = 0
        total_num = 0
        batch_count = 0

        for data,target in  cv_train_dataloader:
            inputs, labels = data.to(device), target.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            
            loss = criterion(outputs, labels)
            predicted = torch.max(outputs.data, 1)[1]
            correct_num_temp = (predicted==labels).sum()
            correct_num += correct_num_temp.item()
            total_num += data.shape[0]
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            batch_count += 1 

        print('epoch:%d loss: %.3f acc: %.3f' %
             (epoch + 1, running_loss / batch_count, correct_num*100/total_num))
            

        # Early stopping
        the_current_loss = validation(model, device, cv_valid_dataloader, criterion)
        print('The current loss:', the_current_loss)

        if the_current_loss > the_last_loss:
            trigger_times += 1
            print('trigger times:', trigger_times)

            if trigger_times >= patience:
                print('Early stopping!\nStart to test process.')
                return model

        else:
            print('trigger times: 0')
            trigger_times = 0

        the_last_loss = the_current_loss

    return model

In [None]:
#検証関数の定義

def validation(model, device, cv_valid_dataloader, criterion):
    model.eval()
    running_loss = 0

    with torch.no_grad():
        for data,target in cv_valid_dataloader:
            inputs, labels = data.to(device), target.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

    return running_loss / len(cv_valid_dataloader)

In [None]:
#テスト関数の定義

def test(device, model, test_dataloader):
    model.eval()
    
    correct_num = 0
    total_num = 0
    predicts_list = []
    labels_list = []
    scores_list=[] 
    
    #誤分類画像を特定するためのフォルダパスの保存
    wrong_predictions = [] 
    wrong_predicted_labels = []  
    true_labels_for_wrong_predictions = [] 


    with torch.no_grad():
        for i, (data, target) in enumerate(test_dataloader):
            inputs, labels = data.to(device), target.to(device)

            outputs = model(inputs)
            m = nn.Softmax(dim=1)
            probs = m(outputs)
            
            _, predicted = torch.max(outputs.data, 1)
            correct_num_temp = (predicted == labels).sum()
            correct_num += correct_num_temp.item()
            total_num += data.shape[0]
            
            wrong_indices = (predicted != labels).nonzero(as_tuple=True)[0].tolist()
            for idx in wrong_indices:
                wrong_file_path = test_dataloader.dataset.samples[i * test_dataloader.batch_size + idx][0]
                wrong_predictions.append(wrong_file_path)
                # 誤分類されたラベル
                wrong_predicted_labels.append(predicted[idx].item())  
                # 正しいラベル
                true_labels_for_wrong_predictions.append(labels[idx].item())  
            
            device2 = torch.device('cpu')
            labels=labels.to(device2)
            predicted = predicted.to(device2)
            probs = probs.to(device2)

            labels_list.append(labels.numpy())
            predicts_list.append(predicted.numpy())
            scores_list.append(probs.numpy()) 
    
        labels = np.concatenate(labels_list)
        predicted = np.concatenate(predicts_list)
        scores = np.concatenate(scores_list) 

        fpr, tpr, _ = roc_curve(labels, scores[:, 1]) 
        roc_auc = auc(fpr, tpr)

        C = confusion_matrix(labels, predicted)
        ac = accuracy_score(labels, predicted)
        pre = precision_score(labels, predicted, pos_label=1)  
        re = recall_score(labels, predicted, pos_label=1)
        f1 = f1_score(labels, predicted, pos_label=1)
        
        AUC = roc_auc_score(labels, scores[:, 1])  

        print(C)
        print("\n")
        print(f"test accuracy : {ac:.3f}")
        print(f"test precision : {pre:.3f}")
        print(f"test recall : {re:.3f}")
        print(f"test F1 : {f1:.3f}")
        print(f"AUC : {AUC:.3f}")

        print("AUC for the positive class: {:.3f}".format(roc_auc))

        
        #ROC曲線図の作成
        plt.plot(fpr, tpr,
             label='ROC curve (area = {:.3g})'
                   ''.format(roc_auc),
             color='navy', linestyle=':', linewidth=4)

            
        plt.plot([0, 1], [0, 1], 'k--', lw=2)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver operating characteristic to multi-class')
        plt.legend(loc="lower right")
        
        
        #混同行列図のクラス名の設定
        class_labels = ['Low', 'High']
        C = confusion_matrix(labels, predicted)

        
        #混同行列図の作成
        plt.figure(figsize=(8, 6))
        sns.heatmap(C, annot=True, cmap='Blues', fmt='g', xticklabels=class_labels, yticklabels=class_labels,annot_kws={'size': 20})
        plt.xlabel('Predicted Labels')
        plt.ylabel('True Labels')
        plt.title('Confusion Matrix')
        plt.show()
        
        # フォルダ構造に基づいて誤分類された画像を保存
        base_destination_folder = "/home/yamaguchi/特定フォルダ"
        if not os.path.exists(base_destination_folder):
            os.makedirs(base_destination_folder)
        
        #クラスのフォルダ名
        class_folders = ['Low', 'High']  

        for file_path, true_label, wrong_label in zip(wrong_predictions, true_labels_for_wrong_predictions, wrong_predicted_labels):
            # フォルダ構造: base_destination_folder/正しいラベル/誤分類されたラベル
            destination_folder = os.path.join(base_destination_folder, class_folders[true_label], class_folders[wrong_label])
            if not os.path.exists(destination_folder):
                os.makedirs(destination_folder)

            # このフォルダに画像をコピー
            dest_file_path = os.path.join(destination_folder, os.path.basename(file_path))
            shutil.copy(file_path, dest_file_path)


In [None]:
#k分割交差検証

kf = KFold(n_splits=10)

In [None]:
#main関数
def main():
    
    for _fold, (train_index, valid_index) in enumerate(kf.split(np.arange(len(train_dataset)))):
    
        #交差検証に伴う使用モデルの再定義
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=2)
       
        for param in model.parameters():
            param.requires_grad = True  
        model = model.to(device)


        batch_size = 32
        #損失関数の定義
        criterion = nn.CrossEntropyLoss()
        #最適化関数の定義
        optimizer = optim.SGD(model.parameters(), lr=0.001,momentum=0.9)
        
        #交差検証に伴うデータセットの分割
        cv_train_dataset = Subset(train_dataset, train_index)
        cv_train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
        cv_valid_dataset   = Subset(train_dataset, valid_index)
        cv_valid_dataloader = DataLoader(train_dataset, batch_size, shuffle=False)
        
        print('Fold {}------------------------------------------------------------------------------'.format(_fold+1))

        model = train(device, model, optimizer, criterion, cv_train_dataloader, cv_valid_dataloader)
        #パラメータの保存
        torch.save(model.state_dict(), 'sample_' + str(_fold) + '.pt')
        #test関数の実行
        test(device, model, test_dataloader)


if __name__ == '__main__':
    main()