先import一些library

In [None]:
import datetime
import os
import time
import warnings
import gc
# import presets
import torch
import torch.utils.data
import torchvision
from tmp_modules import utils,transforms

# from coco_utils import get_coco
from torch import nn
from torch.optim.lr_scheduler import PolynomialLR
from torchvision.transforms import functional as F, InterpolationMode
import torchvision.transforms as transforms
try:
    from pytorch_model_summary import summary
except:
    !pip install pytorch-model-summary
    from pytorch_model_summary import summary

try:
    from torchviz import make_dot
except:
    !pip install torchviz
    from torchviz import make_dot

import transformers
try:
    import datasets
except:
    !pip install cchardet
    !pip install datasets
    import datasets
    
# from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
# from IPython.display import clear_output, display

try:
        from thop import profile
except:
        !pip install thop

import torch
from transformers import SamModel, SamProcessor

先定義一些function

In [None]:
def criterion(inputs, target):
    losses = {}
    for name, x in inputs.items():
        losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)

    if len(losses) == 1:
        return losses["out"]

    return losses["out"] + 0.5 * losses["aux"]

def KD_criterion(student_outputs, teacher_outputs, labels, temperature=1):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    """
    print("student_outputs size = ", student_outputs["out"].size())
    print("student_outputs = ", student_outputs["out"])
    print("teacher_outputs size = ", teacher_outputs[0].size())
    print("teacher_outputs = ", teacher_outputs[0])
    """
    # Calculate Cross Entropy
    original_loss = nn.functional.cross_entropy(student_outputs["out"], labels, ignore_index=255)

    # Calculate Distillation Loss
    soft_teacher_outputs = torch.softmax(teacher_outputs[0][0, 0, :, :] / temperature, dim=1)
    soft_student_outputs = torch.log_softmax(student_outputs["out"][0, 0, :, :] / temperature, dim=1)
    distillation_loss = nn.KLDivLoss()(soft_student_outputs.to(device), soft_teacher_outputs.to(device))

    # 總損失為原始損失加上蒸餾損失
    total_loss = 0.3*original_loss + 0.7*distillation_loss

    return total_loss

def evaluate(model, data_loader, device, num_classes):
    model.eval()
    model.to(device)

    header = "Test:"
    num_processed_samples = 0
    total_loss            = 0
    with torch.inference_mode():
        for image, target in data_loader:
            image, target = image.to(device), target.to(device)

            ###
            target = target.squeeze(1)
            # Convert target to Long type
            target = target.type(torch.LongTensor).to(device)
            
            output = model(image)
            loss = criterion(output, target)
            total_loss += loss.item()
            # output = output["out"]
            # confmat.update(target.flatten(), output.argmax(1).flatten())
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            num_processed_samples += image.shape[0]
            
            del image, target, output, loss

            gc.collect()
            torch.cuda.empty_cache()
    return total_loss/ num_processed_samples

