In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [2]:
def list_files_only(directory):
    # List all items in the directory
    items = os.listdir(directory)
    # Filter out directories
    files = [item for item in items if os.path.isfile(os.path.join(directory, item))]
    return files

In [3]:
class SoilSegmentationDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        """
        Args:
            image_dir (string): Directory with all the images.
            label_dir (string): Directory with all the labels.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.images = list_files_only(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.images[idx])
        label_name = os.path.join(self.label_dir, self.images[idx])
        
        image = Image.open(img_name).convert('L')  # Convert to grayscale
        label = Image.open(label_name).convert('L')

        if self.transform:
            image = self.transform(image)
            label = self.transform(label)

        return image, label

In [4]:
# Define transformations and dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Resize((256, 256))  # Resize images if required
])
train_dataset = SoilSegmentationDataset('e:/3.Experimental_Data/DL_Data_raw/images/', 'e:/3.Experimental_Data/DL_Data_raw/labels/', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

In [5]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.dconv_down1 = DoubleConv(1, 64)
        self.dconv_down2 = DoubleConv(64, 128)
        self.dconv_down3 = DoubleConv(128, 256)
        self.dconv_down4 = DoubleConv(256, 512)        

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)        
        
        self.dconv_up3 = DoubleConv(256 + 512, 256)
        self.dconv_up2 = DoubleConv(128 + 256, 128)
        self.dconv_up1 = DoubleConv(64 + 128, 64)
        
        self.conv_last = nn.Conv2d(64, 1, 1)
        
    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)
        
        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)
        
        x = self.dconv_down4(x)
        
        x = self.upsample(x)        
        x = torch.cat([x, conv3], dim=1)
        
        x = self.dconv_up3(x)
        x = self.upsample(x)
        x = torch.cat([x, conv2], dim=1)

        x = self.dconv_up2(x)
        x = self.upsample(x)
        x = torch.cat([x, conv1], dim=1)

        x = self.dconv_up1(x)
        
        x = self.conv_last(x)
        return x

In [6]:
# Initialize the model
model = UNet()
model = model.cuda()  # Move model to GPU

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Define the loss function
criterion = nn.BCEWithLogitsLoss()

In [7]:
from tqdm import tqdm  # for displaying the progress bar

def train_model(model, train_loader, criterion, optimizer, num_epochs=25):
    model.train()  # Set model to training mode
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        
        for images, labels in tqdm(train_loader):
            images = images.cuda()  # Move images to GPU
            labels = labels.cuda()  # Move labels to GPU
            # labels = labels.unsqueeze(1)  # Add channel dimension to labels
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * images.size(0)
        
        epoch_loss = running_loss / len(train_loader.dataset)
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')

# Run the training loop
train_model(model, train_loader, criterion, optimizer, num_epochs=50)

100%|██████████| 3/3 [00:02<00:00,  1.28it/s]


Epoch 1/50, Loss: 0.7567


100%|██████████| 3/3 [00:01<00:00,  1.62it/s]


Epoch 2/50, Loss: 0.6283


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 3/50, Loss: 0.5534


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 4/50, Loss: 0.4965


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 5/50, Loss: 0.4537


100%|██████████| 3/3 [00:01<00:00,  1.64it/s]


Epoch 6/50, Loss: 0.4219


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 7/50, Loss: 0.3980


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 8/50, Loss: 0.3818


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 9/50, Loss: 0.3724


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 10/50, Loss: 0.3626


100%|██████████| 3/3 [00:01<00:00,  1.62it/s]


Epoch 11/50, Loss: 0.3556


100%|██████████| 3/3 [00:01<00:00,  1.62it/s]


Epoch 12/50, Loss: 0.3508


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 13/50, Loss: 0.3461


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 14/50, Loss: 0.3422


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 15/50, Loss: 0.3393


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 16/50, Loss: 0.3363


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 17/50, Loss: 0.3347


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 18/50, Loss: 0.3317


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 19/50, Loss: 0.3315


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 20/50, Loss: 0.3383


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 21/50, Loss: 0.3429


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 22/50, Loss: 0.3322


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 23/50, Loss: 0.3290


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 24/50, Loss: 0.3250


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 25/50, Loss: 0.3230


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 26/50, Loss: 0.3204


100%|██████████| 3/3 [00:01<00:00,  1.62it/s]


Epoch 27/50, Loss: 0.3176


100%|██████████| 3/3 [00:01<00:00,  1.62it/s]


Epoch 28/50, Loss: 0.3157


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 29/50, Loss: 0.3137


100%|██████████| 3/3 [00:01<00:00,  1.62it/s]


Epoch 30/50, Loss: 0.3120


100%|██████████| 3/3 [00:01<00:00,  1.62it/s]


Epoch 31/50, Loss: 0.3099


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 32/50, Loss: 0.3080


100%|██████████| 3/3 [00:01<00:00,  1.62it/s]


Epoch 33/50, Loss: 0.3062


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 34/50, Loss: 0.3045


100%|██████████| 3/3 [00:01<00:00,  1.62it/s]


Epoch 35/50, Loss: 0.3029


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 36/50, Loss: 0.3012


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 37/50, Loss: 0.2999


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 38/50, Loss: 0.2981


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 39/50, Loss: 0.2969


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 40/50, Loss: 0.2953


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 41/50, Loss: 0.2937


100%|██████████| 3/3 [00:01<00:00,  1.62it/s]


Epoch 42/50, Loss: 0.2924


100%|██████████| 3/3 [00:01<00:00,  1.62it/s]


Epoch 43/50, Loss: 0.2911


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 44/50, Loss: 0.2899


100%|██████████| 3/3 [00:01<00:00,  1.62it/s]


Epoch 45/50, Loss: 0.2882


100%|██████████| 3/3 [00:01<00:00,  1.62it/s]


Epoch 46/50, Loss: 0.2867


100%|██████████| 3/3 [00:01<00:00,  1.62it/s]


Epoch 47/50, Loss: 0.2854


100%|██████████| 3/3 [00:01<00:00,  1.62it/s]


Epoch 48/50, Loss: 0.2841


100%|██████████| 3/3 [00:01<00:00,  1.63it/s]


Epoch 49/50, Loss: 0.2827


100%|██████████| 3/3 [00:01<00:00,  1.62it/s]

Epoch 50/50, Loss: 0.2814





In [12]:

def save_image(image, path):
    """Save a tensor as an image."""
    image = image.squeeze().cpu().numpy()
    plt.imsave(path, image, cmap='gray')

def test_model(model, test_loader, device='cuda'):
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Turn off gradients to speed up this part
        for i, (images, labels) in enumerate(test_loader):
            images = images.to(device)
            outputs = model(images)
            outputs = torch.sigmoid(outputs)  # Apply sigmoid to get values between 0 and 1
            outputs = outputs > 0.5  # Threshold the probabilities to create a binary mask
            
            # Save output images
            for idx, output in enumerate(outputs):
                save_path = f'e:/3.Experimental_Data/DL_Data_raw/tests_inference/002_ou_DongYing_{i*test_loader.batch_size + idx + 13635}_roi_selected.png'
                save_image(output, save_path)

            print(f'Processed batch {i+1}/{len(test_loader)}')

# Example of creating a test dataset and loader (similar to train_loader)
test_dataset = SoilSegmentationDataset('e:/3.Experimental_Data/DL_Data_raw/tests/', 'e:/3.Experimental_Data/DL_Data_raw/tests_labels/', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

In [13]:
# Test the model
test_model(model, test_loader)

Processed batch 1/5
Processed batch 2/5
Processed batch 3/5
Processed batch 4/5
Processed batch 5/5
