## 数据预处理

In [1]:
import os
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np


seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)


class TongueDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_dir = os.path.join(root_dir, 'images')
        self.mask_dir = os.path.join(root_dir, 'annotations')
        self.class_names = self.load_class_names()

    def load_class_names(self):
        class_names_file = os.path.join(self.root_dir, 'class_names.txt')
        with open(class_names_file, 'r') as f:
            class_names = f.readlines()
        class_names = [name.strip() for name in class_names]
        return class_names

    def __len__(self):
        return len(os.listdir(self.image_dir))

    def __getitem__(self, idx):
        img_name = os.listdir(self.image_dir)[idx]
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(
            self.mask_dir, os.path.splitext(img_name)[0] + '.png')

        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path)

        if self.transform:
            image = self.transform(image)
            mask = transforms.Resize((1000, 1000))(mask)
            mask = torch.tensor(np.array(mask), dtype=torch.float32)

        return image, mask


data_transform = transforms.Compose([
    transforms.Resize((1000, 1000)),
    transforms.ToTensor()
])

dataset = TongueDataset(root_dir='./data/split_dataset_ultra', transform=data_transform)

print(len(dataset))
print(dataset.class_names)
print(dataset[0][0].shape)
print(dataset[0][1].shape)
print(dataset[0][1].unique())
print(dataset[0][1])

1000
['_background_', 'Tg']
torch.Size([3, 1000, 1000])
torch.Size([1000, 1000])
tensor([0., 1.])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


In [2]:
from torch.utils.data import random_split


aug_transform = transforms.Compose([
    transforms.RandomApply(transforms.ColorJitter(brightness=0.2, contrast=0.2,
                           saturation=0.2, hue=0.1)),
    transforms.RandomApply(transforms.GaussianBlur(3, sigma=(0.1, 2.0)), 0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
                             0.229, 0.224, 0.225]),
])

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
                             0.229, 0.224, 0.225]),
])

# 划分数据集
train_size = int(0.6 * len(dataset))
val_size = int(0.2 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
train_dataloader.dataset.transform = aug_transform
for i, (img, mask) in enumerate(train_dataloader):
    print(img.shape, mask.shape)
    print(mask.unique())
    break
val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=False)
val_dataloader.dataset.transform = transform
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False)
test_dataloader.dataset.transform = transform

torch.Size([2, 3, 1000, 1000]) torch.Size([2, 1000, 1000])
tensor([0., 1.])


## 定义网络

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
import numpy as np
from PIL import Image
from torchvision import transforms
import torch.nn.functional as F
import cv2


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# 定义基于VGG16的FCN网络
class VGG16_FCN(nn.Module):
    def __init__(self, num_classes):
        super(VGG16_FCN, self).__init__()
        # 加载预训练的VGG16模型
        vgg16 = models.vgg16(
            weights=models.VGG16_Weights.DEFAULT)

        # 取出VGG16的前面部分（去掉全连接层）
        self.features = vgg16.features

        # 用1x1卷积替换VGG16的全连接层
        self.conv6 = nn.Conv2d(512, 4096, kernel_size=5, padding=2)
        self.relu6 = nn.ReLU(inplace=True)
        self.dropout6 = nn.Dropout2d()
        self.conv7 = nn.Conv2d(4096, 4096, kernel_size=3, padding=1)
        self.relu7 = nn.ReLU(inplace=True)
        self.dropout7 = nn.Dropout2d()

        # 最后的卷积层用于生成分割结果
        self.score_fr = nn.Conv2d(4096, num_classes, kernel_size=1)
        self.upscore = nn.ConvTranspose2d(
            num_classes, num_classes, kernel_size=64, stride=32, padding=16, bias=False)

        # 初始化上采样层权重
        self.upscore.weight.data.fill_(0)
        self.upscore.weight.data[:, :, 16, 16] = 1  # 双线性插值

    def forward(self, x):
        # 前向传播
        x = self.features(x)
        x = self.relu6(self.conv6(x))
        x = self.dropout6(x)
        x = self.relu7(self.conv7(x))
        x = self.dropout7(x)
        x = self.score_fr(x)
        x = self.upscore(x)
        return x
    

def convert_to_one_hot(labels):
    # 获取原始标签中的类别数
    num_classes = len(torch.unique(labels))

    # 创建一个大小为 (B, C, W, H) 的零张量，其中 C 是类别数
    one_hot_labels = torch.zeros(labels.size(
        0), num_classes, *labels.size()[1:])

    # 遍历每个样本的原始标签，并将相应位置的值设置为 1
    for i in range(num_classes):
        one_hot_labels[:, i, :, :] = (labels == i).float()

    return one_hot_labels


class SegmentationLoss(nn.Module):
    def __init__(self, weight_bce=0.5, weight_connectivity=0, weight_smoothness=0):
        super(SegmentationLoss, self).__init__()
        self.weight_bce = weight_bce
        self.weight_connectivity = weight_connectivity
        self.weight_smoothness = weight_smoothness
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.smooth = 1e-6

    def forward(self, logits, masks):
        # 计算 BCE Loss
        bce_loss = self.bce_loss(logits, masks)

        # 计算 Dice Loss
        probs = torch.sigmoid(logits)
        intersection = torch.sum(probs * masks)
        dice_loss = 1 - (2. * intersection + self.smooth) / \
            (torch.sum(probs) + torch.sum(masks) + self.smooth)
        # 计算连通性损失
        connectivity_loss = 0
        if self.weight_connectivity > 0:
            connectivity_loss = self.connectivity_loss(probs)

        # 计算平滑性损失
        smoothness_loss = 0
        if self.weight_smoothness > 0:
            smoothness_loss = self.smoothness_loss(probs)

        # 加权结合损失
        combined_loss = self.weight_bce * bce_loss + \
            (1 - self.weight_bce) * dice_loss + \
            self.weight_connectivity * connectivity_loss + \
            self.weight_smoothness * smoothness_loss

        return combined_loss

    # 连通性正则化
    def connectivity_loss(self, mask):
        class_indices = np.array(torch.argmax(mask, dim=1, keepdim=True).cpu(), dtype=np.uint8)
        class_indices = class_indices.reshape(class_indices.shape[0], class_indices.shape[2], class_indices.shape[3])
        contours, _ = zip(*[cv2.findContours(
            class_indice, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) for class_indice in class_indices])
        print(len(contours[0]), len(contours[1]))
        counts = [len(contour) for contour in contours]
        counts = sum(counts) 
        return counts - 2

    # 平滑性正则化
    def smoothness_loss(self, mask):
        class_indices = np.array(torch.argmax(
            mask, dim=1, keepdim=True).cpu(), dtype=np.uint8)
        class_indices = class_indices.reshape(
            class_indices.shape[0], class_indices.shape[2], class_indices.shape[3])
        contours, _ = zip(*[cv2.findContours(
            class_indice, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) for class_indice in class_indices])
        print(len(contours[0]), len(contours[1]))
        max_contour = [max(contour, key=cv2.contourArea) for contour in contours if len(contour) > 0]
        area = [cv2.contourArea(contour) for contour in max_contour]
        perimeter = [cv2.arcLength(contour, True) for contour in max_contour]
        division = [perimeter[i] * perimeter[i] * np.pi / (4 * area[i]) - 1
                    for i in range(len(area)) if area[i] > 0]
        return np.mean(division) if len(division) > 0 else 0

In [4]:
from tqdm import tqdm


