## 数据预处理

In [6]:
import os
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np


seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)


class TongueDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_dir = os.path.join(root_dir, 'images')
        self.mask_dir = os.path.join(root_dir, 'annotations')
        self.class_names = self.load_class_names()

    def load_class_names(self):
        class_names_file = os.path.join(self.root_dir, 'class_names.txt')
        with open(class_names_file, 'r') as f:
            class_names = f.readlines()
        class_names = [name.strip() for name in class_names]
        return class_names

    def __len__(self):
        return len(os.listdir(self.image_dir))

    def __getitem__(self, idx):
        img_name = os.listdir(self.image_dir)[idx]
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(
            self.mask_dir, os.path.splitext(img_name)[0] + '.png')

        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path)

        if self.transform:
            image = self.transform(image)
            mask = transforms.Resize((1000, 1000))(mask)
            mask = torch.tensor(np.array(mask), dtype=torch.float32)

        return image, mask


data_transform = transforms.Compose([
    transforms.Resize((1000, 1000)),
    transforms.ToTensor()
])

dataset = TongueDataset(root_dir='./data/split_dataset_ultra', transform=data_transform)

print(len(dataset))
print(dataset.class_names)
print(dataset[0][0].shape)
print(dataset[0][1].shape)
print(dataset[0][1].unique())
print(dataset[0][1])

1000
['_background_', 'Tg']
torch.Size([3, 1000, 1000])
torch.Size([1000, 1000])
tensor([0., 1.])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


In [7]:
from torch.utils.data import random_split


aug_transform = transforms.Compose([
    transforms.RandomApply(transforms.ColorJitter(brightness=0.2, contrast=0.2,
                           saturation=0.2, hue=0.1)),
    transforms.RandomApply(transforms.GaussianBlur(3, sigma=(0.1, 2.0)), 0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
                             0.229, 0.224, 0.225]),
])

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
                             0.229, 0.224, 0.225]),
])

# 划分数据集
train_size = int(0.6 * len(dataset))
val_size = int(0.2 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
train_dataloader.dataset.transform = aug_transform
for i, (img, mask) in enumerate(train_dataloader):
    print(img.shape, mask.shape)
    print(mask.unique())
    break
val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=False)
val_dataloader.dataset.transform = transform
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False)
test_dataloader.dataset.transform = transform

torch.Size([2, 3, 1000, 1000]) torch.Size([2, 1000, 1000])
tensor([0., 1.])


## 定义网络

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
import numpy as np
from PIL import Image
from torchvision import transforms
import torch.nn.functional as F


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# 定义基于VGG16的FCN网络
class VGG16_FCN(nn.Module):
    def __init__(self, num_classes):
        super(VGG16_FCN, self).__init__()
        # 加载预训练的VGG16模型
        vgg16 = models.vgg16(
            weights=models.VGG16_Weights.DEFAULT)

        # 取出VGG16的前面部分（去掉全连接层）
        self.features = vgg16.features

        # 用1x1卷积替换VGG16的全连接层
        self.conv6 = nn.Conv2d(512, 4096, kernel_size=5, padding=2)
        self.relu6 = nn.ReLU(inplace=True)
        self.dropout6 = nn.Dropout2d()
        self.conv7 = nn.Conv2d(4096, 4096, kernel_size=3, padding=1)
        self.relu7 = nn.ReLU(inplace=True)
        self.dropout7 = nn.Dropout2d()

        # 最后的卷积层用于生成分割结果
        self.score_fr = nn.Conv2d(4096, num_classes, kernel_size=1)
        self.upscore = nn.ConvTranspose2d(
            num_classes, num_classes, kernel_size=64, stride=32, padding=16, bias=False)

        # 初始化上采样层权重
        self.upscore.weight.data.fill_(0)
        self.upscore.weight.data[:, :, 16, 16] = 1  # 双线性插值

    def forward(self, x):
        # 前向传播
        x = self.features(x)
        x = self.relu6(self.conv6(x))
        x = self.dropout6(x)
        x = self.relu7(self.conv7(x))
        x = self.dropout7(x)
        x = self.score_fr(x)
        x = self.upscore(x)
        return x
    

def convert_to_one_hot(labels):
    # 获取原始标签中的类别数
    num_classes = len(torch.unique(labels))

    # 创建一个大小为 (B, C, W, H) 的零张量，其中 C 是类别数
    one_hot_labels = torch.zeros(labels.size(
        0), num_classes, *labels.size()[1:])

    # 遍历每个样本的原始标签，并将相应位置的值设置为 1
    for i in range(num_classes):
        one_hot_labels[:, i, :, :] = (labels == i).float()

    return one_hot_labels


