In [1]:
import os
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, Subset
from sklearn.model_selection import train_test_split, KFold
from torchvision import transforms
from PIL import Image
from transformers import SegformerForSemanticSegmentation, SegformerConfig, SegformerImageProcessor
import tkinter as tk
from tkinter import filedialog
import tensorflow as tf
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.metrics import jaccard_score, precision_score, recall_score, f1_score, accuracy_score
from transformers import SamModel, SamProcessor
from torch import nn
import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MySegFormer_0409(nn.Module):
    def __init__(self,num_classes,backbone="b1",id2label=None):
        super().__init__()
        self.num_classes = num_classes
        if id2label is not None:
            self.id2label = id2label
        else:
            self.id2label = {i:str(i) for i in range(self.num_classes)}
        self.segformer = SegformerForSemanticSegmentation.from_pretrained(f"nvidia/mit-{backbone}",
                                                         num_labels=self.num_classes, 
                                                         id2label=self.id2label, 
                                                         label2id={v:k for k,v in self.id2label.items()}
                                                         , ignore_mismatched_sizes=True)
    def forward(self,x):
        y = self.segformer(x)
        y = nn.functional.interpolate(y.logits, size=x.shape[-2:], mode="bilinear", align_corners=False,antialias=True)        
        return {'out':y}
    
class MySegFormer_0604(nn.Module):
    def __init__(self,num_classes,backbone="b0",id2label=None):
        super().__init__()
        self.num_classes = num_classes
        if id2label is not None:
            self.id2label = id2label
        else:
            self.id2label = {i:str(i) for i in range(self.num_classes)}
        self.segformer = SegformerForSemanticSegmentation.from_pretrained(f"nvidia/mit-{backbone}",
                                                         num_labels=self.num_classes, 
                                                         id2label=self.id2label, 
                                                         label2id={v:k for k,v in self.id2label.items()}
                                                         , ignore_mismatched_sizes=True)
    def forward(self,x):
        y = self.segformer(x)
        y = nn.functional.interpolate(y.logits, size=x.shape[-2:], mode="bilinear", align_corners=False,antialias=True)        
        return {'out':y}
        # 在conda 環境裡huggingface包好的Segformer有改(modeling_segformer.py)

# Student Model: Segformer 0409
model_name = "nvidia/mit-b0"
num_classes = 2
model_segformer = MySegFormer_0409(num_classes)

Position_Embedding_0628_teacher_0.01


Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b1 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
import cv2
import shutil

#weight_dir = "weights_KD_segformer_0418test_from0_60\segformer_data_size_350.pth"
#weight_dir = "C:/天_11157065/NAS_data/午仔魚_屏東張詳誌/weights/weights_KD_segformer_0628/weights_KD_segformer_0628_30/segformer_data_size_300.pth"
weight_dir = "C:/天_11157065/git/RipplesDetection/ar0DB/inference/segformer_data_size_300.pth"
#weight_dir = "C:/天_11157065/git/RipplesDetection/ar0DB/inference/weight/weights_KD_segformer_0628_30/segformer_data_size_300.pth"

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

def transform_image(image):
    try:
        # Convert PIL Image to NumPy array
        image_np = np.array(image)
        
        # Apply median blur using OpenCV
        # image_np = cv2.medianBlur(image_np, 3)
        
        # Convert back to PIL Image
        image = Image.fromarray(image_np)
        # transform the image
        transform = transforms.Compose([
            transforms.Resize((1024, 1024)),
            transforms.ToTensor(),
        ])
        return transform(image)
    except IOError as e:
        print(f"Error: - {e}")
        return None

def select_folder():
    root = tk.Tk()
    root.withdraw()
    parent_folder = filedialog.askdirectory(title="選擇影像資料夾")
    return parent_folder

def apply_mask(mask, mask_path):
    mask_image = Image.open(mask_path).convert('L')
    mask_array = np.array(mask_image)
    # 遮罩應用：將mask中被遮罩的部分設為0（背景類）
    mask[mask_array == 0] = 0
    return mask

def move_png_files(data_dir_path):
    raw_image_dir = os.path.join(data_dir_path, 'raw_image')
    os.makedirs(raw_image_dir, exist_ok=True)
    
    for filename in os.listdir(data_dir_path):
        if filename.endswith('.png'):
            src = os.path.join(data_dir_path, filename)
            dst = os.path.join(raw_image_dir, filename)
            shutil.move(src, dst)