# 用enumerate去iterate through all the data.
def train_one_epoch_01(student_model, teacher_model, processor, criterion, optimizer, data_loader, lr_scheduler, device, pbar, scaler=None):
    student_model.train()
    teacher_model.eval()
    training_loss = []
    for idx, (image, target) in enumerate(data_loader):
        target2 = np.array(target)
        image, target = image.to(device), target.to(device)
        ### ground truth given by annotation
        target = target.squeeze(1)
        # Convert target to Long type
        target = target.type(torch.LongTensor).to(device)
        ############################################################### Teacher Model : SAM
        # Retrieve the image embeddings
        # processor
        inputs = processor(image, return_tensors="pt").to(device)
        image_embeddings = teacher_model.get_image_embeddings(inputs["pixel_values"])

        # 待加功能$$$$$
        # setup prompts : target points, bounding box
        input_boxes = [[[250, 300, 700, 550]]]
        input_points = [[[500, 400], [600, 400]]]
        
        bbox = [[[get_bounding_box(target2)]]]
        # 送到processor計算遮罩
        # inputs = processor(image, input_boxes=[input_boxes], input_points=[input_points], return_tensors="pt").to(device)
        #inputs = processor(image, return_tensors="pt").to(device)
        inputs = processor(image, input_boxes=[bbox], return_tensors="pt").to(device)

        inputs.pop("pixel_values", None)
        inputs.update({"image_embeddings": image_embeddings})

        """
        with torch.no_grad():
            outputs = model(**inputs)

        masks, teacher_output = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
        """
        ###############################################################
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            student_output = student_model(image)
            sam_outputs = teacher_model(**inputs)
            # print("sam_outputs = ", sam_outputs)
            masks, teacher_output = processor.image_processor.post_process_masks(sam_outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
            #print("output-->", output)
            loss = criterion(student_output, teacher_output, target)

        optimizer.zero_grad()
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        if lr_scheduler is not None:
            lr_scheduler.step()

        pbar.update(idx, values=[("loss",loss.item())])

        training_loss.append(loss.item())

        del image, target, student_output, loss

        gc.collect()
        torch.cuda.empty_cache()
    '''
        except:
            print('except: *************')
            error_count += 1
            if error_count < 10:
                model.to('cpu')
                gc.collect()
                torch.cuda.empty_cache()
                model.to(device)
                model.train()
            else:
                raise RuntimeError('GPU out of memory error')
    '''
    return np.mean(np.array(training_loss))

def train(student_model, teacher_model, processor, epochs, data_loader, data_loader_valid, early_stop=0, model_pathname=True):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device, ":",torch.cuda.get_device_name(0))
    gc.collect()
    torch.cuda.empty_cache()

    student_model.to(device)
    teacher_model.to(device)
    if hasattr(student_model,'backbone'):
        params_to_optimize = [
            {"params": [p for p in student_model.backbone.parameters() if p.requires_grad]},
            {"params": [p for p in student_model.classifier.parameters() if p.requires_grad]},
        ]
    else:
        params_to_optimize = [{ "params": student_model.parameters() }]

    optimizer = torch.optim.Adam(params_to_optimize)
    scaler = None #torch.cuda.amp.GradScaler()
    iters_per_epoch = len(data_loader)
    lr_scheduler    = None # PolynomialLR(optimizer, total_iters=iters_per_epoch * epochs, power=0.9)
    training_loss   = []
    val_loss        = []
    min_val_loss    = np.inf
    min_val_epoch   = 0
    start_time = time.time()

    n_batch = len(data_loader)
    pbar = tf.keras.utils.Progbar(target=n_batch,stateful_metrics=['val_loss'])

    ######### weight
    weight_filename = "segformer.pth"
    # 確定weights資料夾是否存在，如果不存在則創建它
    weights_dir = os.path.join(os.getcwd(), "weights")
    if not os.path.exists(weights_dir):
        os.makedirs(weights_dir)
    model_pathname = os.path.join(weights_dir, weight_filename)
    ######### weight end    
    for epoch in range(epochs):
        
        print(f'Epoch {epoch+1}/{epochs}')

        #tr_loss = train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, 10, scaler)
        tr_loss = train_one_epoch_01(student_model, teacher_model, processor, KD_criterion, optimizer, data_loader, lr_scheduler, device, pbar, scaler)
        
        #print(tr_loss)
        training_loss.append(tr_loss)

        if data_loader_valid is not None:
            # def evaluate(model, data_loader, device, num_classes):
            # def evaluate01(student_model, teacher_model, processor, data_loader, device, criterion):
            v_loss = evaluate(student_model, data_loader_valid, device=device, num_classes=2)
            #v_loss = evaluate01(student_model, teacher_model, processor, data_loader_valid, KD_criterion, device=device)
            val_loss.append(v_loss)
            pbar.update(n_batch, values=[('val_loss', v_loss)])

            # save the parameters with the least loss
            if v_loss < min_val_loss:
                min_val_loss = v_loss
                min_val_epoch= epoch
                if model_pathname is not None:
                    torch.save(student_model.state_dict(),model_pathname)
                    print(f"Saved model weights to '{model_pathname}'.")

            # early stop
            if early_stop > 0 and epoch - min_val_epoch >= early_stop:
                break

        if lr_scheduler is not None:
            checkpoint = {
                "model": student_model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
            } 
        else:
            checkpoint = {
                "model": student_model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "epoch": epoch,
            } 

    if data_loader_valid is not None and model_pathname is not None:
        student_model.load_state_dict(torch.load(model_pathname))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f"Training time {total_time_str}")
    return {'loss':training_loss,'val_loss':val_loss}

