WE BROKE SOMETHING IN THE TRAINING OF THE POINT FINDER.


In [None]:
# import statements
import torch
import utils
import config
import matplotlib.pyplot as plt
from torchvision.models import resnet34
from torch.utils.data import DataLoader
from utils import normalize_brightness, histogram_eq_global
from scipy.fft import fft2, ifft2, fftshift
import cv2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
size = 256
import torch.nn as nn

In [None]:
def filter_image(img_path, radius):
    # Check if the image file exists
    if not os.path.exists(img_path):
        raise FileNotFoundError(f"Image file not found: {img_path}")

    # Load and convert image to double precision (range [0, 1])
    img = cv2.imread(img_path)
    if img is None:
        raise ValueError(f"Unable to read the image file: {img_path}")

    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0

    # Apply histogram equalization
    img_eq = histogram_eq_global((img_rgb * 255).astype(np.uint8)) / 255.0

    # Perform FFT on the V channel of the equalized image
    v_channel = img_eq[:, :, 2]
    img_f = fft2(v_channel)

    # Create circular mask in the frequency domain
    height, width = v_channel.shape
    y, x = np.ogrid[:height, :width]
    center_y, center_x = height // 2, width // 2
    mask = np.ones((height, width), dtype=np.float32)
    mask[(y - center_y)**2 + (x - center_x)**2 < radius**2] = 0

    # Apply the mask in the frequency domain
    img_f_filtered = img_f * fftshift(mask)

    # Perform inverse FFT to get the filtered image
    f_ed = np.abs(ifft2(img_f_filtered))

    # Resize the original RGB image and the filtered image to (256, 256)
    img_rgb_resized = cv2.resize(img_rgb, (256, 256))
    f_ed_resized = cv2.resize(f_ed, (256, 256))

    # Stack the RGB image and the filtered image into a tensor
    combined_image = np.concatenate([img_rgb_resized, f_ed_resized[:, :, np.newaxis]], axis=-1)

    # Convert to a PyTorch tensor
    combined_tensor = torch.tensor(combined_image, dtype=torch.float32)

    return combined_tensor


In [None]:
import os
import json
import torch
import pandas as pd
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from tqdm import tqdm

class StevenCustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.data = self.data.dropna(subset=["kp-1"])
        self.imgSize = size
        self.pointImages = torch.empty((len(self.data), 3, self.imgSize, self.imgSize)).to(device)
        self.images = torch.empty((len(self.data), 4, self.imgSize, self.imgSize)).to(device)

        for idx in tqdm(range(len(self.data)), desc="Processing items"):
            img_path = config.DATASET_PATH + utils.get_img_path(self.data.iloc[idx]["image"])
            
            coordinates = self.__get_coordinates__(idx)

            self.images[idx] = filter_image(img_path, 20).permute(2, 0, 1).float().to(device)

            

            # Convert coordinates and labels to tensors
            coordinates = torch.tensor(coordinates, dtype=torch.float32)
            labels = coordinates[:, 2].long()  # Labels for each keypoint
            
            x_scaled = coordinates[:, 0] / 100 * self.imgSize  # X positions scaled to image size
            y_scaled = coordinates[:, 1] / 100 * self.imgSize  # Y positions scaled to image size

            # Create a grid of pixel coordinates (image_size x image_size)
            x_grid, y_grid = torch.meshgrid(torch.arange(self.imgSize), torch.arange(self.imgSize), indexing='ij')
            x_grid, y_grid = x_grid.float(), y_grid.float()

            # Calculate distances to each coordinate point in a vectorized manner
            distances = torch.sqrt((x_grid.unsqueeze(0) - x_scaled.view(-1, 1, 1)) ** 2 + 
                                   (y_grid.unsqueeze(0) - y_scaled.view(-1, 1, 1)) ** 2)

            # Find the closest coordinate for each pixel
            closest_distances, closest_indices = torch.min(distances, dim=0)
            closest_labels = labels[closest_indices]

            # Apply Gaussian function to closest distances
            gaussian_values = torch.clamp(self.gaussian(x=0.0+closest_distances/255, sig=0.02), max=1.0)

            # Assign values to the pointImages tensor at the specific channel for each label
            self.pointImages[idx, 0, torch.arange(self.imgSize), torch.arange(self.imgSize).view(-1, 1)] = gaussian_values.to(device)
            #self.pointImages[idx, 1, torch.arange(self.imgSize), torch.arange(self.imgSize).view(-1, 1)] = gaussian_values.to(device)
            #self.pointImages[idx, 2, torch.arange(self.imgSize), torch.arange(self.imgSize).view(-1, 1)] = gaussian_values.to(device)

            #self.pointImages[idx, 2, :, :] = gaussian_values.to(device).permute(1, 0)

            #print("max",max)

    def __len__(self):
        return len(self.data)

    def gaussian(self, x, mu=0, sig=0.02):
        return torch.exp(-((x - mu) ** 2) / (2 * (sig ** 2)))

    def __getitem__(self, idx):
        img = self.images[idx]
        pointImg = self.pointImages[idx]
        #max = torch.max(pointImg)
        #print("max at getItem",max)
        return img, pointImg

    def __get_label__(self, idx):
        img_metadata_str = str(self.data.iloc[idx].get("kp-1", "nan"))
        img_metadata = json.loads(img_metadata_str)

        contact_labels = [
            metadata.get("keypointlabels") == ["Contact"]
            for metadata in img_metadata if isinstance(metadata, dict)
        ]
        return int(all(contact_labels))

    def __get_coordinates__(self, idx):
        img_metadata_str = self.data.iloc[idx]["kp-1"]
        img_metadata = json.loads(img_metadata_str) if img_metadata_str else []

        if len(img_metadata) == 0:
            print(f"No metadata found for index {idx}")
            return []

        coordinates = [
            [metadata.get('x'), metadata.get('y'), 0 if metadata.get('keypointlabels') == ['Contact'] else 1]
            for metadata in img_metadata if isinstance(metadata, dict)
        ]
        return coordinates


