## 数据预处理

In [1]:
import os
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np


seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)


class TongueDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_dir = os.path.join(root_dir, 'images')
        self.mask_dir = os.path.join(root_dir, 'annotations')
        self.class_names = self.load_class_names()

    def load_class_names(self):
        class_names_file = os.path.join(self.root_dir, 'class_names.txt')
        with open(class_names_file, 'r') as f:
            class_names = f.readlines()
        class_names = [name.strip() for name in class_names]
        return class_names

    def __len__(self):
        return len(os.listdir(self.image_dir))

    def __getitem__(self, idx):
        img_name = os.listdir(self.image_dir)[idx]
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(
            self.mask_dir, os.path.splitext(img_name)[0] + '.png')

        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path)

        if self.transform:
            image = self.transform(image)
            mask = transforms.Resize((32, 32))(mask)
            mask = torch.tensor(np.array(mask), dtype=torch.float32)

        return image, mask


data_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor()
])

dataset = TongueDataset(root_dir='./data/split_dataset_ultra', transform=data_transform)

print(len(dataset))
print(dataset.class_names)
print(dataset[0][0].shape)
print(dataset[0][1].shape)
print(dataset[0][1].unique())
print(dataset[0][1])

1000
['_background_', 'Tg']
torch.Size([3, 32, 32])
torch.Size([32, 32])
tensor([0., 1.])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


In [2]:
from torch.utils.data import random_split


aug_transform = transforms.Compose([
    transforms.RandomApply(transforms.ColorJitter(brightness=0.2, contrast=0.2,
                           saturation=0.2, hue=0.1)),
    transforms.RandomApply(transforms.GaussianBlur(3, sigma=(0.1, 2.0)), 0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
                             0.229, 0.224, 0.225]),
])

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
                             0.229, 0.224, 0.225]),
])

# 划分数据集
train_size = int(0.6 * len(dataset))
val_size = int(0.2 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
train_dataloader.dataset.transform = aug_transform
for i, (img, mask) in enumerate(train_dataloader):
    print(img.shape, mask.shape)
    print(mask.unique())
    break
val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=False)
val_dataloader.dataset.transform = transform
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False)
test_dataloader.dataset.transform = transform

torch.Size([2, 3, 32, 32]) torch.Size([2, 32, 32])
tensor([0., 1.])


## 定义网络

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
import numpy as np
import torch.nn.functional as F
import sys


sys.setrecursionlimit(100000)  # 将默认的递归深度修改
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# 定义基于VGG16的FCN网络
class VGG16_FCN(nn.Module):
    def __init__(self, num_classes):
        super(VGG16_FCN, self).__init__()
        # 加载预训练的VGG16模型
        vgg16 = models.vgg16(
            weights=models.VGG16_Weights.DEFAULT)

        # 取出VGG16的前面部分（去掉全连接层）
        self.features = vgg16.features

        # 用1x1卷积替换VGG16的全连接层
        self.conv6 = nn.Conv2d(512, 4096, kernel_size=5, padding=2)
        self.relu6 = nn.ReLU(inplace=True)
        self.dropout6 = nn.Dropout2d()
        self.conv7 = nn.Conv2d(4096, 4096, kernel_size=3, padding=1)
        self.relu7 = nn.ReLU(inplace=True)
        self.dropout7 = nn.Dropout2d()

        # 最后的卷积层用于生成分割结果
        self.score_fr = nn.Conv2d(4096, num_classes, kernel_size=1)
        self.upscore = nn.ConvTranspose2d(
            num_classes, num_classes, kernel_size=64, stride=32, padding=16, bias=False)

        # 初始化上采样层权重
        self.upscore.weight.data.fill_(0)
        self.upscore.weight.data[:, :, 16, 16] = 1  # 双线性插值

    def forward(self, x):
        # 前向传播
        x = self.features(x)
        x = self.relu6(self.conv6(x))
        x = self.dropout6(x)
        x = self.relu7(self.conv7(x))
        x = self.dropout7(x)
        x = self.score_fr(x)
        x = self.upscore(x)
        return x
    

