In [1]:
import nibabel as nib
import matplotlib.pyplot as plt

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import torchvision.transforms as transforms
import numpy as np

In [20]:
# 自定义数据集类
class SpineWeb15(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        
    def __len__(self):
        return len(self.data_path)
    
    def __getitem__(self, index):
        metadata = np.load(self.data_path+str(index)+".npy")
        image = metadata[0]
        label = metadata[1]

        image_data = torch.from_numpy(image).float()
        label_data = torch.from_numpy(label).float()
        
        if self.transform:
            image_data = self.transform(image_data)
            
        return image_data, label_data



In [29]:
# U-Net structure

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.double_conv(x)


class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        
        # 编码器部分
        self.conv1 = DoubleConv(in_channels, 64)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)   #64
        self.conv2 = DoubleConv(64, 128)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)   #32
        self.conv3 = DoubleConv(128, 256)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)   #16
        self.conv4 = DoubleConv(256, 512)
        self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)   #8
        
        # 解码器部分
        self.conv5 = DoubleConv(512, 1024)
        self.upconv6 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)   #16
        self.conv6 = DoubleConv(1024, 512)
        self.upconv7 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)    #32
        self.conv7 = DoubleConv(512, 256)
        self.upconv8 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)    #64
        self.conv8 = DoubleConv(256, 128)
        self.upconv9 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)     #128 feature map size
        self.conv9 = DoubleConv(128, 64)
        
        # 输出层
        self.output = nn.Conv2d(64, out_channels, kernel_size=1)
    
    def forward(self, x):
        # 编码器
        c1 = self.conv1(x)
        p1 = self.maxpool1(c1)
        c2 = self.conv2(p1)
        p2 = self.maxpool2(c2)
        c3 = self.conv3(p2)
        p3 = self.maxpool3(c3)
        c4 = self.conv4(p3)
        p4 = self.maxpool4(c4)
        
        # 解码器
        u5 = self.conv5(p4)
        u5 = self.upconv6(u5)
        u5 = torch.cat((u5, c4), dim=1)
        c6 = self.conv6(u5)
        u6 = self.upconv7(c6)
        u6 = torch.cat((u6, c3), dim=1)
        c7 = self.conv7(u6)
        u7 = self.upconv8(c7)
        u7 = torch.cat((u7, c2), dim=1)
        c8 = self.conv8(u7)
        u8 = self.upconv9(c8)
        c9 = self.conv9(u8)
        
        # 输出层
        output = self.output(c9)
        return output

# 创建UNet模型实例
model = UNet(in_channels=1, out_channels=1)

In [30]:
# set the hyper parameters and the paths 设置超参数和路径
data_path = "./dataset/Spineweb_dataset15/processed/"

#hyper parameters set
batch_size = 16
num_epochs = 1
learning_rate = 0.001
in_channels = 1  # according to demand to adjust the number of channels 根据实际情况修改通道数
out_channels = 1  # according to demand to adjust the number of channels 根据实际情况修改通道数

# 创建数据集和数据加载器
transform = transforms.Compose([
    transforms.ToPILImage(),  # 转换为 PIL 图像对象
    transforms.Resize((128, 128)),  # 调整大小为 128x128
    transforms.ToTensor()  # 转换为张量
])
dataset = SpineWeb15(data_path=data_path,transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 创建模型和优化器
model = UNet(in_channels, out_channels)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.BCEWithLogitsLoss()  # 二分类任务可以使用BCEWithLogitsLoss

# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# 训练模型
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    correct = 0
    total = 0

    for images, labels in dataloader:
        # print("images",images.shape)
        # print("labels",labels.shape)
        
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        
    # train_accuracy = 100 * correct / total
    # avg_train_loss = train_loss / len(dataloader)

    # print('Epoch [{}/{}], Average Training Loss: {:.4f}, Training Accuracy: {:.2f}%'
    #       .format(epoch + 1, num_epochs, avg_train_loss, train_accuracy))

# 接下来，你可以根据需要对模型进行评估或应用
# ...

RuntimeError: Given groups=1, weight of size [64, 128, 3, 3], expected input[16, 64, 128, 128] to have 128 channels, but got 64 channels instead

In [None]:
# 在测试集上评估模型
model.eval()
test_loss = 0
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_dataloader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        test_loss += criterion(outputs, labels).item()

test_accuracy = 100 * correct / total
print('Test Accuracy: {:.2f}%'.format(test_accuracy))