def train_model(model, train_loader, val_loader, optimizer, num_epochs=25):
    for epoch in range(num_epochs):
        # 训练模型
        model.train()
        train_running_loss = 0.0
        for i in range(3):
            if i == 0:
                criterion = SegmentationLoss()
            elif i == 1:
                criterion = SegmentationLoss(weight_connectivity=1)
            else:
                criterion = SegmentationLoss(weight_smoothness=0.5)
            for images, labels in tqdm(train_loader):
                images = images.to(device)
                optimizer.zero_grad()
                outputs = model(images)
                # print(outputs.size())
                outputs_resized = F.interpolate(outputs, size=(images.size(
                    2), images.size(3)), mode='bilinear', align_corners=True)
                labels_for_loss = convert_to_one_hot(labels).to(device)
                loss = criterion(outputs_resized, labels_for_loss)  # 根据需要定义损失函数
                # loss = criterion(outputs, labels_for_loss)
                loss.backward()
                # if i == 2:
                train_running_loss += loss.item()
                optimizer.step()
    
            # 打印每个epoch的损失
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_running_loss/len(train_loader)}")
            torch.save(model.state_dict(), f'vgg16_fcn_optim_{epoch+1}_{i+1}.pth')

        model.eval()
        val_running_loss = 0.0
        IOU = 0
        iou_class1 = 0
        iou_class2 = 0
        criterion = SegmentationLoss()
        for images, labels in tqdm(val_loader):
            images = images.to(device)
            labels = labels.to(device)
            with torch.no_grad():
                outputs = model(images)
            outputs_resized = F.interpolate(outputs, size=(images.size(
                2), images.size(3)), mode='bilinear', align_corners=True)
            one_hot_labels = convert_to_one_hot(labels).to(device)
            loss = criterion(outputs_resized, one_hot_labels)
            # loss = criterion(outputs, labels_for_loss)
            val_running_loss += loss.item()
            outputs = torch.sigmoid(outputs_resized)  # 使用sigmoid函数将输出限制在0到1之间

            # 分别取出两个类别的输出和标签
            outputs_class1 = outputs[:, 0, :, :]  # 第一个类别的输出
            outputs_class2 = outputs[:, 1, :, :]  # 第二个类别的输出
            labels_class1 = one_hot_labels[:, 0, :, :]    # 第一个类别的标签
            labels_class2 = one_hot_labels[:, 1, :, :]    # 第二个类别的标签

            # 计算第一个类别的交并比（IOU）
            intersection_class1 = torch.sum(outputs_class1 * labels_class1)
            union_class1 = torch.sum(outputs_class1) + torch.sum(labels_class1)
            temp_iou1 = (intersection_class1 + 1e-6) / (union_class1 - intersection_class1 + 1e-6)
            iou_class1 += temp_iou1

            # 计算第二个类别的交并比（IOU）
            intersection_class2 = torch.sum(outputs_class2 * labels_class2)
            union_class2 = torch.sum(outputs_class2) + torch.sum(labels_class2)
            temp_iou2 = (intersection_class2 + 1e-6) / (union_class2 - intersection_class2 + 1e-6)
            iou_class2 += temp_iou2

            # 计算平均IOU
            IOU += (temp_iou1 + temp_iou2) / 2

            # 输出遮罩
            # outputs = torch.argmax(outputs, dim=1)
            # mask1 = outputs[0].squeeze().cpu().numpy()
            # final_mask1 = np.zeros_like(mask1)
            # final_mask1[mask1 == 1] = 255
            # final_mask1 = Image.fromarray(final_mask1.astype(np.uint8))
            # final_mask1.save('mask1.png')
            # label_mask1 = labels[0].squeeze().cpu().numpy()
            # final_label_mask1 = np.zeros_like(label_mask1)
            # final_label_mask1[label_mask1 == 1] = 255
            # final_label_mask1 = Image.fromarray(final_label_mask1.astype(np.uint8))
            # final_label_mask1.save('label_mask1.png')
            # mask2 = outputs[1].squeeze().cpu().numpy()
            # final_mask2 = np.zeros_like(mask2)
            # final_mask2[mask2 == 1] = 255
            # final_mask2 = Image.fromarray(final_mask2.astype(np.uint8))
            # final_mask2.save('mask2.png')
            # label_mask2 = labels[1].squeeze().cpu().numpy()
            # final_label_mask2 = np.zeros_like(label_mask2)
            # final_label_mask2[label_mask2 == 1] = 255
            # final_label_mask2 = Image.fromarray(final_label_mask2.astype(np.uint8))
            # final_label_mask2.save('label_mask2.png')

        
        print(
            f"Validation Loss: {val_running_loss/len(val_loader)}, mIOU: {IOU/len(val_loader)}, IOU Background: {iou_class1/len(val_loader)}, IOU Tongue: {iou_class2/len(val_loader)}")

