In [None]:
import os
import numpy as np
from PIL import Image
import pandas as pd

import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, roc_curve, auc


from alex import alex_test
from resnet import resnet_model
from vgg import vgg


In [None]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        self.classes = ['+14.51', '-10.5', '10.51-11.5', '11.51-12.5', '12.51-13.5', '13.51-14.5']     
        self.images = []         
        self.labels = []         

        for i, c in enumerate(self.classes):         
            class_dir = os.path.join(root_dir, c)     
            for img_name in os.listdir(class_dir):   
                img_path = os.path.join(class_dir, img_name)  
                self.images.append(img_path)
                self.labels.append(i)
    def __len__(self):
        return len(self.images)        

    def __getitem__(self, idx):       
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert('RGB')   

        if self.transform:                  
            image = self.transform(image)

        return image, label                

def extract_features(model, dataloader, device):
    features = []
    labels = []

    for imgs, lbls in dataloader:    
        imgs = imgs.to(device)     
        lbls = lbls.to(device)

        with torch.no_grad():                 
            features_batch = model(imgs)    
            features.append(features_batch.cpu().numpy())
            labels.append(lbls.cpu().numpy())
            # print(labels)

    features = np.vstack(features)    
    labels = np.hstack(labels)      
    # labels = np.vstack(labels)

    return features, labels

def combine_features(models, dataloader, device):
    all_features = []
    for model in models:
        features, labels = extract_features(model, dataloader, device)
        all_features.append(features.reshape(len(features), -1))
        # all_features.append(features)
    print(np.hstack(all_features).shape)
    print(np.expand_dims(labels, axis=1).shape)
    return np.hstack(all_features), np.expand_dims(labels, axis=1)


In [None]:
def main():
    data_transform = transforms.Compose([
        transforms.CenterCrop((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
             

    train_data = CustomDataset(root_dir=r'E:\clongz\stacking_data\xiaorongshiyan\train', transform=data_transform)  


    val_data = CustomDataset(root_dir=r'E:\clongz\stacking_data\xiaorongshiyan\val', transform=data_transform)
    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=nw)
    val_loader = DataLoader(val_data, batch_size=32, shuffle=False, num_workers=nw)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    resnetmodel = resnet_model.resnet34(num_classes=6, include_top=True)
    resnetmodel.load_state_dict(torch.load(r'E:\onedrive\OneDrive - cumt.edu.cn\Python\stacking\resnet\resnet34.pth'))    
    resnetmodel = torch.nn.Sequential(*(list(resnetmodel.children())[:-1]))
    resnetmodel.to(device)

    vgg_model = vgg.vgg('vgg16', num_classes=6, init_weights=True)
    vgg_model.load_state_dict(torch.load(r'E:\onedrive\OneDrive - cumt.edu.cn\Python\stacking\vgg\vgg_first.pth'))    
    vgg_model.classifier = torch.nn.Identity()
    vgg_model.to(device)

    alexnet_model = alex_test.AlexNet()   
    alexnet_model.load_state_dict(torch.load(r'E:\onedrive\OneDrive - cumt.edu.cn\Python\stacking\alex\alexnet.pth'))   
    alexnet_model.classifier = torch.nn.Identity()
    alexnet_model.to(device)


    # models = [resnetmodel, vgg_model, alexnet_model]
    # models = [resnetmodel]
    models = [alexnet_model, resnetmodel]

    train_features, train_labels = combine_features(models, train_loader, device)
    val_features, val_labels = combine_features(models, val_loader, device)

    clf = RandomForestClassifier(n_estimators=100, random_state=42)
    clf.fit(train_features, train_labels)

    val_preds = clf.predict(val_features)        #val_preds = clf.predict
    cm = confusion_matrix(val_labels, val_preds)

    target_names = ['+14.51', '-10.5', '10.51-11.5', '11.51-12.5', '12.51-13.5', '13.51-14.5']
    # class_accuracy = np.diag(cm) / cm.sum(axis=1) * 100

    plt.figure(figsize=(10, 7))
    sns.heatmap(cm, annot=True, fmt="d", cmap="YlGnBu")

    # for i in range(6):
    #     for j in range(6):
    #         if i == j: 
    #             text = f"{class_accuracy[i]:.1f}%"
    #             plt.text(j + 0.5, i +1, text, ha='center', va='bottom', color='r')

    plt.xticks(np.arange(6)+0.5,target_names)
    plt.yticks(np.arange(6)+0.5,target_names)
    plt.tick_params(axis='x', which='both', bottom=False, top=False)
    plt.tick_params(axis='y', which='both', left=False, right=False)
    plt.xlabel("Predicted Ash")
    plt.ylabel("True Ash")
    plt.title("Confusion Matrix")
    # plt.show()
    # plt.figure(figsize=(10, 7))
    # sns.heatmap(cm, annot=True, fmt="d", cmap="YlGnBu")
    # plt.xlabel("Predicted label")
    # plt.ylabel("True label")
    # plt.title("Confusion Matrix")
    plt.savefig("E:\onedrive\OneDrive - cumt.edu.cn\Python\stacking\confusion_matrix.png", dpi=300)

    ##
    # print classification report

    # cr = classification_report(val_labels, val_preds, target_names=target_names)
    # print(cr)


    # print(classification_report(val_labels, val_preds, target_names=target_names))
    report = classification_report(val_labels, val_preds, target_names=target_names, output_dict=True)
    print(report)
    # df_report = pd.DataFrame(report).transpose()
    # plt.figure(figsize=(10, 6))
    # plt.plot(df_report['recall'], label='Recall')
    # plt.plot(df_report['f1-score'], label='F1-score')

  
    # plt.title('Recall and F1-score')
    # plt.xlabel('Classes')
    # plt.ylabel('Score')
  
    # plt.legend()
    # plt.savefig(r"E:\onedrive\OneDrive - cumt.edu.cn\Python\stacking\R_F.png", dpi=300)





if __name__ == "__main__":
    main()
