In [23]:
img_size = 256
num_epochs = 100
patch_size = 32
contrastive_batch_size = 256
batch_size = 8
classes = 2
learning_rate = 1e-3

data_dir = '/home/louis/Documents/project/PatchCL-MedSeg/0_data_dataset_voc_950_kidney/'
output_dir = '/home/louis/Documents/project/pixel-contrastive-segmentation/dataset/splits/kidney/'

In [24]:
import os
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import torchvision.models as models
from torch.utils.data import DataLoader
from pixel_level_contrastive_learning import PixelCL

In [25]:
from dataset_kidney import BaseDatasets  

In [26]:
import numpy as np
print(np.__version__)

1.26.4


In [33]:
# Define transformations if needed
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
])

IMG_folder_path = data_dir 
msk_folder_path = data_dir
# Load file lists
with open(os.path.join(output_dir, "1-3", "labeled.txt"), 'r') as file:
    labeled_files = [line.strip().split(' ') for line in file.readlines()]
with open(os.path.join(output_dir, "1-3", "unlabeled.txt"), 'r') as file:
    unlabeled_files = [line.strip() for line in file.readlines()]
with open(os.path.join(output_dir, "val.txt"), 'r') as file:
    val_files = [line.strip().split(' ') for line in file.readlines()]

# # Create datasets and dataloaders
labeled_dataset = BaseDatasets(labeled_files, IMG_folder_path, msk_folder_path, transform)
unlabeled_dataset = BaseDatasets(unlabeled_files, IMG_folder_path, transform=transform)
val_dataset = BaseDatasets(val_files, IMG_folder_path, msk_folder_path, transform)