def showSegmentationResult(model, dataset, num_images=10):
    plt.figure(figsize=(16, num_images * 5))
    model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    for i, idx in enumerate(range(min(len(dataset), num_images))):
        image, target = dataset[idx]
        outputs = model(torch.unsqueeze(image, 0).to(device))   
        labels = torch.squeeze(torch.argmax(outputs['out'].cpu(), dim=1)).numpy()

        # Original Image
        plt.subplot(num_images, 4, i * 4 + 1)
        plt.imshow(np.transpose(torch.squeeze(image).cpu().numpy(), (1, 2, 0)))
        plt.axis(False)
        plt.title('Original Image')

        # Ground Truth Label
        plt.subplot(num_images, 4, i * 4 + 2)
        plt.imshow(target.permute(1, 2, 0).cpu().numpy())
        plt.axis(False)
        plt.title('Ground Truth Label')

        # Model Prediction
        plt.subplot(num_images, 4, i * 4 + 3)
        plt.imshow(labels)
        plt.axis(False)
        plt.title('Model Prediction')

        # Overlay Prediction on Original Image
        overlay = image.clone()
        overlay[0, labels == 1] = 1
        plt.subplot(num_images, 4, i * 4 + 4)
        plt.imshow(np.transpose(overlay.squeeze().cpu().numpy(), (1, 2, 0)))
        plt.axis(False)
        plt.title('Overlay Prediction on Original Image')

    plt.show()

    print("outputs['out'] size = ", outputs['out'].size())
    print(outputs['out'])
    print("labels = ", labels)
    return

def get_bounding_box(ground_truth_map):
  ground_truth_map = ground_truth_map[0, 0, :, :]
  #print("ground_truth_map = ", ground_truth_map.shape)
  # get bounding box from mask
  y_indices, x_indices = np.where(ground_truth_map > 0)
  x_min, x_max = np.min(x_indices), np.max(x_indices)
  y_min, y_max = np.min(y_indices), np.max(y_indices)
  # add perturbation to bounding box coordinates
  H, W = ground_truth_map.shape
  x_min = max(0, x_min - np.random.randint(0, 20))
  x_max = min(W, x_max + np.random.randint(0, 20))
  y_min = max(0, y_min - np.random.randint(0, 20))
  y_max = min(H, y_max + np.random.randint(0, 20))
  bbox = [x_min, y_min, x_max, y_max]

  return bbox

定義Segformer，直接用包好的

In [None]:
import torch.nn as nn
from transformers import SegformerForSemanticSegmentation

class SegFormer(nn.Module):
    def __init__(self,num_classes,backbone="b1",id2label=None):
        super().__init__()
        self.num_classes = num_classes
        if id2label is not None:
            self.id2label = id2label
        else:
            self.id2label = {i:str(i) for i in range(self.num_classes)}
        self.segformer = SegformerForSemanticSegmentation.from_pretrained(f"nvidia/mit-{backbone}",
                                                         num_labels=self.num_classes, 
                                                         id2label=self.id2label, 
                                                         label2id={v:k for k,v in self.id2label.items()})
    def forward(self,x):
        y = self.segformer(x)
        y = nn.functional.interpolate(y.logits, size=x.shape[-2:], mode="bilinear", align_corners=False,antialias=True)        
        return {'out':y}

class SegFormer01(nn.Module):
    def __init__(self,num_classes,backbone="b1",id2label=None):
        super().__init__()
        self.num_classes = num_classes
        if id2label is not None:
            self.id2label = id2label
        else:
            self.id2label = {i:str(i) for i in range(self.num_classes)}
        self.segformer = SegformerForSemanticSegmentation
    def forward(self,x):
        y = self.segformer(x)
        y = nn.functional.interpolate(y.logits, size=x.shape[-2:], mode="bilinear", align_corners=False,antialias=True)        
        return {'out':y}
    
num_classes = 2
segformer_transform   = transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Resize((128,128),antialias=True)])
segformer_target_transform = transforms.Compose([lambda x:torch.from_numpy(cv2.resize(np.asarray(x),(128,128),cv2.INTER_NEAREST).astype(np.int64)-1)])
# segformer_model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b1")
# segformer_model =SegformerForSemanticSegmentation.from_pretrained(pretrained_model_name_or_path="C:/Users/user/Desktop/Ripple_KD/weights/segformer.pth")
# segformer_model.load_state_dict()

segformer_model = SegFormer(num_classes)

print(summary(segformer_model, torch.zeros((1, 3, 1024, 1024)), show_input=True, show_parent_layers=True, max_depth=1))

# 計算FLOP
flops, params = profile(segformer_model, inputs=(torch.zeros((1, 3, 1024, 1024)),), verbose=False)
print("FLOPs =", '{:,.0f}'.format(flops))