class SegmentationLoss(nn.Module):
    def __init__(self, weight_bce=0.5, weight_connectivity=0, weight_smoothness=0):
        super(SegmentationLoss, self).__init__()
        self.weight_bce = weight_bce
        self.weight_connectivity = weight_connectivity
        self.weight_smoothness = weight_smoothness
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.smooth = 1e-6

    def forward(self, logits, masks):
        # 计算 BCE Loss
        bce_loss = self.bce_loss(logits, masks)

        # 计算 Dice Loss
        probs = torch.sigmoid(logits)
        intersection = torch.sum(probs * masks)
        dice_loss = 1 - (2. * intersection + self.smooth) / \
            (torch.sum(probs) + torch.sum(masks) + self.smooth)
        # 计算连通性损失
        connectivity_loss = 0
        if self.weight_connectivity > 0:
            connectivity_loss = self.connectivity_loss(logits)

        # 计算平滑性损失
        smoothness_loss = 0
        if self.weight_smoothness > 0:
            smoothness_loss = self.smoothness_loss(logits)

        # 加权结合损失
        combined_loss = self.weight_bce * bce_loss + \
            (1 - self.weight_bce) * dice_loss + \
            self.weight_connectivity * connectivity_loss + \
            self.weight_smoothness * smoothness_loss

        return combined_loss

    # 连通性损失函数
    def connectivity_loss(self, mask):
        class_indices = torch.tensor(torch.argmax(
            mask, dim=1, keepdim=True), dtype=torch.float32).to(device)
        connected_components = torch.unique(class_indices)
        count = 0
        for value in connected_components:
            binary_mask = (class_indices == value).float()
            _, labels = binary_mask.view(1, -1).unique(dim=1, return_inverse=True)
            count += labels.max().item() + 1
        loss = count - 1
        return loss

    # 平滑性损失函数
    def smoothness_loss(self, mask):
        class_indices = torch.tensor(torch.argmax(
            mask, dim=1, keepdim=True), dtype=torch.float32).to(device)
        sobel_x = torch.abs(F.conv2d(class_indices, torch.tensor(
            [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device='cuda').unsqueeze(0).unsqueeze(0)))
        sobel_y = torch.abs(F.conv2d(class_indices, torch.tensor(
            [[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32, device='cuda').unsqueeze(0).unsqueeze(0)))
        edge_curvature = torch.mean(sobel_x + sobel_y)
        loss = edge_curvature
        return loss

In [9]:
from tqdm import tqdm


def train_model(model, train_loader, val_loader, optimizer, num_epochs=10):
    for epoch in range(num_epochs):
        # 训练模型
        model.train()
        train_running_loss = 0.0
        for i in range(3):
            if i == 0:
                criterion = SegmentationLoss()
            elif i == 1:
                criterion = SegmentationLoss(weight_connectivity=2)
            else:
                criterion = SegmentationLoss(weight_smoothness=2)
            for images, labels in tqdm(train_loader):
                images = images.to(device)
                optimizer.zero_grad()
                outputs = model(images)
                # print(outputs.size())
                outputs_resized = F.interpolate(outputs, size=(images.size(
                    2), images.size(3)), mode='bilinear', align_corners=True)
                labels_for_loss = convert_to_one_hot(labels).to(device)
                loss = criterion(outputs_resized, labels_for_loss)  # 根据需要定义损失函数
                # loss = criterion(outputs, labels_for_loss)
                loss.backward()
                # if i == 2:
                train_running_loss += loss.item()
                optimizer.step()
    
            # 打印每个epoch的损失
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_running_loss/len(train_loader)}")
            # torch.save(model.state_dict(), f'vgg16_fcn_{epoch+1}_{i+1}.pth')

        model.eval()
        val_running_loss = 0.0
        IOU = 0
        iou_class1 = 0
        iou_class2 = 0
        criterion = SegmentationLoss()
        for images, labels in tqdm(val_loader):
            images = images.to(device)
            labels = labels.to(device)
            with torch.no_grad():
                outputs = model(images)
            outputs_resized = F.interpolate(outputs, size=(images.size(
                2), images.size(3)), mode='bilinear', align_corners=True)
            one_hot_labels = convert_to_one_hot(labels).to(device)
            loss = criterion(outputs_resized, one_hot_labels)
            # loss = criterion(outputs, labels_for_loss)
            val_running_loss += loss.item()
            outputs = torch.sigmoid(outputs_resized)  # 使用sigmoid函数将输出限制在0到1之间

            # 分别取出两个类别的输出和标签
            outputs_class1 = outputs[:, 0, :, :]  # 第一个类别的输出
            outputs_class2 = outputs[:, 1, :, :]  # 第二个类别的输出
            labels_class1 = one_hot_labels[:, 0, :, :]    # 第一个类别的标签
            labels_class2 = one_hot_labels[:, 1, :, :]    # 第二个类别的标签

            # 计算第一个类别的交并比（IOU）
            intersection_class1 = torch.sum(outputs_class1 * labels_class1)
            union_class1 = torch.sum(outputs_class1) + torch.sum(labels_class1)
            temp_iou1 = (intersection_class1 + 1e-6) / (union_class1 - intersection_class1 + 1e-6)
            iou_class1 += temp_iou1

            # 计算第二个类别的交并比（IOU）
            intersection_class2 = torch.sum(outputs_class2 * labels_class2)
            union_class2 = torch.sum(outputs_class2) + torch.sum(labels_class2)
            temp_iou2 = (intersection_class2 + 1e-6) / (union_class2 - intersection_class2 + 1e-6)
            iou_class2 += temp_iou2

            # 计算平均IOU
            IOU += (temp_iou1 + temp_iou2) / 2

            # 输出遮罩
            outputs = torch.argmax(outputs, dim=1)
            mask1 = outputs[0].squeeze().cpu().numpy()
            final_mask1 = np.zeros_like(mask1)
            final_mask1[mask1 == 1] = 255
            final_mask1 = Image.fromarray(final_mask1.astype(np.uint8))
            final_mask1.save('mask1.png')
            label_mask1 = labels[0].squeeze().cpu().numpy()
            final_label_mask1 = np.zeros_like(label_mask1)
            final_label_mask1[label_mask1 == 1] = 255
            final_label_mask1 = Image.fromarray(final_label_mask1.astype(np.uint8))
            final_label_mask1.save('label_mask1.png')
            mask2 = outputs[1].squeeze().cpu().numpy()
            final_mask2 = np.zeros_like(mask2)
            final_mask2[mask2 == 1] = 255
            final_mask2 = Image.fromarray(final_mask2.astype(np.uint8))
            final_mask2.save('mask2.png')
            label_mask2 = labels[1].squeeze().cpu().numpy()
            final_label_mask2 = np.zeros_like(label_mask2)
            final_label_mask2[label_mask2 == 1] = 255
            final_label_mask2 = Image.fromarray(final_label_mask2.astype(np.uint8))
            final_label_mask2.save('label_mask2.png')

        
        print(
            f"Validation Loss: {val_running_loss/len(val_loader)}, mIOU: {IOU/len(val_loader)}, IOU Background: {iou_class1/len(val_loader)}, IOU Tongue: {iou_class2/len(val_loader)}")

In [10]:
num_classes = 2  # Background and tongue
model = VGG16_FCN(num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Train the model
train_model(model, train_dataloader, val_dataloader, optimizer)


100%|██████████| 300/300 [06:24<00:00,  1.28s/it]


Epoch 1/10, Loss: 0.15614316184694568


  class_indices = torch.tensor(torch.argmax(
100%|██████████| 300/300 [06:02<00:00,  1.21s/it]


Epoch 1/10, Loss: 6.2018163971044125


  class_indices = torch.tensor(torch.argmax(
100%|██████████| 300/300 [05:04<00:00,  1.02s/it]


Epoch 1/10, Loss: 6.269814861584455


100%|██████████| 100/100 [01:16<00:00,  1.31it/s]


Validation Loss: 0.025639284355565905, mIOU: 0.94784015417099, IOU Background: 0.9738190770149231, IOU Tongue: 0.9218613505363464


100%|██████████| 300/300 [05:06<00:00,  1.02s/it]


Epoch 2/10, Loss: 0.022381475743216774


100%|██████████| 300/300 [05:02<00:00,  1.01s/it]


Epoch 2/10, Loss: 6.033679944650891


100%|██████████| 300/300 [05:05<00:00,  1.02s/it]


Epoch 2/10, Loss: 6.089551434946867


100%|██████████| 100/100 [01:11<00:00,  1.40it/s]


Validation Loss: 0.016394456662237645, mIOU: 0.9663571119308472, IOU Background: 0.9835390448570251, IOU Tongue: 0.9491748809814453


100%|██████████| 300/300 [05:01<00:00,  1.00s/it]


Epoch 3/10, Loss: 0.025447364879461625


100%|██████████| 300/300 [05:19<00:00,  1.07s/it]


Epoch 3/10, Loss: 6.038619524074408


100%|██████████| 300/300 [05:20<00:00,  1.07s/it]


Epoch 3/10, Loss: 6.08311036701159


100%|██████████| 100/100 [01:11<00:00,  1.40it/s]


Validation Loss: 0.007345855028834194, mIOU: 0.9862232804298401, IOU Background: 0.9932237863540649, IOU Tongue: 0.9792233109474182


100%|██████████| 300/300 [05:03<00:00,  1.01s/it]


Epoch 4/10, Loss: 0.006819150667482366


100%|██████████| 300/300 [05:50<00:00,  1.17s/it]


Epoch 4/10, Loss: 6.01509184302995


100%|██████████| 300/300 [06:18<00:00,  1.26s/it]


Epoch 4/10, Loss: 6.07816922235225


100%|██████████| 100/100 [01:39<00:00,  1.00it/s]


Validation Loss: 0.021899611265398564, mIOU: 0.965167224407196, IOU Background: 0.982033371925354, IOU Tongue: 0.9483011364936829


100%|██████████| 300/300 [06:10<00:00,  1.24s/it]


Epoch 5/10, Loss: 0.011047916888880233


100%|██████████| 300/300 [05:24<00:00,  1.08s/it]


Epoch 5/10, Loss: 6.018418585935918


100%|██████████| 300/300 [05:03<00:00,  1.01s/it]


Epoch 5/10, Loss: 6.058987837537813


100%|██████████| 100/100 [01:12<00:00,  1.37it/s]


Validation Loss: 0.0073505320237018165, mIOU: 0.9870427250862122, IOU Background: 0.9935647249221802, IOU Tongue: 0.9805207848548889


100%|██████████| 300/300 [05:00<00:00,  1.00s/it]


Epoch 6/10, Loss: 0.005762331720131139


100%|██████████| 300/300 [05:04<00:00,  1.01s/it]


Epoch 6/10, Loss: 6.011235408222613


100%|██████████| 300/300 [05:00<00:00,  1.00s/it]


Epoch 6/10, Loss: 6.050891201947816


100%|██████████| 100/100 [01:11<00:00,  1.39it/s]


Validation Loss: 0.006209837822243571, mIOU: 0.9890509843826294, IOU Background: 0.9946458339691162, IOU Tongue: 0.9834563136100769


100%|██████████| 300/300 [05:32<00:00,  1.11s/it]


Epoch 7/10, Loss: 0.00481397620945548


100%|██████████| 300/300 [05:16<00:00,  1.05s/it]


Epoch 7/10, Loss: 6.009516108752384


100%|██████████| 300/300 [05:15<00:00,  1.05s/it]


Epoch 7/10, Loss: 6.0486799290427005


100%|██████████| 100/100 [01:11<00:00,  1.40it/s]


Validation Loss: 0.005712169385515153, mIOU: 0.9904630184173584, IOU Background: 0.9952514171600342, IOU Tongue: 0.9856749773025513


100%|██████████| 300/300 [04:58<00:00,  1.00it/s]


Epoch 8/10, Loss: 0.004458037110355993


100%|██████████| 300/300 [05:01<00:00,  1.00s/it]


Epoch 8/10, Loss: 6.058027822864242


100%|██████████| 300/300 [05:03<00:00,  1.01s/it]


Epoch 8/10, Loss: 6.127323911051887


100%|██████████| 100/100 [01:11<00:00,  1.39it/s]


Validation Loss: 0.010523166446946562, mIOU: 0.9818862676620483, IOU Background: 0.9909308552742004, IOU Tongue: 0.9728416800498962


100%|██████████| 300/300 [05:01<00:00,  1.01s/it]


Epoch 9/10, Loss: 0.009229930222500116


100%|██████████| 300/300 [05:54<00:00,  1.18s/it]


Epoch 9/10, Loss: 6.018838938057888


100%|██████████| 300/300 [05:44<00:00,  1.15s/it]


Epoch 9/10, Loss: 6.059588197897344


100%|██████████| 100/100 [01:41<00:00,  1.02s/it]


Validation Loss: 0.006989772343076766, mIOU: 0.9877044558525085, IOU Background: 0.9939173460006714, IOU Tongue: 0.9814914464950562


 15%|█▌        | 45/300 [00:49<04:18,  1.01s/it]

In [None]:
import torch
import torch.nn as nn
from torchvision import models
import numpy as np
from PIL import Image
from torchvision import transforms


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# 定义基于VGG16的FCN网络
class VGG16_FCN(nn.Module):
    def __init__(self, num_classes):
        super(VGG16_FCN, self).__init__()
        # 加载预训练的VGG16模型
        vgg16 = models.vgg16(
            weights=models.VGG16_Weights.DEFAULT)

        # 取出VGG16的前面部分（去掉全连接层）
        self.features = vgg16.features

        # 用1x1卷积替换VGG16的全连接层
        self.conv6 = nn.Conv2d(512, 4096, kernel_size=1)
        self.relu6 = nn.ReLU(inplace=True)
        self.dropout6 = nn.Dropout2d()
        self.conv7 = nn.Conv2d(4096, 4096, kernel_size=1)
        self.relu7 = nn.ReLU(inplace=True)
        self.dropout7 = nn.Dropout2d()

        # 最后的卷积层用于生成分割结果
        self.score_fr = nn.Conv2d(4096, num_classes, kernel_size=1)
        self.upscore = nn.ConvTranspose2d(
            num_classes, num_classes, kernel_size=64, stride=32, padding=16, bias=False)

        # 初始化上采样层权重
        self.upscore.weight.data.fill_(0)
        self.upscore.weight.data[:, :, 16, 16] = 1  # 双线性插值

    def forward(self, x):
        # 前向传播
        x = self.features(x)
        x = self.relu6(self.conv6(x))
        x = self.dropout6(x)
        x = self.relu7(self.conv7(x))
        x = self.dropout7(x)
        x = self.score_fr(x)
        x = self.upscore(x)
        return x  


def get_segmentation_mask(model, image_path):

    model.eval()

    image = Image.open(image_path).convert('RGB')

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
                             0.229, 0.224, 0.225])
    ])
    image = transform(image).unsqueeze(0)
    print(image.size())
    image = image.to(device)
    with torch.no_grad():
        output = model(image)
    output = torch.sigmoid(output)
    output = torch.argmax(output, dim=1)
    mask = output.squeeze().cpu().numpy()
    return mask

# Get segmentation mask for a sample image

image_path = './data/test/label_data/test1.jpg'  # Replace with your image path
# 载入模型权重
model = VGG16_FCN(num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.load_state_dict(torch.load('model.pth'))

mask = get_segmentation_mask(model, image_path)


# Create the final mask where tongue region is white and background is black

final_mask = np.zeros_like(mask)

final_mask[mask == 0] = 255
final_mask = Image.fromarray(final_mask.astype(np.uint8))

# 叠加原图
image = Image.open(image_path).convert('RGB')
image = image.resize(final_mask.size)
image.paste(final_mask, (0, 0), final_mask)
image

In [5]:
import numpy as np
import torch

# 生成随机的形状为 (2, 2, 3, 3) 的张量，值在 0 到 1 之间
random_tensor = np.random.rand(2, 2, 3, 3)

print(random_tensor)

# 将 NumPy 数组转换为 PyTorch 张量
random_tensor = torch.tensor(random_tensor, dtype=torch.float32)
print(torch.sigmoid(random_tensor))
print(torch.softmax(random_tensor, dim=1))
print(torch.argmax(random_tensor, dim=1, keepdim=True))

[[[[0.12864163 0.33627678 0.32833279]
   [0.48987763 0.58399996 0.01479922]
   [0.2404101  0.11303117 0.55008898]]

  [[0.63509489 0.93711281 0.04018763]
   [0.15410985 0.4658085  0.61758847]
   [0.79034205 0.93190566 0.43294675]]]


 [[[0.06077017 0.07451507 0.1792317 ]
   [0.86664552 0.00749065 0.53963023]
   [0.90220283 0.50632795 0.3555912 ]]

  [[0.58280628 0.51122657 0.54808611]
   [0.18869061 0.99385473 0.57085092]
   [0.98435916 0.46737864 0.85101047]]]]
tensor([[[[0.5321, 0.5833, 0.5814],
          [0.6201, 0.6420, 0.5037],
          [0.5598, 0.5282, 0.6342]],

         [[0.6536, 0.7185, 0.5100],
          [0.5385, 0.6144, 0.6497],
          [0.6879, 0.7175, 0.6066]]],


        [[[0.5152, 0.5186, 0.5447],
          [0.7040, 0.5019, 0.6317],
          [0.7114, 0.6239, 0.5880]],

         [[0.6417, 0.6251, 0.6337],
          [0.5470, 0.7298, 0.6390],
          [0.7280, 0.6148, 0.7008]]]])
tensor([[[[0.3760, 0.3542, 0.5715],
          [0.5832, 0.5295, 0.3537],
          [0.3659,