In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import numpy as np
from PIL import Image, UnidentifiedImageError
import os
import unet_model


# transform images to a standard size
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# Create dataset and dataloader
dataset = unet_model.LineDataset('thick_lines_synthetic/', 'thin_lines_synthetic/', transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

# Split the dataset
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoaders for both splits
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=10, shuffle=False)

In [None]:
# Initialize model, loss, and optimizer
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

model = unet_model.UNet().to(device)
model_path = 'unet_line_thinning_model.pth' 
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters())
train_loss_list = [100]
val_loss_list = [100]
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path, map_location=device))
    with torch.no_grad():
            val_losses=[]
            for (val_data, val_target) in val_loader:
                val_data, val_target = val_data.to(device), val_target.to(device)
                val_output = model(val_data)
                val_losses.append(criterion(val_output, val_target).item())
            val_loss = np.mean(val_losses)
            val_loss_list[0] = val_loss


# Training loop
num_epochs = 1

for epoch in range(1, num_epochs + 1):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        train_loss = criterion(output, target)
        train_loss_list.append(train_loss.item())
        train_loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            val_losses=[]
            for (val_data, val_target) in val_loader:
                val_data, val_target = val_data.to(device), val_target.to(device)
                val_output = model(val_data)
                val_losses.append(criterion(val_output, val_target).item())
            val_loss = np.mean(val_losses)
        
        if val_loss < min(val_loss_list):
           torch.save(model.state_dict(), model_path)
        val_loss_list.append(val_loss )
        if batch_idx % 10 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {train_loss.item():.6f} ' 
                  f'  validation loss: {val_loss:.6f}') 

In [None]:
import matplotlib.pyplot as plt

def load_model(model_path):

    device = torch.device('cpu')
    model = unet_model.UNet().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()  # Set the model to evaluation mode
    return model

# Function to preprocess the input image
def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])
    image = Image.open(image_path).convert('L')  # Convert to grayscale
    image = transform(image).unsqueeze(0)  # Add batch dimension
    return image

# Function to perform inference
def predict(model, image):
    with torch.no_grad():
        output = model(image)
    return output

# Function to post-process the output
def postprocess_output(output):
    output = torch.sigmoid(output)  # Apply sigmoid to get probability map
    output = output.squeeze().cpu().numpy()  # Remove batch dimension and convert to numpy array
    return (output > 0.95).astype(np.uint8) * 255  # Threshold and convert to binary image


 # Path to your saved model
image_path = 'nature_imag_cropped.png'  # Path to a test image
#image_path = 'thick_lines/run_001.png'  # Path to a test image


# Load the model
model_path = 'unet_line_thinning_model.pth'  #shoule be the same as defined in the training block
model = load_model(model_path)

# Preprocess the image
input_image = preprocess_image(image_path)


# Perform inference
output = predict(model, input_image)
output.shape
plt.imshow( output.squeeze(0).squeeze(0),cmap='gray')

# Post-process the output
result = postprocess_output(output)
# # Save or display the result
# result_image = Image.fromarray(result)
# plt.imshow(result_image)

# result_image.save('thinned_line_result.png')
# # print("Thinned line image saved as 'thinned_line_result.png'")


In [None]:
plt.imshow(input_image.squeeze(0).squeeze(0),cmap='gray')


In [None]:
plt.imshow(result,cmap='gray')


In [None]:
hand_drawn_line =  preprocess_image('hand_drawn_cropped.png')

In [None]:
plt.imshow(hand_drawn_line.squeeze(0).squeeze(0),cmap='gray')


In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 6))
im1 = input_image.squeeze(0).squeeze(0)
im2 = result
im3 = hand_drawn_line.squeeze(0).squeeze(0)
# Display the first image
im1_display = ax1.imshow(im1, cmap='gray' if im1.ndim == 2 else None)
#ax1.axis('off')  # Hide axes

# Display the second image
im2_display = ax2.imshow(im2, cmap='gray' if im2.ndim == 2 else None)
#ax2.axis('off')  # Hide axes

im2_display = ax3.imshow(im3, cmap='gray' if im2.ndim == 2 else None)

# Add colorbars if the images are 2D (likely grayscale)
# if im1.ndim == 2:
#     fig.colorbar(im1_display, ax=ax1)
# if im2.ndim == 2:
#     fig.colorbar(im2_display, ax=ax2)

# Adjust the layout and display the plot
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def save_three_arrays_side_by_side(im1, im2, im3, save_path, titles=None, fig_size=(18, 6), dpi=300):
    # Create a figure with three subplots side by side
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=fig_size)
    
    # Function to display a single image
    def display_image(ax, im, title):
        im_display = ax.imshow(im, cmap='gray' if im.ndim == 2 else None)
        #ax.axis('off')  # Hide axes
        if title:
            ax.set_title(title)
        #if im.ndim == 2:
            #fig.colorbar(im_display, ax=ax)
    
    # Display the three images
    display_image(ax1, im1, titles[0] if titles else None)
    display_image(ax2, im2, titles[1] if titles else None)
    display_image(ax3, im3, titles[2] if titles else None)
    
    # Adjust the layout
    plt.tight_layout()
    
    # Save the figure
    plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
    
    # Close the figure to free up memory
    plt.close(fig)
    
    print(f"Side-by-side image saved to {save_path}")

# Example usage
# im1, im2, and im3 are your NumPy arrays
save_three_arrays_side_by_side(im1, im3, im2, 'three_images_comparison.png', titles=['original','ground truth',
                                                                                     'unet-processed' ])

In [None]:
im4=output.squeeze(0).squeeze(0)

In [None]:
save_three_arrays_side_by_side(im1, im3, im4, 'three_images_comparison_unfiltered.png', titles=['original','ground truth',
                                                                                     'unet-processed' ])