def convert_to_one_hot(labels):
    # 获取原始标签中的类别数
    num_classes = len(torch.unique(labels))

    # 创建一个大小为 (B, C, W, H) 的零张量，其中 C 是类别数
    one_hot_labels = torch.zeros(labels.size(
        0), num_classes, *labels.size()[1:])

    # 遍历每个样本的原始标签，并将相应位置的值设置为 1
    for i in range(num_classes):
        one_hot_labels[:, i, :, :] = (labels == i).float()

    return one_hot_labels


class SegmentationLoss(nn.Module):
    def __init__(self, weight_bce=0.5, weight_connectivity=0, weight_smoothness=0):
        super(SegmentationLoss, self).__init__()
        self.weight_bce = weight_bce
        self.weight_connectivity = weight_connectivity
        self.weight_smoothness = weight_smoothness
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.smooth = 1e-6

    def forward(self, logits, masks):
        # 计算 BCE Loss
        bce_loss = self.bce_loss(logits, masks)

        # 计算 Dice Loss
        probs = torch.sigmoid(logits)
        intersection = torch.sum(probs * masks)
        dice_loss = 1 - (2. * intersection + self.smooth) / \
            (torch.sum(probs) + torch.sum(masks) + self.smooth)
        
        if self.weight_connectivity != 0:
            connectivity_loss = torch.tensor(0, dtype=torch.float32)
            images = torch.argmax(probs, dim=1)
            # masks = torch.argmax(masks, dim=1)
            for i in range(images.size(0)):
                zero_count_image, one_count_image, _, _ = self.connected_components(
                    images[i])
                # zero_count_target, one_count_target, _, _ = self.connected_components(
                #     masks[i])
                connectivity_loss += (zero_count_image - 1) ** 2 + \
                    (one_count_image - 1) ** 2
            connectivity_loss /= images.size(0)
        else:
            connectivity_loss = torch.tensor(0, dtype=torch.float32)

        if self.weight_smoothness != 0:
            smoothness_loss = torch.tensor(0, dtype=torch.float32)
            images = torch.argmax(probs, dim=1)
            # masks = torch.argmax(masks, dim=1)
            pi = torch.tensor(np.pi)
            for i in range(images.size(0)):
                _, _, _, area_perimeter1_image = self.connected_components(
                    images[i])
                # _, _, _, area_perimeter1_target = self.connected_components(
                #     masks[i])
                if area_perimeter1_image is None:
                    smoothness_loss += 1
                    continue
                max_area_index_image = torch.argmax(area_perimeter1_image[:, 0])
                # max_area_index_target = torch.argmax(area_perimeter1_target[:, 0])
                perimeter_area_ratio_image = area_perimeter1_image[max_area_index_image][1] ** 2 * pi / \
                    (4 * area_perimeter1_image[max_area_index_image][0])
                # perimeter_area_ratio_target = area_perimeter1_target[max_area_index_target][1] ** 2 * pi / \
                #     (4 * area_perimeter1_target[max_area_index_target][0])
                smoothness_loss += (perimeter_area_ratio_image - 1) ** 2
            smoothness_loss /= images.size(0)
        else:
            smoothness_loss = torch.tensor(0, dtype=torch.float32)

        # 加权结合损失
        combined_loss = self.weight_bce * bce_loss + \
            (1 - self.weight_bce) * dice_loss + \
            self.weight_connectivity * connectivity_loss + \
            self.weight_smoothness * smoothness_loss

        return combined_loss
    
    def flood_fill(self, image, x, y, target_color, visited, area_perimeter):
        visited.add((x, y))
        area_perimeter[0][0] += 1
        area_perimeter[0][1] += 4

        if x > 0 and image[x - 1][y] == target_color:
            area_perimeter[0][1] -= 1
            if (x - 1, y) not in visited:
                self.flood_fill(image, x - 1, y, target_color,
                                visited, area_perimeter)
        if x < image.size(0) - 1 and image[x + 1][y] == target_color:
            area_perimeter[0][1] -= 1
            if (x + 1, y) not in visited:
                self.flood_fill(image, x + 1, y, target_color,
                                visited, area_perimeter)
        if y > 0 and image[x][y - 1] == target_color:
            area_perimeter[0][1] -= 1
            if (x, y - 1) not in visited:
                self.flood_fill(image, x, y - 1, target_color,
                                visited, area_perimeter)
        if y < image.size(1) - 1 and image[x][y + 1] == target_color:
            area_perimeter[0][1] -= 1
            if (x, y + 1) not in visited:
                self.flood_fill(image, x, y + 1, target_color,
                                visited, area_perimeter)

    def connected_components(self, image):
        visited = set()
        zero_count = torch.tensor(0, dtype=torch.float32)
        one_count = torch.tensor(0, dtype=torch.float32)
        area_perimeter0 = None
        area_perimeter1 = None

        for i in range(image.size(0)):
            for j in range(image.size(1)):
                if (i, j) not in visited:
                    if image[i][j] == 0:
                        zero_count += 1
                        temp = torch.zeros((1, 2), dtype=torch.float32)
                        self.flood_fill(image, i, j, 0, visited, temp)
                        if area_perimeter0 is None:
                            area_perimeter0 = temp.clone()
                        else:
                            area_perimeter0 = torch.cat(
                                (area_perimeter0, temp), dim=0)  # 添加到张量
                    elif image[i][j] == 1:
                        one_count += 1
                        temp = torch.zeros((1, 2), dtype=torch.float32)
                        self.flood_fill(image, i, j, 1, visited, temp)
                        if area_perimeter1 is None:
                            area_perimeter1 = temp.clone()
                        else:
                            area_perimeter1 = torch.cat(
                                (area_perimeter1, temp), dim=0)  # 添加到张量

        # 跳过初始全零项
        return zero_count, one_count, area_perimeter0, area_perimeter1