labeled_dataloader = DataLoader(labeled_dataset, batch_size=batch_size, shuffle=True)
unlabeled_dataloader = DataLoader(unlabeled_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


print('===========================================================')
print('number of labeled_dataset: ', len(labeled_dataset))
print('number of unlabeled_dataset: ', len(unlabeled_dataset))
print('number of val_dataset: ', len(val_dataset))
print('===========================================================')

number of labeled_dataset:  285
number of unlabeled_dataset:  570
number of val_dataset:  95


In [16]:
import torch
from torch import nn
from torchvision import models
from torchvision.models import ResNet50_Weights

# Define the segmentation model based on ResNet
class ResNetSegmentation(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        # Base ResNet model, without the final fully connected layer
        original_resnet = models.resnet50(pretrained=True)
        self.backbone = nn.Sequential(*list(original_resnet.children())[:-2])
        
        # Additional layers for segmentation
        self.conv1x1 = nn.Conv2d(2048, num_classes, 1)
        self.upsample = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=False)

    def forward(self, x):
        x = self.backbone(x)
        x = self.conv1x1(x)
        x = self.upsample(x)
        return x

In [17]:
num_classes = 2  # 例如 VOC2012 数据集的类别数
model = ResNetSegmentation(num_classes).cuda()

# 检查模型输出
dummy_input = torch.randn(1, 3, 256, 256).cuda()  # 假设输入是 256x256 的图像
dummy_output = model(dummy_input)
print("Output shape:", dummy_output.shape)  # 应该是 (1, num_classes, 256, 256)



Output shape: torch.Size([1, 2, 256, 256])


In [18]:
# Use the unmodified ResNet-50 for PixelCL
original_resnet = models.resnet50(pretrained=True).cuda()
learner = PixelCL(
    original_resnet,
    image_size=256,
    hidden_layer_pixel='layer4',
    hidden_layer_instance=-2,
    projection_size=256,
    projection_hidden_size=2048,
    moving_average_decay=0.99,
    ppm_num_layers=1,
    ppm_gamma=2,
    distance_thres=0.7,
    similarity_temperature=0.3,
    alpha=1.0,
    use_pixpro=True,
    cutout_ratio_range=(0.6, 0.8)
).cuda()

In [19]:
opt = torch.optim.Adam(list(model.parameters()) + list(learner.parameters()), lr=1e-4)

In [20]:
# Training function including both supervised and unsupervised losses
def train_epoch(model, learner, labeled_loader, optimizer, criterion, epoch):
    model.train()
    total_supervised_loss = 0
    total_contrastive_loss = 0
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model.to(device)
    learner.to(device)

    for imgs, masks in tqdm(labeled_loader, desc=f"Epoch {epoch+1}", leave=False):
        imgs, masks = imgs.to(device), masks.to(device)

        optimizer.zero_grad()

        # Forward pass for segmentation
        outputs = model(imgs)
        supervised_loss = criterion(outputs, masks)
        total_supervised_loss += supervised_loss.item()

        # Calculate the unsupervised PixelCL loss
        contrast_loss = learner(imgs)
        total_contrastive_loss += contrast_loss.item()

        # Combine losses and backpropagate
        loss = supervised_loss + contrast_loss
        loss.backward()
        optimizer.step()

        # Update the moving average for the target encoder in PixelCL
        learner.update_moving_average()

    avg_supervised_loss = total_supervised_loss / len(labeled_loader)
    avg_contrastive_loss = total_contrastive_loss / len(labeled_loader)
    print(f"Epoch [{epoch+1}], Supervised Loss: {avg_supervised_loss:.4f}, Contrastive Loss: {avg_contrastive_loss:.4f}")


In [21]:
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    train_epoch(model, learner, labeled_dataloader, opt, nn.CrossEntropyLoss(), epoch)

                                                        

Epoch [1], Supervised Loss: 0.1139, Contrastive Loss: 1.6439


                                                        

Epoch [2], Supervised Loss: 0.0367, Contrastive Loss: 1.0215


                                                        

Epoch [3], Supervised Loss: 0.0352, Contrastive Loss: 0.9271


                                                        

Epoch [4], Supervised Loss: 0.0340, Contrastive Loss: 0.8646


                                                        

Epoch [5], Supervised Loss: 0.0333, Contrastive Loss: 0.7586


                                                        

Epoch [6], Supervised Loss: 0.0326, Contrastive Loss: 0.5874


                                                        

Epoch [7], Supervised Loss: 0.0323, Contrastive Loss: 0.5377


                                                        

Epoch [8], Supervised Loss: 0.0321, Contrastive Loss: 0.3693


                                                        

Epoch [9], Supervised Loss: 0.0317, Contrastive Loss: 0.1773


                                                         

Epoch [10], Supervised Loss: 0.0317, Contrastive Loss: 0.2566




In [22]:
# Save the improved segmentation model
torch.save(model.state_dict(), 'improved-resnet-segmentation.pt')

# Test

In [55]:
import torch
from tqdm import tqdm
import numpy as np

def calculate_iou(pred, target, n_classes):
    ious = []
    # Convert softmax predictions to class indexes if not already
    pred = torch.argmax(pred, dim=1) if pred.shape[1] > 1 else pred.squeeze(1)
    
    if target.shape[1] > 1:
        # Assuming target is one-hot encoded
        target = torch.argmax(target, dim=1)
    
    # Flatten the arrays to calculate IoU on a per-pixel basis
    pred = pred.view(-1)
    target = target.view(-1)

    for cls in range(n_classes):
        pred_inds = (pred == cls)
        target_inds = (target == cls)
        
        # Calculate Intersection and Union
        intersection = (pred_inds & target_inds).sum().item()
        union = pred_inds.sum().item() + target_inds.sum().item() - intersection
        
        if union == 0:
            ious.append(float('nan'))  # Avoid division by zero
        else:
            ious.append(intersection / union)
    
    return np.nanmean(ious)

def iou_pytorch(outputs: torch.Tensor, labels: torch.Tensor):
    SMOOTH = 1e-6
    
    # Convert output probabilities to binary predictions
    outputs = torch.sigmoid(outputs)  # Assuming outputs are logits from the model; remove if already probabilities
    outputs = (outputs > 0.5).float()  # Threshold the probabilities to create binary predictions
    
    # Flatten the tensors to simplify the intersection/union computation
    outputs = outputs.view(outputs.shape[0], -1)
    labels = labels.view(labels.shape[0], -1)
    
    # Compute the intersection and union
    intersection = (outputs * labels).sum(1)
    union = (outputs + labels).sum(1) - intersection
    
    # Compute IoU and handle cases where the union is 0
    iou = (intersection + SMOOTH) / (union + SMOOTH)
    
    # You can threshold IoU values here if needed (e.g., for metric computations)
    return iou.mean()


def validate_model(model, val_loader, criterion, device, num_classes):
    model.eval()
    total_loss = 0
    total_iou = 0
    count = 0
    
    with torch.no_grad():
        for imgs, masks in tqdm(val_loader, desc="Validating", leave=False):
            imgs, masks = imgs.to(device), masks.to(device)
            
            # Forward pass
            outputs = model(imgs)
            loss = criterion(outputs, masks)
            total_loss += loss.item()
            
            # Calculate IoU - ensure outputs and masks are compatible
            iou = iou_pytorch(outputs, masks)
            total_iou += iou.item()  # Sum IoU for averaging
            count += 1
    
    avg_loss = total_loss / len(val_loader)
    avg_iou = total_iou / count if count > 0 else 0
    return avg_loss, avg_iou


In [56]:
# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# 计算验证集上的平均损失和 IoU
val_loss, val_iou = validate_model(model, val_dataloader, nn.CrossEntropyLoss(), device, num_classes)

print(f"Validation Loss: {val_loss:.4f}")
print(f"Mean IoU: {val_iou:.4f}")

                                                           

Validation Loss: 0.0342
Mean IoU: 0.9811