定義SAM

In [None]:
############################################################### Teacher Model : SAM
# define SAM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

print(summary(sam_model, torch.zeros((1, 3, 1024, 1024)).to(device), show_input=True, show_parent_layers=True, max_depth=1))
"""
# 計算FLOP
flops, params = profile(sam_model, inputs=(torch.zeros((1, 3, 1024, 1024)),), verbose=False)
print("FLOPs =", '{:,.0f}'.format(flops))
"""

In [None]:
from torch.utils.data import Dataset
from PIL import Image
import os

class SplashDataSet(Dataset):
    def __init__(self, data_dir, train_size, val_size, transform=None):
        self.data_dir = data_dir
        self.images_dir = os.path.join(data_dir, 'images')
        self.annotations_dir = os.path.join(data_dir, 'annotations')
        self.images_list = os.listdir(self.images_dir)
        self.annotations_list = os.listdir(self.annotations_dir)
        assert len(self.images_list) == len(self.annotations_list), "Number of images and annotations should be the same."
         # 隨機選擇訓練集和驗證集
        if train_size:
            train_indices = np.random.choice(len(self.images_list), train_size, replace=False)
            self.images_list = [self.images_list[i] for i in train_indices]
            self.annotations_list = [self.annotations_list[i] for i in train_indices]
        
        if val_size:
            val_indices = np.random.choice(len(self.images_list), val_size, replace=False)
            self.images_list = [self.images_list[i] for i in val_indices]
            self.annotations_list = [self.annotations_list[i] for i in val_indices]
        """
        self.train_indices = train_indices
        self.val_indices = val_indices
        """
        self.transform = transform
        
    def __len__(self):
        return len(self.images_list)
    
    def __getitem__(self, idx):
        # Read original image
        img_name = os.path.join(self.images_dir, self.images_list[idx])
        image = Image.open(img_name)

        # Read mask
        mask_name = os.path.join(self.annotations_dir, self.annotations_list[idx])
        mask = Image.open(mask_name)

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        # Convert mask to binary 0 and 1
        mask = (mask > 0).to(torch.int)
        mask = mask[0, None, :, :]
        return image, mask

In [None]:
# trial 
from torch.utils.data import Subset, DataLoader
import tkinter as tk
from tkinter import filedialog
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt

# transform the image
transform = transforms.Compose([
    transforms.Resize((1024, 1024)),
    transforms.ToTensor(),
])

def select_folder():
    root = tk.Tk()
    root.withdraw()
    parent_folder = filedialog.askdirectory(title="選擇資料夾")
    return parent_folder

data_folder = select_folder()

# 初始化 KFold
kfold = KFold(n_splits=5, shuffle=False)

# 定義不同的訓練集大小
train_sizes = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]  # 每個訓練集的大小
val_size = None  # 驗證集大小

# 初始化空的列表來存儲不同訓練集大小下的平均驗證損失
mean_val_losses = []
std_val_losses = []
mean_val_losses_KD = []
std_val_losses_KD = []

# 迴圈遍歷不同的訓練集大小
for train_size in train_sizes:
    val_losses = []  # 存儲每個 fold 的驗證損失
    val_losses_KD = []
    
    # create SplashDataSet
    dataset = SplashDataSet(data_dir=data_folder, train_size=train_size, val_size=val_size, transform=transform)
    
    # 迴圈遍歷每一個 fold
    for fold, (train_index, val_index) in enumerate(kfold.split(dataset)):
        print(f'Fold {fold+1}')
        
        # 分割訓練集和驗證集
        train_dataset = Subset(dataset, train_index)
        val_dataset = Subset(dataset, val_index)
        
        # 初始化 DataLoader
        train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
        val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
        
        # 在這裡進行模型的訓練和評估
        lc = train(segformer_model, sam_model, processor, 10, train_loader, val_loader, 0, 1)
        lc_KD = train(segformer_model, sam_model, processor, 10, train_loader, val_loader, 0.8, 1)

        # 取最後一個 epoch 的驗證損失
        val_loss = lc['val_loss'][-1]
        val_losses.append(val_loss)
        
        val_loss_KD = lc_KD['val_loss'][-1]
        val_losses_KD.append(val_loss_KD)
    
    # 計算平均驗證損失並添加到列表中
    mean_val_loss = np.mean(val_losses)
    mean_val_losses.append(mean_val_loss)
    std_val_loss = np.std(val_losses)
    std_val_losses.append(std_val_loss)

    mean_val_loss_KD = np.mean(val_losses_KD)
    mean_val_losses_KD.append(mean_val_loss_KD)
    std_val_loss_KD = np.std(val_losses_KD)
    std_val_losses_KD.append(std_val_loss_KD)

