In [102]:
import torch
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
import albumentations as A
import torchvision.transforms.functional as TF
import torch.nn as nn
import os

device = torch.device('cpu')

In [103]:
from vgg19 import VGGUNET19
model = VGGUNET19()
checkpoint_path = "VGGUnet19_Segmentation_best.pth.tar"
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'), weights_only=True)
model.load_state_dict(checkpoint["model_state_dict"])

<All keys matched successfully>

In [104]:
train_transforms = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(p=0.5),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    A.Resize(512, 512),
])

In [105]:
# Define the CustomBBoxLoss class
class CustomBBoxLoss(nn.Module):
    def __init__(self):
        super(CustomBBoxLoss, self).__init__()
        self.mse_loss = nn.MSELoss()

    def normalize_color(self, color):
        """
        Normalize an RGBA color from [0, 255] to [0, 1].
        Args:
            color: Tuple (R, G, B, A) in [0, 255]
        Returns:
            Normalized tuple in [0, 1]
        """
        return tuple(c / 255.0 for c in color)

    def forward(self, pred, bbox_target_list):
        """
        Arguments:
            pred: Tensor of shape (1, C, H, W), the model's predicted output
            bbox_target_list: List of tuples containing bbox coordinates and target color.
                Format: [(((x1, y1), (x2, y2)), (R, G, B, A)), ...]

        Returns:
            loss: The computed loss
        """
        loss = 0.0

        for bbox, target_color in bbox_target_list:
            # Extract bounding box coordinates
            (x1, y1), (x2, y2) = bbox
            width, height = x2 - x1, y2 - y1

            # Extract the predicted region within the bounding box
            pred_region = TF.crop(pred, top=y1, left=x1, height=height, width=width)  # shape: (C, H, W)

            # Normalize the target color to [0, 1]
            target_color = self.normalize_color(target_color)

            # Ensure target_tensor matches the number of channels in pred_region
            num_channels = pred_region.shape[0]  # Get the number of channels in pred_region
            target_tensor = torch.tensor(
                target_color[:num_channels],  # Slice target_color to match num_channels
                dtype=pred.dtype,
                device=pred.device
            ).view(num_channels, 1, 1)
            target_tensor = target_tensor.expand_as(pred_region)

            # Compute MSE loss for the current bounding box
            loss += self.mse_loss(pred_region, target_tensor)

        return loss / len(bbox_target_list)


In [106]:

class PlanDataset(Dataset):
    def __init__(self, directory, transform=None):

        self.directory = directory
        self.transform = transform
        
        
        self.img_files = [img_file for img_file in os.listdir(directory) if img_file.endswith('Plan.jpg')]

    def __len__(self):
        return len(self.img_files)
    
    def __getitem__(self, index):

        selected_img_file = self.img_files[index]
        
       
        plan = Image.open(os.path.join(self.directory, selected_img_file)).convert('RGB')
        
        
        plan = np.array(plan).astype(np.float32)
        
        
        if self.transform is not None:
            transformed = self.transform(image=plan)
            plan = transformed['image']
        
        plan = torch.from_numpy(plan.copy().transpose((2, 0, 1)))  # (H, W, C) -> (C, H, W)
        
        return plan

In [107]:
dataset = PlanDataset(directory='train', transform=None)


In [108]:
train_dataloader = torch.utils.data.DataLoader(
    dataset, 
    batch_size=1, 
    shuffle=True, 
)

In [109]:
# Define bounding box target list
bbox_target_list = [
    (((264, 414), (317, 448)), (254, 81, 80, 255))  # RGBA values in [0, 255]
]

In [110]:
criterion = criterion = CustomBBoxLoss()


optimizer = torch.optim.Adam(
    model.parameters(),
    lr = 1e-4,
    betas = (0.9, 0.999), 
)

In [111]:
def train_start():
    model.train()
    for epoch in range(1):  # Single epoch for testing
        for idx, img_batch in enumerate(train_dataloader):
            img_batch = img_batch.to(device)  # img_batch already has batch dimension

            optimizer.zero_grad()

            pred = model(img_batch)  # Get model prediction

            loss = criterion(pred, bbox_target_list)  # Compute loss

            loss.backward()
            optimizer.step()

            print(f"Epoch {epoch + 1}, Batch {idx + 1}, Loss: {loss.item()}")
        
        

In [112]:
train_start()

Epoch 1, Batch 1, Loss: 0.08223015815019608


In [113]:
# Visualize output
def visualize_output(predictions, from_tensor=True):
    color_mapping = {
        0: (0, 0, 0),
        1: (255, 80, 80),
        2: (80, 80, 255),
        3: (255, 255, 255),
    }
    
    if from_tensor:
        predictions = torch.round(predictions).type(torch.LongTensor)
        predictions = predictions.squeeze(0).squeeze(0).numpy()
    
    height, width = predictions.shape
    colored_mask = np.zeros((height, width, 3), dtype=np.uint8)
    for label, color in color_mapping.items():
        colored_mask[predictions == label] = color

    return Image.fromarray(colored_mask)


# Generate function
def generate(image, model):
    device = torch.device('cpu')

    model = model.to(device)

    transform = transforms.Compose([transforms.ToTensor()])
    input_tensor = transform(image).unsqueeze(0).to(device)

    with torch.inference_mode():
        output = model(input_tensor)

    output_image_pil = visualize_output(output)
    
    return output_image_pil

In [114]:
image = Image.open('6_Plan.jpg')

In [115]:
output_image = generate(image, model)

In [116]:
output_image.show()