In [4]:
from tqdm import tqdm


def train_model(model, train_loader, val_loader, optimizer, num_epochs=15):
    for epoch in range(num_epochs):
        # 训练模型
        model.train()
        train_running_loss = 0.0
        for i in range(3):
            if i == 0:
                criterion = SegmentationLoss()
            elif i == 1:
                criterion = SegmentationLoss(weight_connectivity=0.5)
            else:
                criterion = SegmentationLoss(weight_smoothness=0.5)
            for images, labels in tqdm(train_loader):
                images = images.to(device)
                optimizer.zero_grad()
                outputs = model(images)
                # print(outputs.size())
                outputs_resized = F.interpolate(outputs, size=(images.size(
                    2), images.size(3)), mode='bilinear', align_corners=True)
                labels_for_loss = convert_to_one_hot(labels).to(device)
                loss = criterion(outputs_resized, labels_for_loss)  # 根据需要定义损失函数
                # loss = criterion(outputs, labels_for_loss)
                loss.backward()
                # if i == 2:
                train_running_loss += loss.item()
                optimizer.step()
    
            # 打印每个epoch的损失
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_running_loss/len(train_loader)}")
            torch.save(model.state_dict(), f'vgg16_fcn_optim_{epoch+1}_{i+1}.pth')

        model.eval()
        val_running_loss = 0.0
        IOU = 0
        iou_class1 = 0
        iou_class2 = 0
        criterion = SegmentationLoss()
        for images, labels in tqdm(val_loader):
            images = images.to(device)
            labels = labels.to(device)
            with torch.no_grad():
                outputs = model(images)
            outputs_resized = F.interpolate(outputs, size=(images.size(
                2), images.size(3)), mode='bilinear', align_corners=True)
            one_hot_labels = convert_to_one_hot(labels).to(device)
            loss = criterion(outputs_resized, one_hot_labels)
            # loss = criterion(outputs, labels_for_loss)
            val_running_loss += loss.item()
            outputs = torch.sigmoid(outputs_resized)  # 使用sigmoid函数将输出限制在0到1之间

            # 分别取出两个类别的输出和标签
            outputs_class1 = outputs[:, 0, :, :]  # 第一个类别的输出
            outputs_class2 = outputs[:, 1, :, :]  # 第二个类别的输出
            labels_class1 = one_hot_labels[:, 0, :, :]    # 第一个类别的标签
            labels_class2 = one_hot_labels[:, 1, :, :]    # 第二个类别的标签

            # 计算第一个类别的交并比（IOU）
            intersection_class1 = torch.sum(outputs_class1 * labels_class1)
            union_class1 = torch.sum(outputs_class1) + torch.sum(labels_class1)
            temp_iou1 = (intersection_class1 + 1e-6) / (union_class1 - intersection_class1 + 1e-6)
            iou_class1 += temp_iou1

            # 计算第二个类别的交并比（IOU）
            intersection_class2 = torch.sum(outputs_class2 * labels_class2)
            union_class2 = torch.sum(outputs_class2) + torch.sum(labels_class2)
            temp_iou2 = (intersection_class2 + 1e-6) / (union_class2 - intersection_class2 + 1e-6)
            iou_class2 += temp_iou2

            # 计算平均IOU
            IOU += (temp_iou1 + temp_iou2) / 2

            # 输出遮罩
            # outputs = torch.argmax(outputs, dim=1)
            # mask1 = outputs[0].squeeze().cpu().numpy()
            # final_mask1 = np.zeros_like(mask1)
            # final_mask1[mask1 == 1] = 255
            # final_mask1 = Image.fromarray(final_mask1.astype(np.uint8))
            # final_mask1.save('mask1.png')
            # label_mask1 = labels[0].squeeze().cpu().numpy()
            # final_label_mask1 = np.zeros_like(label_mask1)
            # final_label_mask1[label_mask1 == 1] = 255
            # final_label_mask1 = Image.fromarray(final_label_mask1.astype(np.uint8))
            # final_label_mask1.save('label_mask1.png')
            # mask2 = outputs[1].squeeze().cpu().numpy()
            # final_mask2 = np.zeros_like(mask2)
            # final_mask2[mask2 == 1] = 255
            # final_mask2 = Image.fromarray(final_mask2.astype(np.uint8))
            # final_mask2.save('mask2.png')
            # label_mask2 = labels[1].squeeze().cpu().numpy()
            # final_label_mask2 = np.zeros_like(label_mask2)
            # final_label_mask2[label_mask2 == 1] = 255
            # final_label_mask2 = Image.fromarray(final_label_mask2.astype(np.uint8))
            # final_label_mask2.save('label_mask2.png')

        
        print(
            f"Validation Loss: {val_running_loss/len(val_loader)}, mIOU: {IOU/len(val_loader)}, IOU Background: {iou_class1/len(val_loader)}, IOU Tongue: {iou_class2/len(val_loader)}")