# 在這裡繪製圖表
plt.figure()
plt.plot(train_sizes, mean_val_losses, marker='o', color='black')
plt.plot(train_sizes, mean_val_losses_KD, marker='o', color='blue')
plt.title('Mean Validation Loss vs. Training Set Size')
plt.xlabel('Training Set Size')
plt.ylabel('Mean Validation Loss')
plt.grid(True)
plt.show()

In [None]:
plt.figure(figsize=(16, 6))

# 圖1
plt.subplot(1, 2, 1)
plt.plot(train_sizes, mean_val_losses_KD, marker='o', color='b', label='Mean Validation Loss')
plt.fill_between(train_sizes, np.maximum(0, np.array(mean_val_losses_KD) - np.array(std_val_losses_KD)), 
                 np.array(mean_val_losses_KD) + np.array(std_val_losses_KD), color='skyblue', alpha=0.3)
plt.title('Mean Validation Loss vs. Training Set Size (teacher_ratio = 0.8, temperature = 1)')
plt.xlabel('Training Set Size')
plt.ylabel('Mean Validation Loss')
plt.grid(True)
plt.legend()

# 圖2
plt.subplot(1, 2, 2)
plt.plot(train_sizes, mean_val_losses, marker='o', color='b', label='Mean Validation Loss')
plt.fill_between(train_sizes, np.maximum(0, np.array(mean_val_losses) - np.array(std_val_losses)), 
                 np.array(mean_val_losses) + np.array(std_val_losses), color='black', alpha=0.3)
plt.title('Mean Validation Loss vs. Training Set Size (teacher_ratio = 0, temperature = 1)')
plt.xlabel('Training Set Size')
plt.ylabel('Mean Validation Loss')
plt.grid(True)
plt.legend()

plt.tight_layout()  # 避免重疊
plt.show()


印出來看看

In [None]:
print("印出來看看👀")
showSegmentationResult(segformer_model, val_dataset)

Inference

In [None]:
"""
weight_path = "weights/segformer.pth"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(torch.cuda.get_device_name(0))
print(device)
num_classes = 2
KD_Segformer = SegFormer01(num_classes)
""""""
# 載入已經訓練好的最佳權重
if weight_path is not None:
    print(f"Loaded model weights from '{weight_path}'.")
    KD_Segformer.load_state_dict(torch.load(weight_path))
else:
    print("Model pathname is not specified. Cannot load model weights.")
""""""
KD_Segformer.to(device)
KD_Segformer.eval()

png_files = select_folder()
for file_path in os.listdir(png_files):
    if file_path.endswith(".png"):
        # 讀一張一張 frame
        raw_image = Image.open(file_path).convert("RGB")

        with torch.no_grad():
            outputs = KD_Segformer(torch.unsqueeze(raw_image, 0).to(device))
            mask = torch.squeeze(torch.argmax(outputs['out'].cpu(), dim=1)).numpy()

        plt.imshow(mask)
        plt.axis(False)
        plt.title('Model Prediction')

        # Overlay Prediction on Original Image
        overlay = raw_image.clone()
        overlay[0, mask == 1] = 1
        plt.imshow(np.transpose(overlay.squeeze().cpu().numpy(), (1, 2, 0)))
        plt.axis(False)
        plt.title('Overlay Prediction on Original Image')
"""        

In [None]:
"""
############################################################### Teacher Model : SAM
# define SAM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

# Retrieve the image embeddings
# processor
inputs = processor(image, return_tensors="pt").to(device)
image_embeddings = model.get_image_embeddings(inputs["pixel_values"])

# 待加功能$$$$$
# setup prompts : target points, bounding box
input_boxes = [[[250, 300, 700, 550]]]
input_points = [[[500, 400], [600, 400]]]

# 送到processor計算遮罩
inputs = processor(image, input_boxes=[input_boxes], input_points=[input_points], return_tensors="pt").to(device)
# inputs = processor(raw_image, input_points=[input_points], return_tensors="pt").to(device)

inputs.pop("pixel_values", None)
inputs.update({"image_embeddings": image_embeddings})

with torch.no_grad():
    outputs = model(**inputs)

masks, teacher_output = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())

###############################################################
"""