In [1]:
import torch
import pandas as pd
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from PIL import Image
import os
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from collections import Counter
from transformers import MaskFormerFeatureExtractor,MaskFormerForInstanceSegmentation


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "use_huggingface/final_checkpoint"
# feature_extractor = SegformerImageProcessor.from_pretrained(model_name)
# model = SegformerForSemanticSegmentation.from_pretrained(model_name).to(device)

feature_extractor = MaskFormerFeatureExtractor.from_pretrained(model_name)
# inputs = feature_extractor(images=image, return_tensors="pt")
model = MaskFormerForInstanceSegmentation.from_pretrained(model_name).to(device)

label_df = pd.read_csv("test_label.csv")


Backbone maskformer-swin is not a supported model and may not be compatible with MaskFormer. Supported model types: resnet,swin


In [3]:
def find_mode(lst):
    counter = Counter(lst)
    max_count = max(counter.values())
    mode_values = [item for item, count in counter.items() if count == max_count]
    return mode_values

In [4]:
def get_pred(img_paths):
    images = [Image.open(img_path) for img_path in img_paths]
    # print(len(images[0].getbands()))

    batch_size = 4
    results = {"max0_1":0,"max0_2":0,"max0_3":0,"max1_1":0,"max1_2":0,"max1_3":0}

    # for i in tqdm(range(0, len(images), batch_size),total=int(len(images)//batch_size)):
    for i in range(0, len(images), batch_size):
        batch_paths = images[i:i+batch_size]
        # images[0].show()
        inputs = feature_extractor(images=batch_paths, return_tensors="pt").to(device)
        outputs = model(**inputs)
        # logits = outputs.logits
        target_sizes = [(image.size[0], image.size[1]) for image in batch_paths]

        pred_segs = feature_extractor.post_process_semantic_segmentation(outputs,target_sizes=target_sizes)

        # upsampled_logits = F.interpolate(
        #     logits,
        #     size=images[0].size[::-1],  # Assuming all images are of the same size
        #     mode='bilinear',
        #     align_corners=False
        # )

        # pred_segs = upsampled_logits.argmax(dim=1)
        # break
        for pred_seg in pred_segs:
            flat_tensor = pred_seg.view(-1)
            
            # Count the occurrences of each class
            unique, counts = torch.unique(flat_tensor, return_counts=True)
            counts_dict = dict(zip(unique.cpu().numpy(), counts.cpu().numpy()))
            
            # Get the counts of classes 0, 1, and 2
            zeros = counts_dict.get(0, 0)
            ones = counts_dict.get(1, 0)
            twos = counts_dict.get(2, 0)
            # print(counts_dict)
            
            # Determine the result based on counts of classes 1 and 2
            # if ones > 0 or twos > 0:
            #     if ones > twos:
            #         results.append(0)
            #     else:
            #         results.append(1)
            if ones > results["max0_1"]:
                results["max0_3"] = results["max0_2"]
                results["max0_2"] = results["max0_1"]
                results["max0_1"] = ones
            elif ones > results["max0_2"] and ones != results["max0_1"]:
                results["max0_3"] = results["max0_2"]
                results["max0_2"] = ones
            elif ones > results["max0_3"] and ones != results["max0_1"] and ones != results["max0_2"]:
                results["max0_3"] = ones

            if twos > results["max1_1"]:
                results["max1_3"] = results["max1_2"]
                results["max1_2"] = results["max1_1"]
                results["max1_1"] = twos
            elif twos > results["max1_2"] and twos != results["max1_1"]:
                results["max1_3"] = results["max1_2"]
                results["max1_2"] = twos
            elif twos > results["max1_3"] and twos != results["max1_1"] and twos != results["max1_2"]:
                results["max1_3"] = twos
        # break
    return results

In [6]:
# print(img_paths)
_image_name = "9v"  
img_paths = os.listdir(os.path.join("test_data_origin",_image_name))
img_paths = [os.path.join("test_data_origin",_image_name,img_name) for img_name in img_paths]
images = [Image.open(img_path) for img_path in img_paths]
print(len(images[0].getbands()))

batch_size = 4
results = {"max0_1":0,"max0_2":0,"max0_3":0,"max1_1":0,"max1_2":0,"max1_3":0}

