In [1]:
!pip install torch torchvision matplotlib numpy

Collecting torch
  Downloading torch-2.5.1-cp311-cp311-win_amd64.whl.metadata (28 kB)
Collecting torchvision
  Downloading torchvision-0.20.1-cp311-cp311-win_amd64.whl.metadata (6.2 kB)
Collecting matplotlib
  Downloading matplotlib-3.9.3-cp311-cp311-win_amd64.whl.metadata (11 kB)
Collecting numpy
  Downloading numpy-2.2.0-cp311-cp311-win_amd64.whl.metadata (60 kB)
Collecting filelock (from torch)
  Using cached filelock-3.16.1-py3-none-any.whl.metadata (2.9 kB)
Collecting networkx (from torch)
  Downloading networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting jinja2 (from torch)
  Using cached jinja2-3.1.4-py3-none-any.whl.metadata (2.6 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2024.10.0-py3-none-any.whl.metadata (11 kB)
Collecting sympy==1.13.1 (from torch)
  Using cached sympy-1.13.1-py3-none-any.whl.metadata (12 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy==1.13.1->torch)
  Downloading mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Collecting pillow!=8.3.*,>

In [2]:
!pip install pillow



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import os
from PIL import Image

class MVSTransformer(nn.Module):
    def __init__(self, num_features=16, num_heads=4, depth=4):
        super(MVSTransformer, self).__init__()

        # Image feature extraction (U-Net inspired structure)
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, num_features, kernel_size=3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(num_features, num_features, kernel_size=3, padding=1),
            nn.LeakyReLU(),
        )

        # Transformer block
        self.transformer_blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=num_features, nhead=num_heads, dim_feedforward=num_features * 4
            )
            for _ in range(depth)
        ])

        # Projection layer to align dimensions
        self.projector = nn.Linear(48, num_features)

        # Cost volume regularization
        self.cost_volume_regularizer = nn.Sequential(
            nn.Conv3d(num_features, num_features, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv3d(num_features, 1, kernel_size=3, padding=1)
        )

        # Depth regression
        self.depth_regression = nn.Softmax(dim=1)

    def forward(self, ref_image, src_images, camera_params):
        ref_features = self.feature_extractor(ref_image)
        src_features = [self.feature_extractor(src) for src in src_images]

        # Fuse features from multiple views
        cost_volume = self.construct_cost_volume(ref_features, src_features, camera_params)

        # Transformer-based refinement
        B, D, H, W = cost_volume.shape
        flattened = cost_volume.view(B, D, -1).permute(2, 0, 1)  # Flatten for Transformer
        projected = self.projector(flattened)  # Align dimensions
        for block in self.transformer_blocks:
            projected = block(projected)
        refined_cost = projected.permute(1, 2, 0).view(B, D, H, W)  # Reshape back

        # Regularize cost volume and predict depth
        prob_volume = self.cost_volume_regularizer(refined_cost)
        prob_volume = self.depth_regression(prob_volume)

        depth = self.compute_depth(prob_volume)
        return depth

    def construct_cost_volume(self, ref_features, src_features, camera_params):
        # Construct a cost volume using differentiable homography
        B, C, H, W = ref_features.shape
        D = 48  # Number of depth hypotheses
        cost_volume = torch.zeros((B, D, H, W), device=ref_features.device)

        for d in range(D):
            warped_src = self.warp_source_to_ref(src_features, camera_params, d)
            cost_volume[:, d, :, :] = F.l1_loss(ref_features, warped_src, reduction='none').mean(dim=1)

        return cost_volume

    def warp_source_to_ref(self, src_features, camera_params, depth):
        # Implement differentiable homography-based warping
        # Placeholder: Replace with actual warping logic
        return src_features[0]  # Dummy implementation

    def compute_depth(self, prob_volume):
        D = prob_volume.size(1)
        depth_values = torch.arange(D, device=prob_volume.device).view(1, D, 1, 1)
        depth = torch.sum(prob_volume * depth_values, dim=1)
        return depth

    def visualize_3d(self, depth_map):
        # Visualize the depth map as a 3D reconstruction
        B, H, W = depth_map.shape
        depth_map = depth_map[0].cpu().detach().numpy()

        x = np.arange(0, W)
        y = np.arange(0, H)
        x, y = np.meshgrid(x, y)

        fig = plt.figure(figsize=(10, 7))
        ax = fig.add_subplot(111, projection='3d')
        ax.plot_surface(x, y, depth_map, cmap='viridis')
        plt.show()

# Example usage
if __name__ == "__main__":
    model = MVSTransformer()

    # Directory containing images
    image_dir = "ProcessedDates"  # Replace with the actual directory path
    image_prefix = "KD3001"  # Replace with the prefix of the images to load

    # Load images from the directory
    image_files = [
        os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.startswith(image_prefix)
    ]
    images = [
        transforms.ToTensor()(Image.open(img)) for img in image_files
    ]

    # Ensure we have at least one reference and several source images
    ref_image = images[0].unsqueeze(0)  # Add batch dimension
    src_images = [img.unsqueeze(0) for img in images[1:]]

    camera_params = None  # Placeholder for camera parameters

    depth_map = model(ref_image, src_images, camera_params)
    print("Depth map shape:", depth_map.shape)

    # Visualize the 3D reconstruction
    model.visualize_3d(depth_map)