In [5]:
num_classes = 2  # Background and tongue
model = VGG16_FCN(num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Train the model
train_model(model, train_dataloader, val_dataloader, optimizer)


100%|██████████| 300/300 [05:08<00:00,  1.03s/it]


Epoch 25/25, Loss: 0.003267749783893426


  0%|          | 0/300 [00:00<?, ?it/s]

1 1


  0%|          | 1/300 [00:01<05:23,  1.08s/it]

1 1


  1%|          | 2/300 [00:02<05:19,  1.07s/it]

1 1


  1%|          | 3/300 [00:03<05:19,  1.07s/it]

1 1


  1%|▏         | 4/300 [00:04<05:14,  1.06s/it]

1 2


  2%|▏         | 5/300 [00:05<05:10,  1.05s/it]

1 1


  2%|▏         | 6/300 [00:06<05:08,  1.05s/it]

1 1


  2%|▏         | 7/300 [00:07<05:06,  1.05s/it]

1 1


  3%|▎         | 8/300 [00:08<05:06,  1.05s/it]

1 1


  3%|▎         | 9/300 [00:09<05:04,  1.05s/it]

1 1


  3%|▎         | 10/300 [00:10<05:07,  1.06s/it]

2 1


  4%|▎         | 11/300 [00:11<05:05,  1.06s/it]

1 1


  4%|▍         | 12/300 [00:12<05:03,  1.05s/it]

1 1


  4%|▍         | 13/300 [00:13<05:05,  1.06s/it]

1 1


  5%|▍         | 14/300 [00:14<05:00,  1.05s/it]

1 1


  5%|▌         | 15/300 [00:15<04:55,  1.04s/it]

1 2


  5%|▌         | 16/300 [00:16<04:52,  1.03s/it]

1 1


  6%|▌         | 17/300 [00:17<04:50,  1.03s/it]

1 1


  6%|▌         | 18/300 [00:18<04:51,  1.03s/it]

1 1


  6%|▋         | 19/300 [00:19<04:49,  1.03s/it]

1 1


  7%|▋         | 20/300 [00:20<04:47,  1.03s/it]

1 1


  7%|▋         | 21/300 [00:21<04:45,  1.02s/it]

1 1


  7%|▋         | 22/300 [00:22<04:45,  1.03s/it]

1 1


  8%|▊         | 23/300 [00:23<04:43,  1.03s/it]

1 2


  8%|▊         | 24/300 [00:25<04:43,  1.03s/it]

1 1


  8%|▊         | 25/300 [00:26<04:41,  1.02s/it]

1 1


  9%|▊         | 26/300 [00:27<04:38,  1.02s/it]

1 1


  9%|▉         | 27/300 [00:28<04:34,  1.01s/it]

3 1


  9%|▉         | 28/300 [00:28<04:30,  1.00it/s]

1 1


 10%|▉         | 29/300 [00:30<04:36,  1.02s/it]

2 1


 10%|█         | 30/300 [00:31<04:35,  1.02s/it]

1 1


 10%|█         | 31/300 [00:32<04:36,  1.03s/it]

2 1


 11%|█         | 32/300 [00:33<04:35,  1.03s/it]

1 1


 11%|█         | 33/300 [00:34<04:35,  1.03s/it]

1 2


 11%|█▏        | 34/300 [00:35<04:32,  1.02s/it]

1 1


 12%|█▏        | 35/300 [00:36<04:29,  1.02s/it]

1 1


 12%|█▏        | 36/300 [00:37<04:27,  1.01s/it]

1 1


 12%|█▏        | 37/300 [00:38<04:23,  1.00s/it]

1 1


 13%|█▎        | 38/300 [00:39<04:22,  1.00s/it]

1 1


 13%|█▎        | 39/300 [00:40<04:23,  1.01s/it]

1 2


 13%|█▎        | 40/300 [00:41<04:25,  1.02s/it]

1 1


 14%|█▎        | 41/300 [00:42<04:22,  1.01s/it]

1 1


 14%|█▍        | 42/300 [00:43<04:21,  1.01s/it]

2 1


 14%|█▍        | 43/300 [00:44<04:18,  1.01s/it]

1 1


 15%|█▍        | 44/300 [00:45<04:18,  1.01s/it]

1 1


 15%|█▌        | 45/300 [00:46<04:17,  1.01s/it]

1 1


 15%|█▌        | 46/300 [00:47<04:16,  1.01s/it]

2 1


 16%|█▌        | 47/300 [00:48<04:16,  1.01s/it]

1 1


 16%|█▌        | 48/300 [00:49<04:15,  1.01s/it]

1 1


 16%|█▋        | 49/300 [00:50<04:15,  1.02s/it]

1 1


 17%|█▋        | 50/300 [00:51<04:13,  1.01s/it]

1 2


 17%|█▋        | 51/300 [00:52<04:11,  1.01s/it]

1 1


 17%|█▋        | 52/300 [00:53<04:12,  1.02s/it]

1 2


 18%|█▊        | 53/300 [00:54<04:09,  1.01s/it]

1 1


 18%|█▊        | 54/300 [00:55<04:09,  1.01s/it]

1 1


 18%|█▊        | 55/300 [00:56<04:06,  1.01s/it]

1 1


 19%|█▊        | 56/300 [00:57<04:04,  1.00s/it]

1 1


 19%|█▉        | 57/300 [00:58<04:04,  1.01s/it]

1 1


 19%|█▉        | 58/300 [00:59<04:02,  1.00s/it]

1 2


 20%|█▉        | 59/300 [01:00<04:04,  1.01s/it]

1 1


 20%|██        | 60/300 [01:01<04:02,  1.01s/it]

4 1


 20%|██        | 61/300 [01:02<04:00,  1.01s/it]

1 1


 21%|██        | 62/300 [01:03<03:57,  1.00it/s]

1 1


 21%|██        | 63/300 [01:04<03:56,  1.00it/s]

1 1


 21%|██▏       | 64/300 [01:05<03:55,  1.00it/s]

1 1


 22%|██▏       | 65/300 [01:06<03:54,  1.00it/s]

1 1


 22%|██▏       | 66/300 [01:07<03:55,  1.01s/it]

1 1


 22%|██▏       | 67/300 [01:08<03:53,  1.00s/it]

1 1


 23%|██▎       | 68/300 [01:09<03:52,  1.00s/it]

1 1


 23%|██▎       | 69/300 [01:10<03:50,  1.00it/s]

1 1


 23%|██▎       | 70/300 [01:11<03:50,  1.00s/it]

1 1


 24%|██▎       | 71/300 [01:12<03:50,  1.00s/it]

1 1


 24%|██▍       | 72/300 [01:13<03:47,  1.00it/s]

1 1


 24%|██▍       | 73/300 [01:14<03:45,  1.00it/s]

1 1


 25%|██▍       | 74/300 [01:15<03:44,  1.01it/s]

1 5


 25%|██▌       | 75/300 [01:16<03:46,  1.01s/it]

3 2


 25%|██▌       | 76/300 [01:17<03:46,  1.01s/it]

1 1


 26%|██▌       | 77/300 [01:18<03:43,  1.00s/it]

1 1


 26%|██▌       | 78/300 [01:19<03:43,  1.01s/it]

1 4


 26%|██▋       | 79/300 [01:20<03:42,  1.01s/it]

1 1


 27%|██▋       | 80/300 [01:21<03:43,  1.02s/it]

1 1


 27%|██▋       | 81/300 [01:22<03:39,  1.00s/it]

1 1


 27%|██▋       | 82/300 [01:23<03:41,  1.01s/it]

1 4


 28%|██▊       | 83/300 [01:24<03:39,  1.01s/it]

1 1


 28%|██▊       | 84/300 [01:25<03:40,  1.02s/it]

1 1


 28%|██▊       | 85/300 [01:26<03:36,  1.01s/it]

6 4


 29%|██▊       | 86/300 [01:27<03:35,  1.01s/it]

1 1


 29%|██▉       | 87/300 [01:28<03:35,  1.01s/it]

1 1


 29%|██▉       | 88/300 [01:29<03:35,  1.01s/it]

6 4


 30%|██▉       | 89/300 [01:30<03:33,  1.01s/it]

2 1


 30%|███       | 90/300 [01:31<03:33,  1.02s/it]

8 8


 30%|███       | 91/300 [01:32<03:32,  1.02s/it]

29 21


 31%|███       | 92/300 [01:33<03:30,  1.01s/it]

10 11


 31%|███       | 93/300 [01:34<03:30,  1.02s/it]

33 18


 31%|███▏      | 94/300 [01:35<03:29,  1.02s/it]

29 32


 32%|███▏      | 95/300 [01:36<03:28,  1.02s/it]

35 41


 32%|███▏      | 96/300 [01:37<03:31,  1.04s/it]

35 21


 32%|███▏      | 97/300 [01:38<03:28,  1.03s/it]

31 55


 33%|███▎      | 98/300 [01:39<03:27,  1.03s/it]

31 68


 33%|███▎      | 99/300 [01:40<03:23,  1.01s/it]

40 103


 33%|███▎      | 100/300 [01:41<03:24,  1.02s/it]

63 56


 34%|███▎      | 101/300 [01:42<03:26,  1.04s/it]

31 44


 34%|███▍      | 102/300 [01:43<03:25,  1.04s/it]

27 139


 34%|███▍      | 103/300 [01:45<03:25,  1.04s/it]

49 35


 35%|███▍      | 104/300 [01:46<03:24,  1.04s/it]

43 35


 35%|███▌      | 105/300 [01:47<03:23,  1.04s/it]

76 63


 35%|███▌      | 106/300 [01:48<03:23,  1.05s/it]

46 20


 36%|███▌      | 107/300 [01:49<03:21,  1.04s/it]

18 38


 36%|███▌      | 108/300 [01:50<03:19,  1.04s/it]

22 50


 36%|███▋      | 109/300 [01:51<03:18,  1.04s/it]

44 29


 37%|███▋      | 110/300 [01:52<03:16,  1.03s/it]

34 31


 37%|███▋      | 111/300 [01:53<03:12,  1.02s/it]

14 21


 37%|███▋      | 112/300 [01:54<03:14,  1.03s/it]

54 41


 38%|███▊      | 113/300 [01:55<03:14,  1.04s/it]

35 20


 38%|███▊      | 114/300 [01:56<03:11,  1.03s/it]

15 20


 38%|███▊      | 115/300 [01:57<03:08,  1.02s/it]

15 33


 39%|███▊      | 116/300 [01:58<03:10,  1.04s/it]

29 23


 39%|███▉      | 117/300 [01:59<03:07,  1.02s/it]

33 32


 39%|███▉      | 118/300 [02:00<03:05,  1.02s/it]

21 129


 40%|███▉      | 119/300 [02:01<03:04,  1.02s/it]

24 46


 40%|████      | 120/300 [02:02<03:04,  1.02s/it]

85 40


 40%|████      | 121/300 [02:03<03:03,  1.02s/it]

81 87


 41%|████      | 122/300 [02:04<03:00,  1.02s/it]

18 58


 41%|████      | 123/300 [02:05<02:59,  1.01s/it]

51 55


 41%|████▏     | 124/300 [02:06<02:57,  1.01s/it]

40 142


 42%|████▏     | 125/300 [02:07<02:57,  1.01s/it]

65 36


 42%|████▏     | 126/300 [02:08<02:55,  1.01s/it]

42 38


 42%|████▏     | 127/300 [02:09<02:53,  1.00s/it]

42 38


 43%|████▎     | 128/300 [02:10<02:54,  1.02s/it]

36 26


 43%|████▎     | 129/300 [02:11<02:53,  1.01s/it]

28 29


 43%|████▎     | 130/300 [02:12<02:54,  1.03s/it]

28 38


 44%|████▎     | 131/300 [02:13<02:53,  1.03s/it]

25 46


 44%|████▍     | 132/300 [02:14<02:52,  1.03s/it]

66 19


 44%|████▍     | 133/300 [02:15<02:52,  1.03s/it]

50 27


 45%|████▍     | 134/300 [02:16<02:48,  1.01s/it]

25 7


 45%|████▌     | 135/300 [02:17<02:47,  1.01s/it]

39 88


 45%|████▌     | 136/300 [02:18<02:46,  1.02s/it]

28 35


 46%|████▌     | 137/300 [02:19<02:45,  1.02s/it]

13 16


 46%|████▌     | 138/300 [02:20<02:44,  1.01s/it]

16 45


 46%|████▋     | 139/300 [02:21<02:42,  1.01s/it]

13 10


 47%|████▋     | 140/300 [02:22<02:43,  1.02s/it]

8 15


 47%|████▋     | 141/300 [02:23<02:42,  1.02s/it]

15 11


 47%|████▋     | 142/300 [02:24<02:40,  1.02s/it]

35 35


 48%|████▊     | 143/300 [02:25<02:40,  1.02s/it]

62 52


 48%|████▊     | 144/300 [02:26<02:38,  1.02s/it]

50 39


 48%|████▊     | 145/300 [02:27<02:37,  1.02s/it]

85 52


 49%|████▊     | 146/300 [02:28<02:35,  1.01s/it]

134 117


 49%|████▉     | 147/300 [02:29<02:35,  1.02s/it]

81 104


 49%|████▉     | 148/300 [02:31<02:36,  1.03s/it]

34 57


 50%|████▉     | 149/300 [02:32<02:33,  1.02s/it]

9 16


 50%|█████     | 150/300 [02:33<02:33,  1.02s/it]

6 51


 50%|█████     | 151/300 [02:34<02:33,  1.03s/it]

20 7


 51%|█████     | 152/300 [02:35<02:33,  1.04s/it]

11 16


 51%|█████     | 153/300 [02:36<02:33,  1.04s/it]

13 30


 51%|█████▏    | 154/300 [02:37<02:31,  1.04s/it]

58 13


 52%|█████▏    | 155/300 [02:38<02:29,  1.03s/it]

17 28


 52%|█████▏    | 156/300 [02:39<02:28,  1.03s/it]

55 30


 52%|█████▏    | 157/300 [02:40<02:27,  1.03s/it]

19 42


 53%|█████▎    | 158/300 [02:41<02:25,  1.02s/it]

28 16


 53%|█████▎    | 159/300 [02:42<02:21,  1.01s/it]

37 49


 53%|█████▎    | 160/300 [02:43<02:20,  1.01s/it]

34 48


 54%|█████▎    | 161/300 [02:44<02:19,  1.01s/it]

22 14


 54%|█████▍    | 162/300 [02:45<02:18,  1.01s/it]

28 21


 54%|█████▍    | 163/300 [02:46<02:17,  1.00s/it]

20 23


 55%|█████▍    | 164/300 [02:47<02:15,  1.00it/s]

36 23


 55%|█████▌    | 165/300 [02:48<02:15,  1.00s/it]

19 10


 55%|█████▌    | 166/300 [02:49<02:14,  1.00s/it]

8 23


 56%|█████▌    | 167/300 [02:50<02:13,  1.01s/it]

34 23


 56%|█████▌    | 168/300 [02:51<02:12,  1.00s/it]

14 7


 56%|█████▋    | 169/300 [02:52<02:11,  1.00s/it]

12 50


 57%|█████▋    | 170/300 [02:53<02:11,  1.01s/it]

6 6


 57%|█████▋    | 171/300 [02:54<02:11,  1.02s/it]

43 20


 57%|█████▋    | 172/300 [02:55<02:10,  1.02s/it]

9 15


 58%|█████▊    | 173/300 [02:56<02:08,  1.01s/it]

20 9


 58%|█████▊    | 174/300 [02:57<02:08,  1.02s/it]

3 8


 58%|█████▊    | 175/300 [02:58<02:05,  1.01s/it]

7 17


 59%|█████▊    | 176/300 [02:59<02:04,  1.00s/it]

9 17


 59%|█████▉    | 177/300 [03:00<02:03,  1.00s/it]

15 8


 59%|█████▉    | 178/300 [03:01<02:01,  1.00it/s]

5 8


 60%|█████▉    | 179/300 [03:02<02:00,  1.00it/s]

4 23


 60%|██████    | 180/300 [03:03<01:59,  1.00it/s]

7 10


 60%|██████    | 181/300 [03:04<01:58,  1.00it/s]

5 11


 61%|██████    | 182/300 [03:05<01:57,  1.00it/s]

19 3


 61%|██████    | 183/300 [03:06<01:57,  1.00s/it]

13 16


 61%|██████▏   | 184/300 [03:07<01:55,  1.00it/s]

14 31


 62%|██████▏   | 185/300 [03:08<01:54,  1.00it/s]

29 9


 62%|██████▏   | 186/300 [03:09<01:53,  1.00it/s]

10 17


 62%|██████▏   | 187/300 [03:10<01:53,  1.01s/it]

11 28


 63%|██████▎   | 188/300 [03:11<01:53,  1.01s/it]

22 24


 63%|██████▎   | 189/300 [03:12<01:52,  1.01s/it]

19 13


 63%|██████▎   | 190/300 [03:13<01:51,  1.01s/it]

5 11


 64%|██████▎   | 191/300 [03:14<01:49,  1.00s/it]

5 8


 64%|██████▍   | 192/300 [03:15<01:47,  1.00it/s]

13 3


 64%|██████▍   | 193/300 [03:16<01:47,  1.00s/it]

19 11


 65%|██████▍   | 194/300 [03:17<01:45,  1.00it/s]

8 8


 65%|██████▌   | 195/300 [03:18<01:45,  1.00s/it]

14 11


 65%|██████▌   | 196/300 [03:19<01:44,  1.01s/it]

9 4


 66%|██████▌   | 197/300 [03:20<01:44,  1.01s/it]

11 2


 66%|██████▌   | 198/300 [03:21<01:41,  1.00it/s]

15 6


 66%|██████▋   | 199/300 [03:22<01:41,  1.00s/it]

8 5


 67%|██████▋   | 200/300 [03:23<01:39,  1.01it/s]

7 1


 67%|██████▋   | 201/300 [03:24<01:38,  1.01it/s]

9 5


 67%|██████▋   | 202/300 [03:25<01:37,  1.01it/s]

8 5


 68%|██████▊   | 203/300 [03:26<01:36,  1.01it/s]

12 4


 68%|██████▊   | 204/300 [03:27<01:34,  1.01it/s]

9 7


 68%|██████▊   | 205/300 [03:28<01:34,  1.01it/s]

24 1


 69%|██████▊   | 206/300 [03:29<01:33,  1.01it/s]

35 8


 69%|██████▉   | 207/300 [03:30<01:32,  1.01it/s]

64 15


 69%|██████▉   | 208/300 [03:31<01:31,  1.01it/s]

22 15


 70%|██████▉   | 209/300 [03:32<01:30,  1.00it/s]

15 20


 70%|███████   | 210/300 [03:33<01:29,  1.00it/s]

21 29


 70%|███████   | 211/300 [03:34<01:29,  1.00s/it]

26 19


 71%|███████   | 212/300 [03:35<01:28,  1.00s/it]

24 35


 71%|███████   | 213/300 [03:36<01:27,  1.00s/it]

30 12


 71%|███████▏  | 214/300 [03:37<01:25,  1.00it/s]

26 30


 72%|███████▏  | 215/300 [03:38<01:25,  1.00s/it]

20 14


 72%|███████▏  | 216/300 [03:39<01:24,  1.01s/it]

27 19


 72%|███████▏  | 217/300 [03:40<01:24,  1.01s/it]

9 20


 73%|███████▎  | 218/300 [03:41<01:22,  1.01s/it]

9 12


 73%|███████▎  | 219/300 [03:42<01:21,  1.01s/it]

7 15


 73%|███████▎  | 220/300 [03:43<01:21,  1.02s/it]

5 8


 74%|███████▎  | 221/300 [03:44<01:21,  1.03s/it]

5 3


 74%|███████▍  | 222/300 [03:45<01:20,  1.03s/it]

9 4


 74%|███████▍  | 223/300 [03:46<01:19,  1.03s/it]

36 9


 75%|███████▍  | 224/300 [03:47<01:18,  1.03s/it]

3 1


 75%|███████▌  | 225/300 [03:48<01:17,  1.04s/it]

5 38


 75%|███████▌  | 226/300 [03:49<01:15,  1.02s/it]

8 2


 76%|███████▌  | 227/300 [03:50<01:14,  1.02s/it]

4 8


 76%|███████▌  | 228/300 [03:51<01:13,  1.03s/it]

6 6


 76%|███████▋  | 229/300 [03:52<01:12,  1.02s/it]

4 5


 77%|███████▋  | 230/300 [03:53<01:11,  1.03s/it]

7 9


 77%|███████▋  | 231/300 [03:54<01:10,  1.03s/it]

21 9


 77%|███████▋  | 232/300 [03:55<01:09,  1.02s/it]

7 16


 78%|███████▊  | 233/300 [03:56<01:07,  1.01s/it]

6 5


 78%|███████▊  | 234/300 [03:57<01:07,  1.02s/it]

16 4


 78%|███████▊  | 235/300 [03:58<01:06,  1.02s/it]

11 8


 79%|███████▊  | 236/300 [03:59<01:05,  1.02s/it]

16 47


 79%|███████▉  | 237/300 [04:00<01:04,  1.02s/it]

7 10


 79%|███████▉  | 238/300 [04:01<01:03,  1.02s/it]

15 3


 80%|███████▉  | 239/300 [04:02<01:02,  1.02s/it]

1 3


 80%|████████  | 240/300 [04:03<01:01,  1.03s/it]

2 1


 80%|████████  | 241/300 [04:04<01:00,  1.02s/it]

2 4


 81%|████████  | 242/300 [04:05<00:58,  1.02s/it]

4 2


 81%|████████  | 243/300 [04:06<00:57,  1.01s/it]

3 1


 81%|████████▏ | 244/300 [04:07<00:56,  1.01s/it]

5 3


 82%|████████▏ | 245/300 [04:08<00:55,  1.01s/it]

1 7


 82%|████████▏ | 246/300 [04:09<00:54,  1.01s/it]

3 1


 82%|████████▏ | 247/300 [04:10<00:53,  1.02s/it]

2 1


 83%|████████▎ | 248/300 [04:11<00:52,  1.01s/it]

6 2


 83%|████████▎ | 249/300 [04:13<00:51,  1.01s/it]

1 2


 83%|████████▎ | 250/300 [04:14<00:51,  1.02s/it]

6 6


 84%|████████▎ | 251/300 [04:15<00:49,  1.01s/it]

3 3


 84%|████████▍ | 252/300 [04:16<00:48,  1.01s/it]

2 3


 84%|████████▍ | 253/300 [04:17<00:47,  1.01s/it]

11 1


 85%|████████▍ | 254/300 [04:18<00:46,  1.00s/it]

10 2


 85%|████████▌ | 255/300 [04:19<00:45,  1.01s/it]

2 10


 85%|████████▌ | 256/300 [04:20<00:44,  1.00s/it]

12 3


 86%|████████▌ | 257/300 [04:21<00:43,  1.00s/it]

3 3


 86%|████████▌ | 258/300 [04:22<00:42,  1.01s/it]

2 3


 86%|████████▋ | 259/300 [04:23<00:40,  1.01it/s]

4 6


 87%|████████▋ | 260/300 [04:24<00:39,  1.00it/s]

8 1


 87%|████████▋ | 261/300 [04:25<00:39,  1.00s/it]

1 6


 87%|████████▋ | 262/300 [04:26<00:37,  1.00it/s]

3 1


 88%|████████▊ | 263/300 [04:27<00:37,  1.00s/it]

10 1


 88%|████████▊ | 264/300 [04:28<00:36,  1.00s/it]

3 14


 88%|████████▊ | 265/300 [04:29<00:34,  1.00it/s]

2 1


 89%|████████▊ | 266/300 [04:30<00:34,  1.00s/it]

2 20


 89%|████████▉ | 267/300 [04:31<00:33,  1.00s/it]

2 1


 89%|████████▉ | 268/300 [04:32<00:32,  1.00s/it]

4 2


 90%|████████▉ | 269/300 [04:33<00:31,  1.01s/it]

7 1


 90%|█████████ | 270/300 [04:34<00:30,  1.01s/it]

3 3


 90%|█████████ | 271/300 [04:35<00:29,  1.01s/it]

1 3


 91%|█████████ | 272/300 [04:36<00:28,  1.01s/it]

1 1


 91%|█████████ | 273/300 [04:37<00:27,  1.00s/it]

2 6


 91%|█████████▏| 274/300 [04:38<00:26,  1.01s/it]

3 18


 92%|█████████▏| 275/300 [04:39<00:24,  1.00it/s]

9 21


 92%|█████████▏| 276/300 [04:40<00:23,  1.01it/s]

1 3


 92%|█████████▏| 277/300 [04:41<00:23,  1.00s/it]

7 1


 93%|█████████▎| 278/300 [04:42<00:21,  1.00it/s]

6 2


 93%|█████████▎| 279/300 [04:43<00:20,  1.00it/s]

1 1


 93%|█████████▎| 280/300 [04:44<00:19,  1.00it/s]

1 2


 94%|█████████▎| 281/300 [04:45<00:18,  1.01it/s]

14 2


 94%|█████████▍| 282/300 [04:46<00:18,  1.00s/it]

1 2


 94%|█████████▍| 283/300 [04:47<00:17,  1.00s/it]

5 18


 95%|█████████▍| 284/300 [04:48<00:16,  1.00s/it]

2 17


 95%|█████████▌| 285/300 [04:49<00:15,  1.00s/it]

3 1


 95%|█████████▌| 286/300 [04:50<00:14,  1.00s/it]

2 5


 96%|█████████▌| 287/300 [04:51<00:13,  1.01s/it]

4 8


 96%|█████████▌| 288/300 [04:52<00:12,  1.02s/it]

1 4


 96%|█████████▋| 289/300 [04:53<00:11,  1.01s/it]

2 2


 97%|█████████▋| 290/300 [04:54<00:10,  1.00s/it]

1 1


 97%|█████████▋| 291/300 [04:55<00:09,  1.00s/it]

10 3


 97%|█████████▋| 292/300 [04:56<00:08,  1.00s/it]

2 1


 98%|█████████▊| 293/300 [04:57<00:07,  1.01s/it]

3 2


 98%|█████████▊| 294/300 [04:58<00:06,  1.02s/it]

1 1


 98%|█████████▊| 295/300 [04:59<00:05,  1.01s/it]

6 3


 99%|█████████▊| 296/300 [05:00<00:04,  1.01s/it]

2 2


 99%|█████████▉| 297/300 [05:01<00:03,  1.01s/it]

9 1


 99%|█████████▉| 298/300 [05:02<00:02,  1.01s/it]

1 4


100%|█████████▉| 299/300 [05:03<00:01,  1.02s/it]

2 3


100%|██████████| 300/300 [05:04<00:00,  1.01s/it]


Epoch 25/25, Loss: 26.327596945232944


  0%|          | 0/300 [00:00<?, ?it/s]

2 3


  0%|          | 1/300 [00:01<05:22,  1.08s/it]

1 19


  1%|          | 2/300 [00:02<05:05,  1.03s/it]

1 2


  1%|          | 3/300 [00:03<04:58,  1.01s/it]

5 1


  1%|▏         | 4/300 [00:04<04:56,  1.00s/it]

2 6


  2%|▏         | 5/300 [00:05<04:54,  1.00it/s]

2 1


  2%|▏         | 6/300 [00:06<04:52,  1.00it/s]

3 2


  2%|▏         | 7/300 [00:07<04:52,  1.00it/s]

3 3


  3%|▎         | 8/300 [00:08<04:51,  1.00it/s]

2 2


  3%|▎         | 9/300 [00:09<04:51,  1.00s/it]

1 2


  3%|▎         | 10/300 [00:10<04:50,  1.00s/it]

10 1


  4%|▎         | 11/300 [00:11<04:48,  1.00it/s]

2 5


  4%|▍         | 12/300 [00:12<04:49,  1.01s/it]

5 4


  4%|▍         | 13/300 [00:13<04:47,  1.00s/it]

1 3


  5%|▍         | 14/300 [00:14<04:48,  1.01s/it]

3 2


  5%|▌         | 15/300 [00:15<04:50,  1.02s/it]

1 3


  5%|▌         | 16/300 [00:16<04:49,  1.02s/it]

6 3


  6%|▌         | 17/300 [00:17<04:43,  1.00s/it]

3 1


  6%|▌         | 18/300 [00:18<04:41,  1.00it/s]

1 5


  6%|▋         | 19/300 [00:19<04:41,  1.00s/it]

1 3


  7%|▋         | 20/300 [00:20<04:43,  1.01s/it]

4 4


  7%|▋         | 21/300 [00:21<04:39,  1.00s/it]

5 31


  7%|▋         | 22/300 [00:22<04:37,  1.00it/s]

87 1


  8%|▊         | 23/300 [00:23<04:36,  1.00it/s]

3 2


  8%|▊         | 24/300 [00:24<04:37,  1.01s/it]

2 7


  8%|▊         | 25/300 [00:25<04:39,  1.02s/it]

19 25


  9%|▊         | 26/300 [00:26<04:39,  1.02s/it]

2 5


  9%|▉         | 27/300 [00:27<04:39,  1.02s/it]

3 10


  9%|▉         | 28/300 [00:28<04:38,  1.02s/it]

14 8


 10%|▉         | 29/300 [00:29<04:37,  1.02s/it]

6 26


 10%|█         | 30/300 [00:30<04:37,  1.03s/it]

2 39


 10%|█         | 31/300 [00:31<04:33,  1.02s/it]

9 4


 11%|█         | 32/300 [00:32<04:31,  1.01s/it]

28 13


 11%|█         | 33/300 [00:33<04:32,  1.02s/it]

14 10


 11%|█▏        | 34/300 [00:34<04:31,  1.02s/it]

3 4


 12%|█▏        | 35/300 [00:35<04:28,  1.01s/it]

22 8


 12%|█▏        | 36/300 [00:36<04:28,  1.02s/it]

6 7


 12%|█▏        | 37/300 [00:37<04:26,  1.01s/it]

3 3


 13%|█▎        | 38/300 [00:38<04:26,  1.02s/it]

18 8


 13%|█▎        | 39/300 [00:39<04:24,  1.01s/it]

4 5


 13%|█▎        | 40/300 [00:40<04:24,  1.02s/it]

5 5


 14%|█▎        | 41/300 [00:41<04:24,  1.02s/it]

3 8


 14%|█▍        | 42/300 [00:42<04:21,  1.01s/it]

1 6


 14%|█▍        | 43/300 [00:43<04:18,  1.01s/it]

17 9


 15%|█▍        | 44/300 [00:44<04:21,  1.02s/it]

1 9


 15%|█▌        | 45/300 [00:45<04:21,  1.02s/it]

2 6


 15%|█▌        | 46/300 [00:46<04:20,  1.02s/it]

11 47


 16%|█▌        | 47/300 [00:47<04:18,  1.02s/it]

4 5


 16%|█▌        | 48/300 [00:48<04:17,  1.02s/it]

4 6


 16%|█▋        | 49/300 [00:49<04:14,  1.01s/it]

7 7


 17%|█▋        | 50/300 [00:50<04:14,  1.02s/it]

7 1


 17%|█▋        | 51/300 [00:51<04:12,  1.01s/it]

26 6


 17%|█▋        | 52/300 [00:52<04:10,  1.01s/it]

5 13


 18%|█▊        | 53/300 [00:53<04:09,  1.01s/it]

6 6


 18%|█▊        | 54/300 [00:54<04:08,  1.01s/it]

2 6


 18%|█▊        | 55/300 [00:55<04:11,  1.03s/it]

10 5


 19%|█▊        | 56/300 [00:56<04:09,  1.02s/it]

5 9


 19%|█▉        | 57/300 [00:57<04:09,  1.03s/it]

4 7


 19%|█▉        | 58/300 [00:58<04:07,  1.02s/it]

3 3


 20%|█▉        | 59/300 [00:59<04:09,  1.03s/it]

11 2


 20%|██        | 60/300 [01:00<04:07,  1.03s/it]

3 5


 20%|██        | 61/300 [01:01<04:08,  1.04s/it]

10 3


 21%|██        | 62/300 [01:02<04:05,  1.03s/it]

1 2


 21%|██        | 63/300 [01:03<04:04,  1.03s/it]

4 42


 21%|██▏       | 64/300 [01:04<04:01,  1.02s/it]

1 16


 22%|██▏       | 65/300 [01:06<03:59,  1.02s/it]

10 2


 22%|██▏       | 66/300 [01:07<03:57,  1.02s/it]

8 62


 22%|██▏       | 67/300 [01:08<03:57,  1.02s/it]

1 3


 23%|██▎       | 68/300 [01:09<03:55,  1.02s/it]

2 1


 23%|██▎       | 69/300 [01:10<03:52,  1.01s/it]

5 1


 23%|██▎       | 70/300 [01:11<03:50,  1.00s/it]

2 6


 24%|██▎       | 71/300 [01:12<03:49,  1.00s/it]

2 1


 24%|██▍       | 72/300 [01:12<03:46,  1.00it/s]

3 7


 24%|██▍       | 73/300 [01:14<03:47,  1.00s/it]

3 5


 25%|██▍       | 74/300 [01:15<03:50,  1.02s/it]

3 2


 25%|██▌       | 75/300 [01:16<03:48,  1.01s/it]

14 1


 25%|██▌       | 76/300 [01:17<03:45,  1.01s/it]

3 3


 26%|██▌       | 77/300 [01:18<03:44,  1.01s/it]

4 2


 26%|██▌       | 78/300 [01:19<03:42,  1.00s/it]

7 37


 26%|██▋       | 79/300 [01:20<03:40,  1.00it/s]

13 8


 27%|██▋       | 80/300 [01:21<03:38,  1.01it/s]

11 44


 27%|██▋       | 81/300 [01:22<03:38,  1.00it/s]

8 28


 27%|██▋       | 82/300 [01:23<03:37,  1.00it/s]

32 14


 28%|██▊       | 83/300 [01:24<03:36,  1.00it/s]

32 11


 28%|██▊       | 84/300 [01:25<03:35,  1.00it/s]

10 19


 28%|██▊       | 85/300 [01:26<03:36,  1.01s/it]

23 4


 29%|██▊       | 86/300 [01:27<03:37,  1.02s/it]

5 27


 29%|██▉       | 87/300 [01:28<03:35,  1.01s/it]

3 3


 29%|██▉       | 88/300 [01:29<03:34,  1.01s/it]

7 4


 30%|██▉       | 89/300 [01:30<03:32,  1.01s/it]

3 2


 30%|███       | 90/300 [01:31<03:32,  1.01s/it]

5 11


 30%|███       | 91/300 [01:32<03:30,  1.01s/it]

6 6


 31%|███       | 92/300 [01:33<03:28,  1.00s/it]

8 9


 31%|███       | 93/300 [01:34<03:27,  1.00s/it]

3 7


 31%|███▏      | 94/300 [01:35<03:29,  1.02s/it]

10 13


 32%|███▏      | 95/300 [01:36<03:27,  1.01s/it]

9 16


 32%|███▏      | 96/300 [01:37<03:26,  1.01s/it]

6 7


 32%|███▏      | 97/300 [01:38<03:28,  1.03s/it]

10 25


 33%|███▎      | 98/300 [01:39<03:26,  1.02s/it]

20 22


 33%|███▎      | 99/300 [01:40<03:23,  1.01s/it]

72 41


 33%|███▎      | 100/300 [01:41<03:22,  1.01s/it]

72 97


 34%|███▎      | 101/300 [01:42<03:22,  1.02s/it]

63 82


 34%|███▍      | 102/300 [01:43<03:20,  1.01s/it]

123 34


 34%|███▍      | 103/300 [01:44<03:19,  1.01s/it]

8 62


 35%|███▍      | 104/300 [01:45<03:18,  1.01s/it]

21 17


 35%|███▌      | 105/300 [01:46<03:17,  1.01s/it]

17 5


 35%|███▌      | 106/300 [01:47<03:15,  1.01s/it]

6 9


 36%|███▌      | 107/300 [01:48<03:13,  1.00s/it]

1 34


 36%|███▌      | 108/300 [01:49<03:13,  1.01s/it]

12 9


 36%|███▋      | 109/300 [01:50<03:11,  1.00s/it]

8 11


 37%|███▋      | 110/300 [01:51<03:10,  1.00s/it]

16 7


 37%|███▋      | 111/300 [01:52<03:11,  1.01s/it]

59 20


 37%|███▋      | 112/300 [01:53<03:08,  1.01s/it]

17 20


 38%|███▊      | 113/300 [01:54<03:06,  1.00it/s]

16 11


 38%|███▊      | 114/300 [01:55<03:05,  1.00it/s]

12 16


 38%|███▊      | 115/300 [01:56<03:04,  1.00it/s]

32 6


 39%|███▊      | 116/300 [01:57<03:04,  1.00s/it]

9 61


 39%|███▉      | 117/300 [01:58<03:03,  1.00s/it]

62 13


 39%|███▉      | 118/300 [01:59<03:02,  1.00s/it]

84 4


 40%|███▉      | 119/300 [02:00<03:00,  1.00it/s]

76 29


 40%|████      | 120/300 [02:01<03:00,  1.00s/it]

14 20


 40%|████      | 121/300 [02:02<02:58,  1.00it/s]

47 21


 41%|████      | 122/300 [02:03<02:59,  1.01s/it]

32 17


 41%|████      | 123/300 [02:04<02:56,  1.00it/s]

18 18


 41%|████▏     | 124/300 [02:05<02:55,  1.00it/s]

15 37


 42%|████▏     | 125/300 [02:06<02:54,  1.00it/s]

29 31


 42%|████▏     | 126/300 [02:07<02:54,  1.00s/it]

35 9


 42%|████▏     | 127/300 [02:08<02:52,  1.00it/s]

10 5


 43%|████▎     | 128/300 [02:09<02:51,  1.00it/s]

16 5


 43%|████▎     | 129/300 [02:10<02:51,  1.00s/it]

14 7


 43%|████▎     | 130/300 [02:11<02:50,  1.00s/it]

15 21


 44%|████▎     | 131/300 [02:12<02:49,  1.00s/it]

45 10


 44%|████▍     | 132/300 [02:13<02:48,  1.01s/it]

70 14


 44%|████▍     | 133/300 [02:14<02:48,  1.01s/it]

2 11


 45%|████▍     | 134/300 [02:15<02:47,  1.01s/it]

28 32


 45%|████▌     | 135/300 [02:16<02:45,  1.00s/it]

10 11


 45%|████▌     | 136/300 [02:17<02:44,  1.00s/it]

23 14


 46%|████▌     | 137/300 [02:18<02:41,  1.01it/s]

17 9


 46%|████▌     | 138/300 [02:19<02:42,  1.00s/it]

20 17


 46%|████▋     | 139/300 [02:20<02:40,  1.00it/s]

55 9


 47%|████▋     | 140/300 [02:21<02:38,  1.01it/s]

50 13


 47%|████▋     | 141/300 [02:22<02:36,  1.01it/s]

14 15


 47%|████▋     | 142/300 [02:23<02:36,  1.01it/s]

52 16


 48%|████▊     | 143/300 [02:24<02:36,  1.00it/s]

11 15


 48%|████▊     | 144/300 [02:25<02:36,  1.00s/it]

2 63


 48%|████▊     | 145/300 [02:26<02:35,  1.01s/it]

11 22


 49%|████▊     | 146/300 [02:27<02:32,  1.01it/s]

8 7


 49%|████▉     | 147/300 [02:28<02:32,  1.00it/s]

8 16


 49%|████▉     | 148/300 [02:29<02:33,  1.01s/it]

19 7


 50%|████▉     | 149/300 [02:30<02:33,  1.02s/it]

5 13


 50%|█████     | 150/300 [02:31<02:34,  1.03s/it]

6 38


 50%|█████     | 151/300 [02:32<02:32,  1.02s/it]

14 14


 51%|█████     | 152/300 [02:33<02:30,  1.02s/it]

10 8


 51%|█████     | 153/300 [02:34<02:30,  1.03s/it]

6 10


 51%|█████▏    | 154/300 [02:35<02:31,  1.04s/it]

2 9


 52%|█████▏    | 155/300 [02:36<02:29,  1.03s/it]

4 13


 52%|█████▏    | 156/300 [02:37<02:27,  1.03s/it]

3 4


 52%|█████▏    | 157/300 [02:38<02:25,  1.02s/it]

14 8


 53%|█████▎    | 158/300 [02:39<02:24,  1.02s/it]

1 5


 53%|█████▎    | 159/300 [02:40<02:24,  1.02s/it]

6 5


 53%|█████▎    | 160/300 [02:41<02:21,  1.01s/it]

5 9


 54%|█████▎    | 161/300 [02:42<02:19,  1.00s/it]

2 10


 54%|█████▍    | 162/300 [02:43<02:18,  1.01s/it]

6 26


 54%|█████▍    | 163/300 [02:44<02:16,  1.01it/s]

6 6


 55%|█████▍    | 164/300 [02:45<02:14,  1.01it/s]

2 6


 55%|█████▌    | 165/300 [02:46<02:12,  1.02it/s]

10 2


 55%|█████▌    | 166/300 [02:47<02:12,  1.01it/s]

7 4


 56%|█████▌    | 167/300 [02:48<02:11,  1.01it/s]

1 3


 56%|█████▌    | 168/300 [02:49<02:10,  1.01it/s]

6 12


 56%|█████▋    | 169/300 [02:50<02:10,  1.00it/s]

11 3


 57%|█████▋    | 170/300 [02:51<02:09,  1.00it/s]

9 2


 57%|█████▋    | 171/300 [02:52<02:09,  1.00s/it]

5 1


 57%|█████▋    | 172/300 [02:53<02:09,  1.01s/it]

9 4


 58%|█████▊    | 173/300 [02:54<02:07,  1.00s/it]

4 10


 58%|█████▊    | 174/300 [02:55<02:06,  1.00s/it]

5 7


 58%|█████▊    | 175/300 [02:56<02:04,  1.00it/s]

4 4


 59%|█████▊    | 176/300 [02:57<02:03,  1.00it/s]

6 5


 59%|█████▉    | 177/300 [02:58<02:02,  1.00it/s]

4 6


 59%|█████▉    | 178/300 [02:59<02:02,  1.01s/it]

4 10


 60%|█████▉    | 179/300 [03:00<02:00,  1.01it/s]

5 4


 60%|██████    | 180/300 [03:01<01:58,  1.01it/s]

1 5


 60%|██████    | 181/300 [03:02<01:58,  1.01it/s]

1 3


 61%|██████    | 182/300 [03:03<01:57,  1.00it/s]

2 7


 61%|██████    | 183/300 [03:04<01:56,  1.00it/s]

2 5


 61%|██████▏   | 184/300 [03:05<01:56,  1.00s/it]

4 1


 62%|██████▏   | 185/300 [03:06<01:54,  1.00it/s]

2 7


 62%|██████▏   | 186/300 [03:07<01:56,  1.02s/it]

3 2


 62%|██████▏   | 187/300 [03:08<01:55,  1.02s/it]

3 3


 63%|██████▎   | 188/300 [03:09<01:55,  1.03s/it]

1 5


 63%|██████▎   | 189/300 [03:10<01:55,  1.04s/it]

2 3


 63%|██████▎   | 190/300 [03:11<01:55,  1.05s/it]

1 4


 64%|██████▎   | 191/300 [03:12<01:53,  1.04s/it]

1 2


 64%|██████▍   | 192/300 [03:13<01:53,  1.05s/it]

8 2


 64%|██████▍   | 193/300 [03:14<01:50,  1.04s/it]

3 1


 65%|██████▍   | 194/300 [03:15<01:49,  1.04s/it]

2 3
1 2


 65%|██████▌   | 196/300 [03:16<01:12,  1.43it/s]

2 2


 66%|██████▌   | 197/300 [03:17<01:19,  1.29it/s]

1 1


 66%|██████▌   | 198/300 [03:18<01:25,  1.19it/s]

1 1


 66%|██████▋   | 199/300 [03:19<01:29,  1.13it/s]

2 5


 67%|██████▋   | 200/300 [03:20<01:31,  1.09it/s]

1 5


 67%|██████▋   | 201/300 [03:21<01:32,  1.07it/s]

2 3


 67%|██████▋   | 202/300 [03:22<01:33,  1.05it/s]

1 4


 68%|██████▊   | 203/300 [03:23<01:34,  1.03it/s]

1 2


 68%|██████▊   | 204/300 [03:24<01:33,  1.02it/s]

2 1


 68%|██████▊   | 205/300 [03:25<01:32,  1.02it/s]

1 2


 69%|██████▊   | 206/300 [03:26<01:33,  1.01it/s]

1 4


 69%|██████▉   | 207/300 [03:27<01:33,  1.00s/it]

6 2


 69%|██████▉   | 208/300 [03:28<01:32,  1.01s/it]

5 3


 70%|██████▉   | 209/300 [03:29<01:31,  1.01s/it]

1 1


 70%|███████   | 210/300 [03:30<01:30,  1.00s/it]

4 2


 70%|███████   | 211/300 [03:31<01:29,  1.01s/it]

4 1


 71%|███████   | 212/300 [03:32<01:28,  1.01s/it]

1 1


 71%|███████   | 213/300 [03:33<01:28,  1.02s/it]

3 2


 71%|███████▏  | 214/300 [03:34<01:27,  1.02s/it]

13 3


 72%|███████▏  | 215/300 [03:35<01:25,  1.01s/it]

1 6


 72%|███████▏  | 216/300 [03:36<01:24,  1.01s/it]

1 3


 72%|███████▏  | 217/300 [03:37<01:23,  1.01s/it]

2 4


 73%|███████▎  | 218/300 [03:38<01:21,  1.00it/s]

1 1


 73%|███████▎  | 219/300 [03:39<01:20,  1.01it/s]

5 4


 73%|███████▎  | 220/300 [03:40<01:19,  1.01it/s]

13 1


 74%|███████▎  | 221/300 [03:41<01:18,  1.01it/s]

2 1


 74%|███████▍  | 222/300 [03:42<01:17,  1.00it/s]

4 2


 74%|███████▍  | 223/300 [03:43<01:17,  1.00s/it]

3 3


 75%|███████▍  | 224/300 [03:44<01:16,  1.01s/it]

3 5


 75%|███████▌  | 225/300 [03:45<01:15,  1.01s/it]

2 2


 75%|███████▌  | 226/300 [03:46<01:15,  1.01s/it]

2 9


 76%|███████▌  | 227/300 [03:47<01:14,  1.02s/it]

10 3


 76%|███████▌  | 228/300 [03:48<01:12,  1.00s/it]

2 1


 76%|███████▋  | 229/300 [03:49<01:11,  1.01s/it]

1 4


 77%|███████▋  | 230/300 [03:50<01:10,  1.01s/it]

5 2


 77%|███████▋  | 231/300 [03:51<01:09,  1.01s/it]

1 1


 77%|███████▋  | 232/300 [03:52<01:09,  1.02s/it]

2 3


 78%|███████▊  | 233/300 [03:53<01:07,  1.01s/it]

1 2


 78%|███████▊  | 234/300 [03:54<01:05,  1.00it/s]

8 2


 78%|███████▊  | 235/300 [03:55<01:05,  1.00s/it]

2 19


 79%|███████▊  | 236/300 [03:56<01:04,  1.00s/it]

3 6


 79%|███████▉  | 237/300 [03:57<01:03,  1.01s/it]

1 3


 79%|███████▉  | 238/300 [03:58<01:02,  1.01s/it]

1 1


 80%|███████▉  | 239/300 [03:59<01:01,  1.01s/it]

9 3


 80%|████████  | 240/300 [04:00<01:00,  1.01s/it]

2 1


 80%|████████  | 241/300 [04:01<00:59,  1.01s/it]

2 3


 81%|████████  | 242/300 [04:02<00:58,  1.01s/it]

2 13


 81%|████████  | 243/300 [04:03<00:57,  1.02s/it]

2 1


 81%|████████▏ | 244/300 [04:04<00:56,  1.01s/it]

1 2


 82%|████████▏ | 245/300 [04:05<00:55,  1.00s/it]

4 10


 82%|████████▏ | 246/300 [04:06<00:54,  1.01s/it]

5 1


 82%|████████▏ | 247/300 [04:07<00:53,  1.01s/it]

16 6


 83%|████████▎ | 248/300 [04:08<00:52,  1.00s/it]

2 1


 83%|████████▎ | 249/300 [04:09<00:51,  1.00s/it]

3 1


 83%|████████▎ | 250/300 [04:10<00:50,  1.01s/it]

2 1


 84%|████████▎ | 251/300 [04:11<00:49,  1.01s/it]

2 3


 84%|████████▍ | 252/300 [04:12<00:48,  1.00s/it]

1 1


 84%|████████▍ | 253/300 [04:13<00:47,  1.01s/it]

4 1


 85%|████████▍ | 254/300 [04:14<00:46,  1.01s/it]

6 4


 85%|████████▌ | 255/300 [04:15<00:45,  1.01s/it]

2 6


 85%|████████▌ | 256/300 [04:16<00:44,  1.01s/it]

4 5


 86%|████████▌ | 257/300 [04:17<00:43,  1.01s/it]

4 1


 86%|████████▌ | 258/300 [04:19<00:42,  1.02s/it]

2 3


 86%|████████▋ | 259/300 [04:20<00:41,  1.02s/it]

2 2


 87%|████████▋ | 260/300 [04:21<00:41,  1.03s/it]

2 3


 87%|████████▋ | 261/300 [04:22<00:40,  1.04s/it]

4 7


 87%|████████▋ | 262/300 [04:23<00:39,  1.04s/it]

4 3


 88%|████████▊ | 263/300 [04:24<00:38,  1.04s/it]

2 4


 88%|████████▊ | 264/300 [04:25<00:37,  1.03s/it]

2 2


 88%|████████▊ | 265/300 [04:26<00:35,  1.03s/it]

1 2


 89%|████████▊ | 266/300 [04:27<00:35,  1.03s/it]

2 2


 89%|████████▉ | 267/300 [04:28<00:34,  1.04s/it]

4 2


 89%|████████▉ | 268/300 [04:29<00:32,  1.02s/it]

2 1


 90%|████████▉ | 269/300 [04:30<00:31,  1.03s/it]

5 1


 90%|█████████ | 270/300 [04:31<00:31,  1.04s/it]

1 1


 90%|█████████ | 271/300 [04:32<00:30,  1.04s/it]

2 1


 91%|█████████ | 272/300 [04:33<00:28,  1.03s/it]

1 1


 91%|█████████ | 273/300 [04:34<00:28,  1.04s/it]

2 6


 91%|█████████▏| 274/300 [04:35<00:26,  1.04s/it]

2 1


 92%|█████████▏| 275/300 [04:36<00:25,  1.03s/it]

2 1


 92%|█████████▏| 276/300 [04:37<00:24,  1.03s/it]

4 1


 92%|█████████▏| 277/300 [04:38<00:23,  1.03s/it]

1 1


 93%|█████████▎| 278/300 [04:39<00:22,  1.03s/it]

1 1


 93%|█████████▎| 279/300 [04:40<00:21,  1.03s/it]

1 4


 93%|█████████▎| 280/300 [04:41<00:20,  1.02s/it]

1 1


 94%|█████████▎| 281/300 [04:42<00:19,  1.01s/it]

1 2


 94%|█████████▍| 282/300 [04:43<00:18,  1.01s/it]

1 1


 94%|█████████▍| 283/300 [04:44<00:17,  1.01s/it]

4 4


 95%|█████████▍| 284/300 [04:45<00:15,  1.00it/s]

5 1


 95%|█████████▌| 285/300 [04:46<00:15,  1.00s/it]

10 4


 95%|█████████▌| 286/300 [04:47<00:14,  1.00s/it]

2 1


 96%|█████████▌| 287/300 [04:48<00:13,  1.01s/it]

2 1


 96%|█████████▌| 288/300 [04:49<00:12,  1.00s/it]

2 1


 96%|█████████▋| 289/300 [04:50<00:11,  1.00s/it]

3 5


 97%|█████████▋| 290/300 [04:51<00:10,  1.00s/it]

1 1


 97%|█████████▋| 291/300 [04:52<00:09,  1.01s/it]

3 1


 97%|█████████▋| 292/300 [04:53<00:08,  1.01s/it]

1 1


 98%|█████████▊| 293/300 [04:54<00:07,  1.00s/it]

1 1


 98%|█████████▊| 294/300 [04:55<00:06,  1.01s/it]

2 1


 98%|█████████▊| 295/300 [04:56<00:05,  1.01s/it]

5 2


 99%|█████████▊| 296/300 [04:57<00:04,  1.02s/it]

1 1


 99%|█████████▉| 297/300 [04:58<00:03,  1.02s/it]

2 1


 99%|█████████▉| 298/300 [04:59<00:02,  1.02s/it]

4 1


100%|█████████▉| 299/300 [05:00<00:01,  1.03s/it]

1 3


100%|██████████| 300/300 [05:01<00:00,  1.01s/it]


Epoch 25/25, Loss: 33.4996495734322


100%|██████████| 100/100 [01:09<00:00,  1.44it/s]

Validation Loss: 0.008441486041992902, mIOU: 0.9845870137214661, IOU Background: 0.9923500418663025, IOU Tongue: 0.9768244028091431





In [None]:
import torch
import torch.nn as nn
from torchvision import models
import numpy as np
from PIL import Image
from torchvision import transforms


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# 定义基于VGG16的FCN网络
class VGG16_FCN(nn.Module):
    def __init__(self, num_classes):
        super(VGG16_FCN, self).__init__()
        # 加载预训练的VGG16模型
        vgg16 = models.vgg16(
            weights=models.VGG16_Weights.DEFAULT)

        # 取出VGG16的前面部分（去掉全连接层）
        self.features = vgg16.features

        # 用1x1卷积替换VGG16的全连接层
        self.conv6 = nn.Conv2d(512, 4096, kernel_size=1)
        self.relu6 = nn.ReLU(inplace=True)
        self.dropout6 = nn.Dropout2d()
        self.conv7 = nn.Conv2d(4096, 4096, kernel_size=1)
        self.relu7 = nn.ReLU(inplace=True)
        self.dropout7 = nn.Dropout2d()

        # 最后的卷积层用于生成分割结果
        self.score_fr = nn.Conv2d(4096, num_classes, kernel_size=1)
        self.upscore = nn.ConvTranspose2d(
            num_classes, num_classes, kernel_size=64, stride=32, padding=16, bias=False)

        # 初始化上采样层权重
        self.upscore.weight.data.fill_(0)
        self.upscore.weight.data[:, :, 16, 16] = 1  # 双线性插值

    def forward(self, x):
        # 前向传播
        x = self.features(x)
        x = self.relu6(self.conv6(x))
        x = self.dropout6(x)
        x = self.relu7(self.conv7(x))
        x = self.dropout7(x)
        x = self.score_fr(x)
        x = self.upscore(x)
        return x  


def get_segmentation_mask(model, image_path):

    model.eval()

    image = Image.open(image_path).convert('RGB')

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
                             0.229, 0.224, 0.225])
    ])
    image = transform(image).unsqueeze(0)
    print(image.size())
    image = image.to(device)
    with torch.no_grad():
        output = model(image)
    output = torch.sigmoid(output)
    output = torch.argmax(output, dim=1)
    mask = output.squeeze().cpu().numpy()
    return mask

# Get segmentation mask for a sample image

image_path = './data/test/label_data/test1.jpg'  # Replace with your image path
# 载入模型权重
model = VGG16_FCN(num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.load_state_dict(torch.load('model.pth'))

mask = get_segmentation_mask(model, image_path)


# Create the final mask where tongue region is white and background is black

final_mask = np.zeros_like(mask)

final_mask[mask == 0] = 255
final_mask = Image.fromarray(final_mask.astype(np.uint8))

# 叠加原图
image = Image.open(image_path).convert('RGB')
image = image.resize(final_mask.size)
image.paste(final_mask, (0, 0), final_mask)
image

In [None]:
import numpy as np
import torch

# 生成随机的形状为 (2, 2, 3, 3) 的张量，值在 0 到 1 之间
random_tensor = np.random.rand(2, 2, 3, 3)

print(random_tensor)

# 将 NumPy 数组转换为 PyTorch 张量
random_tensor = torch.tensor(random_tensor, dtype=torch.float32)
print(torch.sigmoid(random_tensor))
print(torch.softmax(random_tensor, dim=1))
print(torch.argmax(random_tensor, dim=1, keepdim=True))