for i in tqdm(range(0, len(images), batch_size),total=int(len(images)//batch_size)):
    batch_paths = images[i:i+batch_size]
    # images[0].show()
    inputs = feature_extractor(images=batch_paths, return_tensors="pt").to(device)
    outputs = model(**inputs)
    # logits = outputs.logits

    # upsampled_logits = F.interpolate(
    #     logits,
    #     size=images[0].size[::-1],  # Assuming all images are of the same size
    #     mode='bilinear',
    #     align_corners=False
    # )

    # pred_segs = upsampled_logits.argmax(dim=1)
    # class_queries_logits = outputs.class_queries_logits
    # masks_queries_logits = outputs.masks_queries_logits

    target_sizes = [(image.size[0], image.size[1]) for image in batch_paths]

    pred_segs = feature_extractor.post_process_semantic_segmentation(outputs,target_sizes=target_sizes)

    # break
    for pred_seg in pred_segs:
        flat_tensor = pred_seg.view(-1)
        
        # Count the occurrences of each class
        unique, counts = torch.unique(flat_tensor, return_counts=True)
        counts_dict = dict(zip(unique.cpu().numpy(), counts.cpu().numpy()))
        
        # Get the counts of classes 0, 1, and 2
        zeros = counts_dict.get(0, 0)
        ones = counts_dict.get(1, 0)
        twos = counts_dict.get(2, 0)
        print(counts_dict)
        
        if ones > results["max0_1"]:
            results["max0_3"] = results["max0_2"]
            results["max0_2"] = results["max0_1"]
            results["max0_1"] = ones
        elif ones > results["max0_2"] and ones != results["max0_1"]:
            results["max0_3"] = results["max0_2"]
            results["max0_2"] = ones
        elif ones > results["max0_3"] and ones != results["max0_1"] and ones != results["max0_2"]:
            results["max0_3"] = ones

        if twos > results["max1_1"]:
            results["max1_3"] = results["max1_2"]
            results["max1_2"] = results["max1_1"]
            results["max1_1"] = twos
        elif twos > results["max1_2"] and twos != results["max1_1"]:
            results["max1_3"] = results["max1_2"]
            results["max1_2"] = twos
        elif twos > results["max1_3"] and twos != results["max1_1"] and twos != results["max1_2"]:
            results["max1_3"] = twos
    # break
results

3


  0%|          | 0/17 [00:00<?, ?it/s]

  6%|▌         | 1/17 [00:00<00:04,  3.59it/s]

{0: 262144}
{0: 261973, 2: 171}
{0: 262144}
{0: 261594, 1: 550}


 12%|█▏        | 2/17 [00:00<00:03,  3.94it/s]

{0: 261168, 1: 669, 2: 307}
{0: 261736, 1: 402, 2: 6}
{0: 261267, 1: 179, 2: 698}
{0: 261709, 1: 435}


 18%|█▊        | 3/17 [00:00<00:03,  4.13it/s]

{0: 260693, 1: 1238, 2: 213}
{0: 261801, 1: 341, 2: 2}
{0: 260794, 1: 102, 2: 1248}
{0: 262123, 1: 5, 2: 16}


 24%|██▎       | 4/17 [00:00<00:03,  4.23it/s]

{0: 258019, 1: 3464, 2: 661}
{0: 260745, 1: 1393, 2: 6}
{0: 259870, 1: 695, 2: 1579}
{0: 261068, 1: 1043, 2: 33}


 29%|██▉       | 5/17 [00:01<00:02,  4.24it/s]

{0: 261379, 1: 765}
{0: 261655, 1: 489}
{0: 261713, 1: 430, 2: 1}
{0: 262144}


 35%|███▌      | 6/17 [00:01<00:02,  4.22it/s]

{0: 261479, 1: 659, 2: 6}
{0: 261114, 1: 1027, 2: 3}
{0: 261653, 1: 481, 2: 10}
{0: 262019, 1: 125}


 41%|████      | 7/17 [00:01<00:02,  4.29it/s]

{0: 261255, 1: 672, 2: 217}
{0: 261616, 1: 476, 2: 52}
{0: 260663, 1: 1226, 2: 255}
{0: 260621, 1: 964, 2: 559}


 47%|████▋     | 8/17 [00:01<00:02,  4.33it/s]

{0: 260977, 1: 1143, 2: 24}
{0: 259149, 1: 2165, 2: 830}
{0: 262144}
{0: 259777, 1: 1757, 2: 610}


 53%|█████▎    | 9/17 [00:02<00:01,  4.36it/s]

{0: 261719, 1: 239, 2: 186}
{0: 260832, 1: 1173, 2: 139}
{0: 260305, 1: 1641, 2: 198}
{0: 261762, 1: 189, 2: 193}


 59%|█████▉    | 10/17 [00:02<00:01,  4.38it/s]

{0: 260499, 1: 1531, 2: 114}
{0: 260696, 1: 1297, 2: 151}
{0: 260226, 1: 1696, 2: 222}
{0: 256620, 1: 3013, 2: 2511}


 65%|██████▍   | 11/17 [00:02<00:01,  4.38it/s]

{0: 256488, 1: 3261, 2: 2395}
{0: 262004, 1: 140}
{0: 256938, 1: 3012, 2: 2194}
{0: 257478, 1: 3065, 2: 1601}


 71%|███████   | 12/17 [00:02<00:01,  4.39it/s]

{0: 257984, 1: 2773, 2: 1387}
{0: 261179, 1: 521, 2: 444}
{0: 261311, 1: 493, 2: 340}
{0: 260725, 1: 1020, 2: 399}


 76%|███████▋  | 13/17 [00:03<00:00,  4.38it/s]

{0: 261255, 1: 490, 2: 399}
{0: 260927, 1: 650, 2: 567}
{0: 261724, 1: 342, 2: 78}
{0: 261241, 1: 903}


 82%|████████▏ | 14/17 [00:03<00:00,  4.40it/s]

{0: 260600, 1: 1339, 2: 205}
{0: 261104, 1: 1040}
{0: 260528, 1: 1599, 2: 17}
{0: 260322, 1: 1654, 2: 168}


 88%|████████▊ | 15/17 [00:03<00:00,  4.41it/s]

{0: 260005, 1: 1967, 2: 172}
{0: 259180, 1: 2312, 2: 652}
{0: 261935, 1: 122, 2: 87}
{0: 262136, 1: 8}


 94%|█████████▍| 16/17 [00:03<00:00,  4.40it/s]

{0: 262144}
{0: 262144}
{0: 260256, 1: 882, 2: 1006}
{0: 260092, 1: 2007, 2: 45}


18it [00:04,  4.38it/s]                        

{0: 261917, 1: 29, 2: 198}
{0: 260833, 1: 1024, 2: 287}
{0: 261567, 1: 558, 2: 19}
{0: 260218, 1: 246, 2: 1680}
{0: 260407, 1: 1612, 2: 125}
{0: 261353, 1: 791}
{0: 261424, 1: 720}





{'max0_1': 3464,
 'max0_2': 3261,
 'max0_3': 3065,
 'max1_1': 2511,
 'max1_2': 2395,
 'max1_3': 2194}

In [None]:
# img_list = os.listdir('res_data/origin')
# img_paths = [os.path.join('res_data/origin', img) for img in img_list]
# print(f"Total images: {len(img_paths)}")

# # Example of processing in batches
# # batch_size = 16
# predictions = []
# labels = []
# for i in tqdm(range(0, len(img_paths), batch_size),total=int(len(img_paths)//batch_size)):
#     batch_paths = img_paths[i:i+batch_size]
#     # import pdb
#     # pdb.set_trace()
#     batch_labels = []
#     for image_name in batch_paths:
#         img_name = os.path.basename(image_name)
#         img_index = int(img_name.split(".")[0].split("_")[0].replace("V", ""))
#         category_name = int(label_df[label_df["number"] == img_index]["label"].values[0])
#         batch_labels.append(category_name)
#     predictions.extend(get_pred(batch_paths))
#     labels.extend(batch_labels)

In [5]:
img_list = os.listdir("test_data_origin")
# pred_list = []
# label_list = []
# max0_list = []
# max1_list = []
label_list = []
res_list = []
for image_name in tqdm(img_list,total=len(img_list)):
# for image_name in img_list:
    img_paths = os.listdir(os.path.join("test_data_origin",image_name))
    img_paths = [os.path.join("test_data_origin",image_name,img_path) for img_path in img_paths]
    # print(img_paths)
    res = get_pred(img_paths)
    # print(res)
    # print(label_list)
    # pred_list.extend(res)
    # print(image_name)
    img_index = int(image_name.replace("V", "").replace("v",""))
    category_name = int(label_df[label_df["number"] == img_index]["label"].values[0])
    # print(category_name)
    # label_list.append(category_name)
    # print('---')
    # max0_list.append(res["max0"])
    # max1_list.append(res["max1"])
    res_list.append(res)
    label_list.append(category_name)

100%|██████████| 285/285 [15:41<00:00,  3.30s/it]


In [6]:
import pandas as pd
# res_list
data_df = pd.DataFrame(res_list)
data_df['label'] = label_list
# data_df = pd.DataFrame({"max0":max0_list,"max1":max1_list,"label":label_list})
data_df

Unnamed: 0,max0_1,max0_2,max0_3,max1_1,max1_2,max1_3,label
0,3928,3645,2973,3631,2269,2148,1
1,3408,2429,2328,3135,2421,1838,0
2,3706,3539,3025,1622,1488,1264,0
3,3253,3195,3079,4956,3449,3279,0
4,3262,3240,2550,4832,4299,3382,0
...,...,...,...,...,...,...,...
280,3564,3080,2797,1314,1076,948,0
281,2883,2813,2622,3577,3531,3111,0
282,2413,2245,2174,2701,1799,1596,0
283,3448,2971,2850,2033,1828,1682,1


In [22]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from imblearn.over_sampling import SMOTE
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier

X = data_df[["max0_1","max0_2","max0_3","max1_1","max1_2","max1_3"]]
y = data_df['label']

# 将数据分为训练集和测试集
# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)

X_train, y_train = X,y
X_test, y_test = X,y

smote = SMOTE(random_state=48)
X_train_smote, y_train_smote = smote.fit_resample(X_train, y_train)

# 标准化数据
scaler = StandardScaler()
X_train_smote = scaler.fit_transform(X_train_smote)
X_test = scaler.transform(X_test)

# 创建Pipeline，包含标准化和逻辑回归
# pipeline = Pipeline([
#     ('scaler', StandardScaler()),  # 标准化
#     ('logreg', LogisticRegression())  # 逻辑回归
# ])
# model_svm = SVC(probability=True, random_state=0)
model_svm = DecisionTreeClassifier(random_state=0,max_depth=5)
# model_svm = LogisticRegression(random_state=42)
model_svm.fit(X_train_smote, y_train_smote)
# model_svm.fit(X_train, y_train)

# 定义超参数搜索范围
# param_grid = {
#     'logreg__C': [0.01, 0.1, 1, 10, 100]  # 正则化强度
# }

# 使用GridSearchCV进行超参数调优
# grid_search = GridSearchCV(pipeline, param_grid, cv=5, scoring='accuracy')
# grid_search.fit(X_train_smote, y_train_smote)

# # 最优模型
# best_model = grid_search.best_estimator_

# 进行预测
# y_pred = best_model.predict(X_test)

y_pred = model_svm.predict(X_test)
# y_pred_proba = model_svm.predict_proba(X_test)[:, 1]

# 模型评价
# print("Best Parameters:", grid_search.best_params_)
print("Classification Report:")
print(classification_report(y_test, y_pred))

print("Confusion Matrix:")
print(confusion_matrix(y_test, y_pred))

# 可视化
# def plot_decision_boundary(X, y, model):
#     h = .02  # 网格步长
#     x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
#     y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
#     xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
#                          np.arange(y_min, y_max, h))
#     Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
#     Z = Z.reshape(xx.shape)
#     plt.contourf(xx, yy, Z, alpha=0.8)
#     plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', marker='o')
#     plt.xlabel('Feature 1')
#     plt.ylabel('Feature 2')
#     plt.title('Decision Boundary')
#     plt.show()

# plot_decision_boundary(X_test, y_test, model)


Classification Report:
              precision    recall  f1-score   support

           0       0.86      0.64      0.74       184
           1       0.55      0.81      0.66       101

    accuracy                           0.70       285
   macro avg       0.71      0.73      0.70       285
weighted avg       0.75      0.70      0.71       285

Confusion Matrix:
[[118  66]
 [ 19  82]]


In [43]:
accuracy = accuracy_score(label_list, pred_list)
precision = precision_score(label_list, pred_list)
recall = recall_score(label_list, pred_list)
f1 = f1_score(label_list, pred_list)

# 打印结果
print(f"Accuracy: {accuracy}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1 Score: {f1}")

Accuracy: 0.6421052631578947
Precision: 0.0
Recall: 0.0
F1 Score: 0.0


In [39]:
pred_list

[0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,


In [30]:
category_name

1