In [5]:
num_classes = 2  # Background and tongue
model = VGG16_FCN(num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Train the model
train_model(model, train_dataloader, val_dataloader, optimizer)


100%|██████████| 300/300 [02:30<00:00,  1.99it/s]


Epoch 1/15, Loss: 0.34826369126637774


100%|██████████| 300/300 [05:14<00:00,  1.05s/it]


Epoch 1/15, Loss: 0.6576357787847519


100%|██████████| 300/300 [07:33<00:00,  1.51s/it]


Epoch 1/15, Loss: 125.0053027099371


100%|██████████| 100/100 [01:29<00:00,  1.11it/s]


Validation Loss: 0.28452641382813454, mIOU: 0.5641860961914062, IOU Background: 0.752784788608551, IOU Tongue: 0.37558725476264954


100%|██████████| 300/300 [03:25<00:00,  1.46it/s]


Epoch 2/15, Loss: 0.30176383929948014


100%|██████████| 300/300 [07:53<00:00,  1.58s/it]


Epoch 2/15, Loss: 0.615066617478927


100%|██████████| 300/300 [06:16<00:00,  1.26s/it]


Epoch 2/15, Loss: 124.79949548100431


100%|██████████| 100/100 [00:48<00:00,  2.06it/s]


Validation Loss: 0.2845044917613268, mIOU: 0.5636608600616455, IOU Background: 0.7528432011604309, IOU Tongue: 0.37447860836982727


100%|██████████| 300/300 [02:33<00:00,  1.96it/s]


Epoch 3/15, Loss: 0.2981033211946487


100%|██████████| 300/300 [06:49<00:00,  1.37s/it]


Epoch 3/15, Loss: 0.6000181845078866


100%|██████████| 300/300 [07:38<00:00,  1.53s/it]


Epoch 3/15, Loss: 130.6844944265733


100%|██████████| 100/100 [01:13<00:00,  1.36it/s]


Validation Loss: 0.2813735644519329, mIOU: 0.5690014362335205, IOU Background: 0.7586777806282043, IOU Tongue: 0.37932533025741577


100%|██████████| 300/300 [03:27<00:00,  1.45it/s]


Epoch 4/15, Loss: 0.2975315910826127


100%|██████████| 300/300 [07:37<00:00,  1.53s/it]


Epoch 4/15, Loss: 0.5949972349653642


100%|██████████| 300/300 [06:01<00:00,  1.20s/it]


Epoch 4/15, Loss: 123.205424633647


100%|██████████| 100/100 [00:49<00:00,  2.04it/s]


Validation Loss: 0.2830453377217054, mIOU: 0.5667007565498352, IOU Background: 0.7543577551841736, IOU Tongue: 0.3790437877178192


100%|██████████| 300/300 [02:24<00:00,  2.08it/s]


Epoch 5/15, Loss: 0.2974910165121158


KeyboardInterrupt: 

: 

In [None]:
import torch
import torch.nn as nn
from torchvision import models
import numpy as np
from PIL import Image
from torchvision import transforms


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# 定义基于VGG16的FCN网络
class VGG16_FCN(nn.Module):
    def __init__(self, num_classes):
        super(VGG16_FCN, self).__init__()
        # 加载预训练的VGG16模型
        vgg16 = models.vgg16(
            weights=models.VGG16_Weights.DEFAULT)

        # 取出VGG16的前面部分（去掉全连接层）
        self.features = vgg16.features

        # 用1x1卷积替换VGG16的全连接层
        self.conv6 = nn.Conv2d(512, 4096, kernel_size=1)
        self.relu6 = nn.ReLU(inplace=True)
        self.dropout6 = nn.Dropout2d()
        self.conv7 = nn.Conv2d(4096, 4096, kernel_size=1)
        self.relu7 = nn.ReLU(inplace=True)
        self.dropout7 = nn.Dropout2d()

        # 最后的卷积层用于生成分割结果
        self.score_fr = nn.Conv2d(4096, num_classes, kernel_size=1)
        self.upscore = nn.ConvTranspose2d(
            num_classes, num_classes, kernel_size=64, stride=32, padding=16, bias=False)

        # 初始化上采样层权重
        self.upscore.weight.data.fill_(0)
        self.upscore.weight.data[:, :, 16, 16] = 1  # 双线性插值

    def forward(self, x):
        # 前向传播
        x = self.features(x)
        x = self.relu6(self.conv6(x))
        x = self.dropout6(x)
        x = self.relu7(self.conv7(x))
        x = self.dropout7(x)
        x = self.score_fr(x)
        x = self.upscore(x)
        return x  


def get_segmentation_mask(model, image_path):

    model.eval()

    image = Image.open(image_path).convert('RGB')

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
                             0.229, 0.224, 0.225])
    ])
    image = transform(image).unsqueeze(0)
    print(image.size())
    image = image.to(device)
    with torch.no_grad():
        output = model(image)
    output = torch.sigmoid(output)
    output = torch.argmax(output, dim=1)
    mask = output.squeeze().cpu().numpy()
    return mask

# Get segmentation mask for a sample image

image_path = './data/test/label_data/test1.jpg'  # Replace with your image path
# 载入模型权重
model = VGG16_FCN(num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.load_state_dict(torch.load('model.pth'))

mask = get_segmentation_mask(model, image_path)


# Create the final mask where tongue region is white and background is black

final_mask = np.zeros_like(mask)

final_mask[mask == 0] = 255
final_mask = Image.fromarray(final_mask.astype(np.uint8))

# 叠加原图
image = Image.open(image_path).convert('RGB')
image = image.resize(final_mask.size)
image.paste(final_mask, (0, 0), final_mask)
image

In [None]:
import numpy as np
import torch

# 生成随机的形状为 (2, 2, 3, 3) 的张量，值在 0 到 1 之间
random_tensor = np.random.rand(2, 2, 3, 3)

print(random_tensor)

# 将 NumPy 数组转换为 PyTorch 张量
random_tensor = torch.tensor(random_tensor, dtype=torch.float32)
print(torch.sigmoid(random_tensor))
print(torch.softmax(random_tensor, dim=1))
print(torch.argmax(random_tensor, dim=1, keepdim=True))