def KD_inference(model, weight_dir, data_dir_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.load_state_dict(torch.load(weight_dir, map_location=device))
    model.to(device)
    model.eval()

    pixel_counts = []
    image_dir_path = os.path.join(data_dir_path, 'raw_image')
    os.makedirs(image_dir_path, exist_ok=True)

    result_dir = os.path.join(data_dir_path, 'results')
    os.makedirs(result_dir, exist_ok=True)

    move_png_files(data_dir_path)

    image_filename_list = sorted(os.listdir(image_dir_path))

    magic_mask_path = "magic_mask/ar0DB_magic_mask.png"

    for image_filename in image_filename_list:
        image_path = os.path.join(image_dir_path, image_filename)
        print("image_path = ", image_path)
        image = Image.open(image_path)
        if transform_image(image) is None:
            continue
        image = transform_image(image).unsqueeze(0).to(device)

        outputs = model(image)
        mask = torch.squeeze(torch.argmax(outputs['out'].cpu(), dim=1)).numpy()
        #print("mask size = ", mask.size())
        #print(mask)
        if magic_mask_path:
            mask = apply_mask(mask, magic_mask_path)

        pixel_count = int(np.sum(mask == 1))
        pixel_counts.append({'time': image_filename, 'pixel_number': pixel_count})

        overlay = image.cpu().squeeze().permute(1, 2, 0).numpy()
        red_channel = overlay[:, :, 0]
        red_channel[mask == 1] = 255
        overlay[:, :, 0] = red_channel
        overlay = Image.fromarray((overlay * 255).astype(np.uint8))
        overlay.save(os.path.join(result_dir, f"overlay_{image_filename}"))
    
    date = data_dir_path.split('-')[-1]
    with open(os.path.join(data_dir_path, f'pixel_counts_{date}.json'), 'w') as f:
        json.dump(pixel_counts, f, indent=4)
    return pixel_counts

#image_dir = select_folder()
image_dirs = [
              "C:/天_11157065/NAS_data/午仔魚_屏東張詳誌/屏東張詳誌/frames-20240821",
              "C:/天_11157065/NAS_data/午仔魚_屏東張詳誌/屏東張詳誌/frames-20240822",
              "C:/天_11157065/NAS_data/午仔魚_屏東張詳誌/屏東張詳誌/frames-20240823",
              "C:/天_11157065/NAS_data/午仔魚_屏東張詳誌/屏東張詳誌/frames-20240824",
              "C:/天_11157065/NAS_data/午仔魚_屏東張詳誌/屏東張詳誌/frames-20240825",
              "C:/天_11157065/NAS_data/午仔魚_屏東張詳誌/屏東張詳誌/frames-20240826",
              "C:/天_11157065/NAS_data/午仔魚_屏東張詳誌/屏東張詳誌/frames-20240827",
              "C:/天_11157065/NAS_data/午仔魚_屏東張詳誌/屏東張詳誌/frames-20240828",
              "C:/天_11157065/NAS_data/午仔魚_屏東張詳誌/屏東張詳誌/frames-20240829",
              "C:/天_11157065/NAS_data/午仔魚_屏東張詳誌/屏東張詳誌/frames-20240830"
             ]
for image_dir in image_dirs:
    print(image_dir)
    pixel_counts = KD_inference(model_segformer, weight_dir, image_dir)
    #C:/Users/user/Desktop/NAS_data/鱸魚/高雄黃明和/frames-20240406
    #pixel_counts = KD_inference(model_segformer, weight_dir, 'C:/Users/user/Desktop/NAS_data/鱸魚/高雄黃明和/frames-20240407_50')
    #print("Mask counts:", pixel_counts)


C:/天_11157065/NAS_data/午仔魚_屏東張詳誌/屏東張詳誌/frames-20240821
image_path =  C:/天_11157065/NAS_data/午仔魚_屏東張詳誌/屏東張詳誌/frames-20240821\raw_image\2024-08-21-00-00-01.png
image_path =  C:/天_11157065/NAS_data/午仔魚_屏東張詳誌/屏東張詳誌/frames-20240821\raw_image\2024-08-21-00-00-11.png
image_path =  C:/天_11157065/NAS_data/午仔魚_屏東張詳誌/屏東張詳誌/frames-20240821\raw_image\2024-08-21-00-00-21.png
image_path =  C:/天_11157065/NAS_data/午仔魚_屏東張詳誌/屏東張詳誌/frames-20240821\raw_image\2024-08-21-00-00-31.png
image_path =  C:/天_11157065/NAS_data/午仔魚_屏東張詳誌/屏東張詳誌/frames-20240821\raw_image\2024-08-21-00-00-41.png
image_path =  C:/天_11157065/NAS_data/午仔魚_屏東張詳誌/屏東張詳誌/frames-20240821\raw_image\2024-08-21-00-00-51.png
image_path =  C:/天_11157065/NAS_data/午仔魚_屏東張詳誌/屏東張詳誌/frames-20240821\raw_image\2024-08-21-00-01-01.png
image_path =  C:/天_11157065/NAS_data/午仔魚_屏東張詳誌/屏東張詳誌/frames-20240821\raw_image\2024-08-21-00-01-11.png
image_path =  C:/天_11157065/NAS_data/午仔魚_屏東張詳誌/屏東張詳誌/frames-20240821\raw_image\2024-08-21-00-01-21.png
image_path =  C:/天