In [73]:
import torch
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import albumentations as A
import torch.nn as nn

device = torch.device('cpu')

In [74]:
from vgg19 import VGGUNET19
model = VGGUNET19()
checkpoint_path = "VGGUnet19_Segmentation_best.pth.tar"
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'), weights_only=True)
model.load_state_dict(checkpoint["model_state_dict"])

<All keys matched successfully>

In [75]:
train_transforms = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(p=0.5),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    A.Resize(512, 512),
])

In [76]:
import torch
from torch.utils.data import Dataset
from PIL import Image
import os
import numpy as np

class PlanDataset(Dataset):
    def __init__(self, directory, transform=None):
        """
        Args:
            directory (str): Path to the directory containing plan images.
            transform (callable, optional): Transformations to apply to the images.
        """
        self.directory = directory
        self.transform = transform
        
        # Collect all files ending with "Plan.jpg"
        self.img_files = [img_file for img_file in os.listdir(directory) if img_file.endswith('Plan.jpg')]

    def __len__(self):
        return len(self.img_files)
    
    def __getitem__(self, index):
        """
        Returns:
            torch.Tensor: The transformed plan image.
        """
        selected_img_file = self.img_files[index]
        
        # Load the plan image
        plan = Image.open(os.path.join(self.directory, selected_img_file)).convert('RGB')
        
        # Convert the image to a NumPy array
        plan = np.array(plan).astype(np.float32) / 255.0  # Normalize to [0, 1]
        
        # Apply transformations if specified
        if self.transform is not None:
            transformed = self.transform(image=plan)
            plan = transformed['image']
        
        # Convert to PyTorch tensor and adjust shape
        plan = torch.from_numpy(plan.transpose((2, 0, 1)))  # (H, W, C) -> (C, H, W)
        
        return plan

In [77]:
dataset = PlanDataset(directory='train', transform=None)


In [78]:
criterion = nn.MSELoss(reduction='none')

def loss_fn(pred, mask, alpha=0.01):
    
    loss_by_pixel = criterion(pred, mask)
    loss = loss_by_pixel.mean()
    
    return loss

optimizer = torch.optim.Adam(
    model.parameters(),
    lr = 1e-4,
    betas = (0.9, 0.999), 
)

In [79]:
### Dummy Data
__batch_size = 1
__in_channels = 3
__width = 512
__height = 512

dummy_input = torch.randn((__batch_size, __in_channels, __height, __width))
dummy_out = model(dummy_input)

print(f'Model input size: {dummy_input.shape}')
print(f'Model output size: {dummy_out.shape}')

Model input size: torch.Size([1, 3, 512, 512])
Model output size: torch.Size([1, 1, 512, 512])


In [80]:
img = dataset[0]
img = img.unsqueeze(0) 
img.shape

torch.Size([1, 3, 512, 512])

In [83]:
def train_start():
    model.to(device)
    
    for epoch in range (1):
        model.train()
        for idx, (img) in enumerate(dataset):
            img = img.to(device)
            img = img.unsqueeze(0) 

            optimizer.zero_grad()

            pred_mask = model(img)

            loss = criterion(pred_mask, pred_mask).mean() 

            loss.backward()

            optimizer.step()
        
        

In [84]:
train_start()