In [None]:
class features_net(nn.Module):
    def __init__(self):
        super().__init__()

        #The initial convolutional layer
        self.preparation_for_resnet = nn.Sequential(
            nn.Conv2d(in_channels = 4, out_channels = 3, kernel_size = 3, stride = 1, padding = 1),
            nn.BatchNorm2d(3),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=1)
        ) 
        
        #The resnet part of the network
        self.resnet = resnet34()
        self.resnet = nn.Sequential(*(list(self.resnet.children())[:-3]))

        #The "upsampling" part of the network
        self.geo_net = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=16, out_channels=8, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(inplace=True),
            
            nn.UpsamplingBilinear2d(scale_factor=4),
            nn.Conv2d(in_channels=8, out_channels=3, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(3),
            nn.ReLU(inplace=True),
            nn.Conv2d(3, out_channels=3, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(3),
            nn.ReLU(inplace=True),
            
        )
    
    '''  '''  
    
    #img is 3x64x64 and depth is 1x64x64
    def forward(self, img):
        x = self.preparation_for_resnet(img)
        x = self.resnet(x)
        #print("x after resnet", x.size())
        x = self.geo_net(x)
        #print("x after geonet", x.size())
        #x = x.repeat(1, 3, 1, 1)
        return x

class geom_loss(torch.nn.Module):
    def __init__(self):
        super(geom_loss, self).__init__()

    def forward(self, pred, target):
        # Calculate the loss for the three different outputs
        
        L= torch.norm( (pred[:,0,:,:] - target[:,0,:,:]), p=1) #
        # Sum the losses (if you want to only supervise one output, you can remove the other losses from the sum)
        return L

In [None]:
from torchvision import transforms


csv_file = config.EXPORT3
assert os.path.exists(csv_file)
df = pd.read_csv(csv_file)

train=df.sample(frac=0.8,random_state=200)
test=df.drop(train.index)

batchSize = 32

# Modify the dataset to apply the transformation
trainDataset = StevenCustomDataset(data = train)
train_loader = DataLoader(trainDataset, batch_size=batchSize, shuffle=False)
model = features_net().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
geoLoss = geom_loss()

epochs = 20

lossMemory = []

# Print the first batch
for epoch in range(epochs):
    for bindex, (batch) in enumerate(tqdm(train_loader, desc="Training batches")):
        images, pointImages = batch
        #print("images",images.size(), "pointImages",pointImages.size())
        optimizer.zero_grad()
        pred = model(images)
        #print("pred", pred.size())
        loss = geoLoss(pred, pointImages)
        loss.backward()
        optimizer.step()

        #print("loss",loss/(size*size*batchSize))
        lossMemory.append(loss.detach().cpu()/(size*size*len(images)))
        if epoch % 3 == 0 and bindex == 0:
            images = images[:,0:3,:,:]
            fig, axs = plt.subplots(5, 4)
            axs[0,0].imshow(images[8].permute(1, 2, 0).cpu().numpy())
            axs[0,1].imshow(pointImages[8].permute(1, 2, 0).cpu())
            axs[0,2].imshow(pred[8].permute(1, 2, 0).detach().cpu().numpy())
            axs[1,0].imshow(images[11].permute(1, 2, 0).cpu().numpy())
            axs[1,1].imshow(pointImages[11].permute(1, 2, 0).cpu())
            axs[1,2].imshow(pred[11].permute(1, 2, 0).detach().cpu().numpy())
            axs[2,0].imshow(images[25].permute(1, 2, 0).cpu().numpy())
            axs[2,1].imshow(pointImages[25].permute(1, 2, 0).cpu())
            axs[2,2].imshow(pred[25].permute(1, 2, 0).detach().cpu().numpy())
            axs[3,0].imshow(images[12].permute(1, 2, 0).cpu().numpy())
            axs[3,1].imshow(pointImages[12].permute(1, 2, 0).cpu())
            axs[3,2].imshow(pred[12].permute(1, 2, 0).detach().cpu().numpy())
            axs[4,0].imshow(images[7].permute(1, 2, 0).cpu().numpy())
            axs[4,1].imshow(pointImages[7].permute(1, 2, 0).float().cpu())
            axs[4,2].imshow(pred[7].permute(1, 2, 0).detach().cpu().numpy())
            axs[4][3].plot(range(len(lossMemory)), lossMemory)
            #print("max7", torch.max(pointImages[7]))
            plt.show()
        


In [None]:
test_dataset = StevenCustomDataset(data = test)
test_loader = DataLoader(test_dataset, batch_size=5, shuffle=False)
model = model.to(device)

with torch.no_grad():
    for bindex, (batch) in enumerate(tqdm(test_loader, desc="Testing batches")):
        images, pointImages = batch
        pred = model(images)
        loss = geoLoss(pred, pointImages)
        #print("loss",loss/(size*size*batchSize))
        pred = pred[1][0].repeat(3,1,1)

        # Parameters for max pooling
        kernel_size = 25
        stride = 1
        padding = (kernel_size - 1) // 2  # Set padding to ensure output size matches input size
        
        # Perform max pooling
        pooled = torch.nn.functional.max_pool2d(pred, kernel_size=kernel_size, stride=stride, padding=padding)

        # Identify locations where the pooled tensor matches the original tensor
        local_maxima_mask = (pred == pooled) & (pred > 0)
        local_maxima_indices = torch.nonzero(local_maxima_mask)
        local_maxima_values = pred[local_maxima_mask]


        #print("Coordinates of local maxima:", local_maxima_indices)
        #print("Values of local maxima:", local_maxima_values)

        images = images[1].permute(1, 2, 0).cpu().numpy()[:,:,0:3]  # Convert `images` to NumPy format for manipulation

        for index, coordinate in enumerate(local_maxima_indices):
            if local_maxima_values[index] > 0.2:
                #print(coordinate)
                # Convert `coordinate` to a NumPy array
                coordinate_np = coordinate.cpu().numpy().astype(int)
                
                # Mark the local maximum on the image in red
                images[coordinate_np[1], coordinate_np[2], 0] = 255
        print(pred.size())
        # Plotting images if bindex is divisible by 1 (every iteration)
        if bindex % 1 == 0:
            fig, axs = plt.subplots(1, 2, figsize=(12, 10))
            axs[0].imshow(images)
            axs[1].imshow(pred.permute(1,2,0).detach().cpu().numpy())
            axs[1].set_title("Point Image")

            # Ensure `pointImages[0]` is also converted for display
            
            plt.tight_layout()
            plt.show()
            '''axs[1,0].imshow(images[11].permute(1, 2, 0).cpu().numpy()/255)
            axs[1,1].imshow(pointImages[11].permute(1, 2, 0).cpu())
            axs[1,2].imshow(pred[11].permute(1, 2, 0).detach().cpu().numpy())
            axs[2,0].imshow(images[25].permute(1, 2, 0).cpu().numpy()/255)
            axs[2,1].imshow(pointImages[25].permute(1, 2, 0).cpu())
            axs[2,2].imshow(pred[25].permute(1, 2, 0).detach().cpu().numpy())
            axs[3,0].imshow(images[12].permute(1, 2, 0).cpu().numpy()/255)
            axs[3,1].imshow(pointImages[12].permute(1, 2, 0).cpu())
            axs[3,2].imshow(pred[12].permute(1, 2, 0).detach().cpu().numpy())
            axs[4,0].imshow(images[7].permute(1, 2, 0).cpu().numpy()/255)
            axs[4,1].imshow(pointImages[7].permute(1, 2, 0).float().cpu())
            axs[4,2].imshow(pred[7].permute(1, 2, 0).detach().cpu().numpy())
            axs[4][3].plot(range(len(lossMemory)), lossMemory)'''
            #plt.show()


model = model.cpu()
#torch.save(model.state_dict(), "models/gaussian_points_finder.pth")
