<a href="https://colab.research.google.com/github/Husseinhhameed/CoHAtNet/blob/main/CoHAtNet_with_Homography_loss_function.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CoHAtNet with 7Scenes dataset

In [None]:
# Mount Google Drive to access the dataset stored there

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Install required packages:
1. 'transforms3d' for quaternion manipulations
2. 'einops' for tensor manipulation and rearranging operations
3. 'kornia' for geometry conversions and other computer vision tasks

In [None]:
!pip install transforms3d

Collecting transforms3d
  Downloading transforms3d-0.4.2-py3-none-any.whl.metadata (2.8 kB)
Downloading transforms3d-0.4.2-py3-none-any.whl (1.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: transforms3d
Successfully installed transforms3d-0.4.2


In [None]:
pip install einops



In [None]:
pip install kornia

Collecting kornia
  Downloading kornia-0.7.3-py2.py3-none-any.whl.metadata (7.7 kB)
Collecting kornia-rs>=0.1.0 (from kornia)
  Downloading kornia_rs-0.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.7 kB)
Downloading kornia-0.7.3-py2.py3-none-any.whl (833 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m833.3/833.3 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading kornia_rs-0.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m54.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: kornia-rs, kornia
Successfully installed kornia-0.7.3 kornia-rs-0.1.5


In [None]:
# Import all necessary libraries and modules for data processing, model creation, and training

import os
import time
import random
import numpy as np
import pandas as pd
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
from torchvision import transforms
from torchvision.transforms import Compose, RandomHorizontalFlip, RandomRotation, ToTensor
from torchvision.transforms.functional import to_tensor
from torchvision.utils import make_grid
from einops import rearrange
from einops.layers.torch import Rearrange
from sklearn.model_selection import KFold
from tqdm import tqdm
import transforms3d.quaternions as txq
import matplotlib.pyplot as plt
from kornia.geometry.conversions import quaternion_to_rotation_matrix
import pickle

  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)


In [None]:
# Define the Global Homography Loss Function
class GlobalHomographyLoss(torch.nn.Module):
    def __init__(self, xmin, xmax, device='cpu'):
        """
        `xmin` is the minimum distance of observations across all frames.
        `xmax` is the maximum distance of observations across all frames.
        """
        super().__init__()

        # `xmin` is the minimum distance of observations in all frames
        xmin = torch.tensor(xmin, dtype=torch.float32, device=device)

        # `xmax` is the maximum distance of observations in all frames
        xmax = torch.tensor(xmax, dtype=torch.float32, device=device)

        # `B_weight` and `C_weight` are the weights of matrices A and B computed from `xmin` and `xmax`
        self.B_weight = torch.log(xmin / xmax) / (xmax - xmin)
        self.C_weight = xmin * xmax

        # `c_n` is the normal vector of the plane inducing the homographies in the ground-truth camera frame
        self.c_n = torch.tensor([0, 0, -1], dtype=torch.float32, device=device).view(3, 1)

        # `eye` is the (3, 3) identity matrix
        self.eye = torch.eye(3, device=device)

    def forward(self, batch):
        A, B, C = compute_ABC(batch['w_t_c'], batch['c_R_w'], batch['w_t_chat'], batch['chat_R_w'], self.c_n, self.eye)

        error = A + B * self.B_weight + C / self.C_weight
        error = error.diagonal(dim1=1, dim2=2).sum(dim=1).mean()
        return error

In [None]:
def compute_ABC(w_t_c, c_R_w, w_t_chat, chat_R_w, c_n, eye):
    """
    Computes A, B, and C matrix given estimated and ground truth poses
    and normal vector n.
    `w_t_c` and `w_t_chat` must have shape (batch_size, 3, 1).
    `c_R_w` and `chat_R_w` must have shape (batch_size, 3, 3).
    `n` must have shape (3, 1).
    `eye` is the (3, 3) identity matrix on the proper device.
    """
    # Ensure all inputs are float32
    w_t_c = w_t_c.float()
    c_R_w = c_R_w.float()
    w_t_chat = w_t_chat.float()
    chat_R_w = chat_R_w.float()
    c_n = c_n.float()
    eye = eye.float()

    chat_t_c = chat_R_w @ (w_t_c - w_t_chat)
    chat_R_c = chat_R_w @ c_R_w.transpose(1, 2)

    A = eye - chat_R_c
    C = c_n @ chat_t_c.transpose(1, 2)
    B = C @ A
    A = A @ A.transpose(1, 2)
    B = B + B.transpose(1, 2)
    C = C @ C.transpose(1, 2)

    return A, B, C


In [None]:
def convert_to_transformation_matrix(translation, quaternion):
    """
    Converts a translation vector and quaternion to a 4x4 transformation matrix.
    """
    # Convert quaternion to rotation matrix
    R = quaternion_to_rotation_matrix(quaternion)

    # Create a 4x4 transformation matrix
    T = torch.eye(4)  # Initialize as 4x4 identity matrix
    T[:3, :3] = R  # Set the top-left 3x3 submatrix as the rotation matrix
    T[:3, 3] = translation  # Set the top-right 3x1 subvector as the translation vector

    return T


In [None]:
def load_and_process_poses(data_dir, seqs, train=True, real=False, vo_lib='orbslam'):
    """
    Loads and processes pose data from the given directory.

    Args:
        data_dir (str): The root directory containing the dataset.
        seqs (list): List of sequences to load poses from.
        train (bool): If True, compute and save statistics; otherwise, load them.
        real (bool): If True, use real-world poses; otherwise, use synthetic ones.
        vo_lib (str): Visual odometry library used ('orbslam', 'libviso2', etc.).

    Returns:
        np.ndarray: Processed and normalized poses.
    """
    ps = {}  # Dictionary to store pose data for each sequence
    vo_stats = {}  # Dictionary to store VO statistics for each sequence
    all_poses = []  # List to collect all pose data

    for seq in seqs:
        seq_dir = os.path.join(data_dir, seq)
        p_filenames = [n for n in os.listdir(seq_dir) if n.find('pose') >= 0]

        if real:
            # Load poses from a real-world dataset
            pose_file = os.path.join(data_dir, f'{vo_lib}_poses', seq)
            pss = np.loadtxt(pose_file)
            frame_idx = pss[:, 0].astype(int)
            if vo_lib == 'libviso2':
                frame_idx -= 1
            ps[seq] = pss[:, 1:13]
            vo_stats_filename = os.path.join(seq_dir, f'{vo_lib}_vo_stats.pkl')
            with open(vo_stats_filename, 'rb') as f:
                vo_stats[seq] = pickle.load(f)
        else:
            # Load poses from synthetic dataset
            frame_idx = np.array(range(len(p_filenames)), dtype=int)
            pss = [np.loadtxt(os.path.join(seq_dir, f'frame-{i:06d}.pose.txt')).flatten()[:12]
                   for i in frame_idx
                   if os.path.exists(os.path.join(seq_dir, f'frame-{i:06d}.pose.txt'))]
            ps[seq] = np.asarray(pss)
            vo_stats[seq] = {'R': np.eye(3), 't': np.zeros(3), 's': 1}  # Default VO stats for synthetic data

        all_poses.append(ps[seq])

    all_poses = np.vstack(all_poses)
    pose_stats_filename = os.path.join(data_dir, 'pose_stats.txt')

    # Compute or load statistics for normalization
    if train and not real:
        mean_t = np.mean(all_poses[:, [3, 7, 11]], axis=0)
        std_t = np.std(all_poses[:, [3, 7, 11]], axis=0)
        np.savetxt(pose_stats_filename, np.vstack((mean_t, std_t)), fmt='%8.7f')
    else:
        mean_t, std_t = np.loadtxt(pose_stats_filename)

    # Process and normalize poses
    processed_poses = []
    for seq in seqs:
        pss = process_poses(poses_in=ps[seq], mean_t=mean_t, std_t=std_t,
                            align_R=vo_stats[seq]['R'], align_t=vo_stats[seq]['t'],
                            align_s=vo_stats[seq]['s'])
        processed_poses.append(pss)

    return np.vstack(processed_poses)


def process_poses(poses_in, mean_t, std_t, align_R, align_t, align_s):
    """
    Aligns and normalizes poses.

    Args:
        poses_in (np.ndarray): Input poses to process.
        mean_t (np.ndarray): Mean translation for normalization.
        std_t (np.ndarray): Standard deviation of translation for normalization.
        align_R (np.ndarray): Rotation matrix for alignment.
        align_t (np.ndarray): Translation vector for alignment.
        align_s (float): Scaling factor for alignment.

    Returns:
        np.ndarray: Processed poses.
    """
    poses_out = np.zeros((len(poses_in), 7))
    poses_out[:, 0:3] = poses_in[:, [3, 7, 11]]  # Translation components

    # Align and process rotation
    for i in range(len(poses_out)):
        R = poses_in[i].reshape((3, 4))[:3, :3]
        q = txq.mat2quat(np.dot(align_R, R))
        q *= np.sign(q[0])  # Constrain to hemisphere
        poses_out[i, 3:] = q  # Keep the quaternion as 4D
        t = poses_out[i, :3] - align_t
        poses_out[i, :3] = align_s * np.dot(align_R, t[:, np.newaxis]).squeeze()

    # Normalize translation
    poses_out[:, :3] -= mean_t
    poses_out[:, :3] /= std_t

    return poses_out

In [None]:
# Define the transformation for the images
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

## Data Loader

In [None]:
import os
import json
import numpy as np
import cv2
from torch.utils.data import Dataset
from tqdm import tqdm
import random

class FireDataset(Dataset):
    def __init__(self, root_dir, xmin_percentile=0.025, xmax_percentile=0.975, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.xmin_percentile = xmin_percentile
        self.xmax_percentile = xmax_percentile

        # Load sequences and samples
        self.seqs = [seq for seq in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, seq))]
        self.samples = self._load_samples()

        # Check for precomputed depth statistics
        stats_file = os.path.join(root_dir, 'depth_stats.json')
        if os.path.exists(stats_file):
            with open(stats_file, 'r') as f:
                stats = json.load(f)
                self.global_xmin = stats['global_xmin']
                self.global_xmax = stats['global_xmax']
        else:
            # Compute depth statistics using a subset
            self.global_depths = self._compute_global_depths()
            self.global_xmin = np.percentile(self.global_depths, self.xmin_percentile * 100)
            self.global_xmax = np.percentile(self.global_depths, self.xmax_percentile * 100)
            with open(stats_file, 'w') as f:
                json.dump({'global_xmin': self.global_xmin, 'global_xmax': self.global_xmax}, f)
            del self.global_depths  # Free memory

        # Load and process poses
        self.processed_poses = self._load_processed_poses()

        # Ensure consistency between samples and processed poses
        min_length = min(len(self.samples), self.processed_poses.shape[0])
        self.samples = self.samples[:min_length]
        self.processed_poses = self.processed_poses[:min_length]

    def _load_samples(self):
        """
        Loads all sample file paths from the dataset directory.
        """
        samples = []
        for seq_folder in self.seqs:
            seq_path = os.path.join(self.root_dir, seq_folder)
            color_files = sorted([f for f in os.listdir(seq_path) if f.endswith('.color.png')])
            depth_files = sorted([f for f in os.listdir(seq_path) if f.endswith('.depth.png')])
            pose_files = sorted([f for f in os.listdir(seq_path) if f.endswith('.pose.txt')])

            for color_file, depth_file, pose_file in zip(color_files, depth_files, pose_files):
                samples.append((os.path.join(seq_path, color_file),
                                os.path.join(seq_path, depth_file),
                                os.path.join(seq_path, pose_file)))
        return samples

    def _compute_global_depths(self):
        """
        Computes global depth statistics by loading a subset of depth images.
        """
        num_samples = len(self.samples)
        sample_indices = random.sample(range(num_samples), min(100, num_samples))  # Use a subset of images
        all_depths = []

        for idx in tqdm(sample_indices, desc="Computing global depths"):
            _, depth_path, _ = self.samples[idx]
            depth_image = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
            valid_depths = depth_image[depth_image > 0]
            all_depths.extend(valid_depths.flatten())

        all_depths = np.array(all_depths)
        return all_depths

    def _load_processed_poses(self):
        """
        Loads and processes pose data from the given directory.
        """
        return load_and_process_poses(self.root_dir, self.seqs)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if idx >= len(self.processed_poses):
            raise IndexError(f"Index {idx} out of bounds for processed poses of size {len(self.processed_poses)}")

        color_path, depth_path, _ = self.samples[idx]

        color_image = cv2.imread(color_path, cv2.IMREAD_COLOR)
        depth_image = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
        pose_matrix = self.processed_poses[idx]

        if self.transform:
            color_image = self.transform(color_image)
            depth_image = (depth_image / depth_image.max() * 255).astype(np.uint8)
            depth_image = self.transform(depth_image)

        return color_image, depth_image, pose_matrix


## CoHAtNet Model

In [None]:
from einops import rearrange
from einops.layers.torch import Rearrange


def conv_3x3_bn(inp, oup, image_size, downsample=False):
    stride = 1 if downsample == False else 2
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.GELU()
    )


class PreNorm(nn.Module):
    def __init__(self, dim, fn, norm):
        super().__init__()
        self.norm = norm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class SE(nn.Module):
    def __init__(self, inp, oup, expansion=0.25):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(oup, int(inp * expansion), bias=False),
            nn.GELU(),
            nn.Linear(int(inp * expansion), oup, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class MBConv(nn.Module):
    def __init__(self, inp, oup, image_size, downsample=False, expansion=4):
        super().__init__()
        self.downsample = downsample
        stride = 1 if self.downsample == False else 2
        hidden_dim = int(inp * expansion)

        if self.downsample:
            self.pool = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)

        if expansion == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
                          1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                # down-sample in the first conv
                nn.Conv2d(inp, hidden_dim, 1, stride, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1,
                          groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                SE(inp, hidden_dim),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

        self.conv = PreNorm(inp, self.conv, nn.BatchNorm2d)

    def forward(self, x):
        if self.downsample:
            return self.proj(self.pool(x)) + self.conv(x)
        else:
            return x + self.conv(x)


class Attention(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == inp)

        self.ih, self.iw = image_size

        self.heads = heads
        self.scale = dim_head ** -0.5

        # parameter table of relative position bias
        self.relative_bias_table = nn.Parameter(
            torch.zeros((2 * self.ih - 1) * (2 * self.iw - 1), heads))

        coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
        coords = torch.flatten(torch.stack(coords), 1)
        relative_coords = coords[:, :, None] - coords[:, None, :]

        relative_coords[0] += self.ih - 1
        relative_coords[1] += self.iw - 1
        relative_coords[0] *= 2 * self.iw - 1
        relative_coords = rearrange(relative_coords, 'c h w -> h w c')
        relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
        self.register_buffer("relative_index", relative_index)

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, oup),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(
            t, 'b n (h d) -> b h n d', h=self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # Use "gather" for more efficiency on GPUs
        relative_bias = self.relative_bias_table.gather(
            0, self.relative_index.repeat(1, self.heads))
        relative_bias = rearrange(
            relative_bias, '(h w) c -> 1 c h w', h=self.ih*self.iw, w=self.ih*self.iw)
        dots = dots + relative_bias

        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out

class HAttention(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == inp)

        self.ih, self.iw = image_size

        self.heads = heads
        self.scale = dim_head ** -0.5

        # parameter table of relative position bias
        self.relative_bias_table = nn.Parameter(
            torch.zeros((2 * self.ih - 1) * (2 * self.iw - 1), heads))

        coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
        coords = torch.flatten(torch.stack(coords), 1)
        relative_coords = coords[:, :, None] - coords[:, None, :]

        relative_coords[0] += self.ih - 1
        relative_coords[1] += self.iw - 1
        relative_coords[0] *= 2 * self.iw - 1
        relative_coords = rearrange(relative_coords, 'c h w -> h w c')
        relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
        self.register_buffer("relative_index", relative_index)

        self.attend = nn.Softmax(dim=-1)
        self.to_qk = nn.Linear(inp, inner_dim * 2, bias=False)

        self.to_out = nn.Sequential(
#            nn.Linear(inner_dim, oup),
            nn.Linear(oup, oup),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x, mbconv):
        qk = self.to_qk(x).chunk(2, dim=-1)
        q, k = map(lambda t: rearrange(
            t, 'b n (h d) -> b h n d', h=self.heads), qk)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # Use "gather" for more efficiency on GPUs
        relative_bias = self.relative_bias_table.gather(
            0, self.relative_index.repeat(1, self.heads))
        relative_bias = rearrange(
            relative_bias, '(h w) c -> 1 c h w', h=self.ih*self.iw, w=self.ih*self.iw)
        dots = dots + relative_bias

        attn = self.attend(dots)
        v = rearrange(mbconv, 'b c ih iw -> b (ih iw) c')
        v = rearrange(v, 'b n (h c) -> b h n c', h=self.heads)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out


class Transformer(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, downsample=False, dropout=0.):
        super().__init__()
        hidden_dim = int(inp * 4)

        self.ih, self.iw = image_size
        self.downsample = downsample

        if self.downsample:
            self.pool1 = nn.MaxPool2d(3, 2, 1)
            self.pool2 = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)

        self.attn = Attention(inp, oup, image_size, heads, dim_head, dropout)
        self.ff = FeedForward(oup, hidden_dim, dropout)

        self.attn = nn.Sequential(
            Rearrange('b c ih iw -> b (ih iw) c'),
            PreNorm(inp, self.attn, nn.LayerNorm),
            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
        )

        self.ff = nn.Sequential(
            Rearrange('b c ih iw -> b (ih iw) c'),
            PreNorm(oup, self.ff, nn.LayerNorm),
            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
        )

    def forward(self, x):
        if self.downsample:
            x = self.proj(self.pool1(x)) + self.attn(self.pool2(x))
        else:
            x = x + self.attn(x)
        x = x + self.ff(x)
        return x

class HTransformer(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, downsample=False, dropout=0.):
        super().__init__()
        hidden_dim = int(inp * 4)

        self.ih, self.iw = image_size
        self.downsample = downsample
        self.layer_norm = nn.LayerNorm(inp)
        self.MBConv = MBConv(inp, oup, image_size, downsample)

        if self.downsample:
            self.pool1 = nn.MaxPool2d(3, 2, 1)
            self.pool2 = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)

        self.attn = HAttention(inp, oup, image_size, heads, dim_head, dropout)
        self.ff = FeedForward(oup, hidden_dim, dropout)

#        self.attn = nn.Sequential(
#            Rearrange('b c ih iw -> b (ih iw) c'),
#            PreNorm(inp, self.attn, nn.LayerNorm),
#            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
#        )

        self.ff = nn.Sequential(
            Rearrange('b c ih iw -> b (ih iw) c'),
            PreNorm(oup, self.ff, nn.LayerNorm),
            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
        )

    def forward(self, x):
        mbconv = self.MBConv(x)
        if self.downsample:
            pool1 = self.pool2(x)
            pool1 = rearrange(pool1, 'b c ih iw -> b (ih iw) c')
            norm1 = self.layer_norm(pool1)
            attn1 = self.attn(norm1, mbconv)
            attn1 = rearrange(attn1, 'b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
            out1 = self.proj(self.pool1(x)) + attn1
#            x = self.proj(self.pool1(x)) + self.attn(self.pool2(x), y)
        else:
            xx = rearrange(x, 'b c ih iw -> b (ih iw) c')
            norm1 = self.layer_norm(xx)
            attn1 = self.attn(norm1, mbconv)
            attn1 = rearrange(attn1, 'b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
            out1 = x + attn1

#            x = x + self.attn(x,y)
        x = out1 + self.ff(out1)
        return x


class CoHAtNet(nn.Module):
    def __init__(self, image_size, in_channels, num_blocks, channels, num_classes=1000, block_types=['C', 'C', 'H', 'H']):
        super().__init__()
        ih, iw = image_size
        block = {'C': MBConv, 'T': Transformer, 'H': HTransformer}

        self.s0 = self._make_layer(
            conv_3x3_bn, in_channels, channels[0], num_blocks[0], (ih // 2, iw // 2))
        self.s1 = self._make_layer(
            block[block_types[0]], channels[0], channels[1], num_blocks[1], (ih // 4, iw // 4))
        self.s2 = self._make_layer(
            block[block_types[1]], channels[1], channels[2], num_blocks[2], (ih // 8, iw // 8))
        self.s3 = self._make_layer(
            block[block_types[2]], channels[2], channels[3], num_blocks[3], (ih // 16, iw // 16))
        self.s4 = self._make_layer(
            block[block_types[3]], channels[3], channels[4], num_blocks[4], (ih // 32, iw // 32))

        self.pool = nn.AvgPool2d(ih // 32, 1)
        hidden_dim = 1024
        self.fc1 = nn.Linear(channels[-1] * 2, hidden_dim)
        self.relu = nn.ReLU()  # Activation function between dense layers
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)  # Another dense layer
        self.fc3 = nn.Linear(hidden_dim // 2, 7)  # Final output layer adjusted to target size (e.g., 6 for 6DoF pose)

    def forward(self, x):
        x = self.s0(x)
        x = self.s1(x)
        x = self.s2(x)
        x = self.s3(x)
        x = self.s4(x)

        x = self.pool(x).view(-1, x.shape[1])

        # Concatenation might be necessary depending on model specifics; assuming you meant *2 in channels
        x = torch.cat((x, x), dim=1)  # Concatenation for expanding feature dimension

        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x


    def _make_layer(self, block, inp, oup, depth, image_size):
        layers = nn.ModuleList([])
        for i in range(depth):
            if i == 0:
                layers.append(block(inp, oup, image_size, downsample=True))
            else:
                layers.append(block(oup, oup, image_size))
        return nn.Sequential(*layers)


def cohatnet_0():
    num_blocks = [2, 2, 3, 5, 2]            # L
    channels = [64, 96, 192, 384, 768]      # D
    return CoHAtNet((256, 256), 3, num_blocks, channels, num_classes=1000)


def cohatnet_1():
    num_blocks = [2, 2, 6, 14, 2]           # L
    channels = [64, 96, 192, 384, 768]      # D
    return CoHAtNet((256, 256), 3, num_blocks, channels, num_classes=1000)


def cohatnet_2():
    num_blocks = [2, 2, 6, 14, 2]           # L
    channels = [128, 128, 256, 512, 1024]   # D
    return CoHAtNet((256, 256), 3, num_blocks, channels, num_classes=1000)


def cohatnet_3():
    num_blocks = [2, 2, 6, 14, 2]           # L
    channels = [192, 192, 384, 768, 1536]   # D
    return CoHAtNet((256, 256), 3, num_blocks, channels, num_classes=1000)


def cohatnet_4():
    num_blocks = [2, 2, 12, 28, 2]          # L
    channels = [192, 192, 384, 768, 1536]   # D
    return CoHAtNet((256, 256), 3, num_blocks, channels, num_classes=1000)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


if __name__ == '__main__':
    img = torch.randn(1, 3, 256, 256)

    net = cohatnet_0()
    out = net(img)
    print(out.shape, count_parameters(net))

    net = cohatnet_1()
    out = net(img)
    print(out.shape, count_parameters(net))

    net = cohatnet_2()
    out = net(img)
    print(out.shape, count_parameters(net))

    net = cohatnet_3()
    out = net(img)
    print(out.shape, count_parameters(net))

    net = cohatnet_4()
    out = net(img)
    print(out.shape, count_parameters(net))


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


torch.Size([1, 7]) 34337023
torch.Size([1, 7]) 62756935
torch.Size([1, 7]) 108877575
torch.Size([1, 7]) 238862151
torch.Size([1, 7]) 411054007


In [None]:
def batch_to_device(batch, device):
    """
    If `device` is not 'cpu', moves all data in batch to the GPU.
    """
    if device != 'cpu':
        for key, value in batch.items():
            if isinstance(value, torch.Tensor):
                batch[key] = value.to(device)
            elif isinstance(value[0], torch.Tensor):
                for index_value, value_value in enumerate(value):
                    value[index_value] = value_value.to(device)
    return batch


## Training

In [None]:
if __name__ == '__main__':
    root_dir = '/content/drive/MyDrive/fire'

    # Check for GPU availability and define device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    dataset = FireDataset(root_dir=root_dir, transform=transform)

    # Instantiate the Global Homography Loss with computed xmin and xmax
    global_loss_fn = GlobalHomographyLoss(xmin=dataset.global_xmin, xmax=dataset.global_xmax, device=device)

    kfold = KFold(n_splits=5, shuffle=True, random_state=42)
    num_epochs = 200  # Define the number of epochs

    for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
        print(f'Fold {fold + 1}')

        train_sampler = SubsetRandomSampler(train_idx)
        val_sampler = SubsetRandomSampler(val_idx)

        train_loader = DataLoader(dataset, batch_size=8, sampler=train_sampler, num_workers=0, pin_memory=False)
        val_loader = DataLoader(dataset, batch_size=8, sampler=val_sampler, num_workers=0, pin_memory=False)

        model = cohatnet_0().to(device)  # Use a smaller model
        optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=4, factor=0.1)

        # Training loop
        for epoch in range(num_epochs):
            model.train()
            train_loss = 0.0

            for color_images, depth_images, poses in tqdm(train_loader):
                color_images, poses = color_images.to(device), poses.to(device)

                optimizer.zero_grad()
                outputs = model(color_images)

                # Convert predicted outputs to a transformation matrix
                pred_translation = outputs[:, :3]
                pred_quaternion = outputs[:, 3:]
                pred_quaternion = F.normalize(pred_quaternion, p=2, dim=-1)

                # Prepare batch for loss computation
                batch = {
                    'w_t_c': pred_translation.unsqueeze(-1),
                    'c_R_w': quaternion_to_rotation_matrix(pred_quaternion),
                    'w_t_chat': poses[:, :3].unsqueeze(-1),  # Corrected for 2D tensor
                    'chat_R_w': quaternion_to_rotation_matrix(poses[:, 3:])  # Convert quaternion to rotation matrix
                }

                loss = global_loss_fn(batch)
                loss.backward()  # This was missing in your original code
                optimizer.step()

                train_loss += loss.item()

            avg_train_loss = train_loss / len(train_loader)

            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for color_images, depth_images, poses in val_loader:
                    color_images, poses = color_images.to(device), poses.to(device)
                    outputs = model(color_images)

                    # Convert predicted outputs to a transformation matrix
                    pred_translation = outputs[:, :3]
                    pred_quaternion = outputs[:, 3:]
                    pred_quaternion = F.normalize(pred_quaternion, p=2, dim=-1)

                    # Prepare batch for loss computation
                    batch = {
                        'w_t_c': pred_translation.unsqueeze(-1),
                        'c_R_w': quaternion_to_rotation_matrix(pred_quaternion),
                        'w_t_chat': poses[:, :3].unsqueeze(-1),
                        'chat_R_w': quaternion_to_rotation_matrix(poses[:, 3:])
                    }

                    loss = global_loss_fn(batch)
                    val_loss += loss.item()

            scheduler.step(val_loss)

            print(f'Epoch {epoch + 1}, Train Loss: {avg_train_loss}, Val Loss: {val_loss / len(val_loader)}')
            torch.cuda.empty_cache()


Using device: cuda


Computing global depths: 100%|██████████| 100/100 [01:12<00:00,  1.39it/s]


Fold 1


100%|██████████| 400/400 [05:48<00:00,  1.15it/s]


Epoch 1, Train Loss: 0.28966043850407003, Val Loss: 0.061404002234339713


100%|██████████| 400/400 [02:16<00:00,  2.94it/s]


Epoch 2, Train Loss: 0.05517053067684174, Val Loss: 0.020966893546283245


100%|██████████| 400/400 [02:15<00:00,  2.94it/s]


Epoch 3, Train Loss: 0.01994513839657884, Val Loss: 0.013791554421186448


100%|██████████| 400/400 [02:16<00:00,  2.93it/s]


Epoch 4, Train Loss: 0.013040020476328209, Val Loss: 0.01220401294529438


100%|██████████| 400/400 [02:16<00:00,  2.94it/s]


Epoch 5, Train Loss: 0.009332132398267277, Val Loss: 0.004363562890794128


100%|██████████| 400/400 [02:17<00:00,  2.91it/s]


Epoch 6, Train Loss: 0.006698154580080881, Val Loss: 0.005998032758943736


100%|██████████| 400/400 [02:16<00:00,  2.92it/s]


Epoch 7, Train Loss: 0.006755618319439236, Val Loss: 0.0075409399345517155


100%|██████████| 400/400 [02:16<00:00,  2.92it/s]


Epoch 8, Train Loss: 0.005799728282145224, Val Loss: 0.005000838788691908


100%|██████████| 400/400 [02:16<00:00,  2.92it/s]


Epoch 9, Train Loss: 0.005651691219536587, Val Loss: 0.008120902669616044


100%|██████████| 400/400 [02:17<00:00,  2.91it/s]


Epoch 10, Train Loss: 0.008296324118273333, Val Loss: 0.0036229212873149664


100%|██████████| 400/400 [02:17<00:00,  2.92it/s]


Epoch 11, Train Loss: 0.004169012782222125, Val Loss: 0.002718017071019858


100%|██████████| 400/400 [02:17<00:00,  2.92it/s]


Epoch 12, Train Loss: 0.0041259924430050885, Val Loss: 0.002445554316509515


100%|██████████| 400/400 [02:16<00:00,  2.92it/s]


Epoch 13, Train Loss: 0.004985432543035131, Val Loss: 0.0031223570613656193


100%|██████████| 400/400 [02:16<00:00,  2.93it/s]


Epoch 14, Train Loss: 0.004502105439023581, Val Loss: 0.003793821844737977


100%|██████████| 400/400 [02:16<00:00,  2.93it/s]


Epoch 15, Train Loss: 0.0037666001472098287, Val Loss: 0.004481421572854742


100%|██████████| 400/400 [02:16<00:00,  2.93it/s]


Epoch 16, Train Loss: 0.004961377410800196, Val Loss: 0.004588748482055962


100%|██████████| 400/400 [02:16<00:00,  2.93it/s]


Epoch 17, Train Loss: 0.004256206968129846, Val Loss: 0.005437889975728467


100%|██████████| 400/400 [02:16<00:00,  2.93it/s]


Epoch 18, Train Loss: 0.0013946942752227187, Val Loss: 0.000838925012940308


100%|██████████| 400/400 [02:16<00:00,  2.93it/s]


Epoch 19, Train Loss: 0.0007650216982801794, Val Loss: 0.0007096901026670821


100%|██████████| 400/400 [02:16<00:00,  2.93it/s]


Epoch 20, Train Loss: 0.0006464843022695277, Val Loss: 0.000672351041721413


100%|██████████| 400/400 [02:16<00:00,  2.94it/s]


Epoch 21, Train Loss: 0.0005473318728036247, Val Loss: 0.0006513437852845527


100%|██████████| 400/400 [02:15<00:00,  2.95it/s]


Epoch 22, Train Loss: 0.00048745564205091794, Val Loss: 0.0005564966070232913


100%|██████████| 400/400 [02:16<00:00,  2.92it/s]


Epoch 23, Train Loss: 0.00043516101588465973, Val Loss: 0.0005161964565922972


100%|██████████| 400/400 [02:16<00:00,  2.94it/s]


Epoch 24, Train Loss: 0.00041815752658294516, Val Loss: 0.0005486478992679622


100%|██████████| 400/400 [02:16<00:00,  2.94it/s]


Epoch 25, Train Loss: 0.0003895926962286467, Val Loss: 0.0004977232425881084


100%|██████████| 400/400 [02:16<00:00,  2.94it/s]


Epoch 26, Train Loss: 0.0003642234454309801, Val Loss: 0.0005567090134718456


100%|██████████| 400/400 [02:16<00:00,  2.94it/s]


Epoch 27, Train Loss: 0.0003655134683685901, Val Loss: 0.0005258312652586028


100%|██████████| 400/400 [02:16<00:00,  2.93it/s]


Epoch 28, Train Loss: 0.0003332148087247333, Val Loss: 0.0005678973942121957


100%|██████████| 400/400 [02:16<00:00,  2.93it/s]


Epoch 29, Train Loss: 0.0003266304897442751, Val Loss: 0.0005237789470993448


100%|██████████| 400/400 [02:16<00:00,  2.94it/s]


Epoch 30, Train Loss: 0.0003431609715698869, Val Loss: 0.0004391001258045435


100%|██████████| 400/400 [02:16<00:00,  2.93it/s]


Epoch 31, Train Loss: 0.0003081712245875678, Val Loss: 0.0004081197742198128


100%|██████████| 400/400 [02:15<00:00,  2.95it/s]


Epoch 32, Train Loss: 0.0002955283303526812, Val Loss: 0.0004084293937194161


100%|██████████| 400/400 [02:16<00:00,  2.94it/s]


Epoch 33, Train Loss: 0.00027446967891592067, Val Loss: 0.00039189012903079857


100%|██████████| 400/400 [02:16<00:00,  2.94it/s]


Epoch 34, Train Loss: 0.00027376657224522207, Val Loss: 0.00040203152762842366


100%|██████████| 400/400 [02:15<00:00,  2.94it/s]


Epoch 35, Train Loss: 0.00028891818028569104, Val Loss: 0.0005373179825255648


100%|██████████| 400/400 [02:15<00:00,  2.95it/s]


Epoch 36, Train Loss: 0.00027408153597207273, Val Loss: 0.00048351854740758427


100%|██████████| 400/400 [02:15<00:00,  2.94it/s]


Epoch 37, Train Loss: 0.0003198873953442671, Val Loss: 0.0005239129197434523


100%|██████████| 400/400 [02:15<00:00,  2.95it/s]


Epoch 38, Train Loss: 0.00025672135960121524, Val Loss: 0.0005824770154140424


100%|██████████| 400/400 [02:16<00:00,  2.94it/s]


Epoch 39, Train Loss: 0.0001433073320367839, Val Loss: 0.0003066783979011234


100%|██████████| 400/400 [02:15<00:00,  2.94it/s]


Epoch 40, Train Loss: 0.00012046929994539823, Val Loss: 0.00030559879531210754


100%|██████████| 400/400 [02:15<00:00,  2.95it/s]


Epoch 41, Train Loss: 0.00011481007823931578, Val Loss: 0.00029702277686737943


100%|██████████| 400/400 [02:15<00:00,  2.94it/s]


Epoch 42, Train Loss: 0.00011115898300886328, Val Loss: 0.0002931000849639531


100%|██████████| 400/400 [02:15<00:00,  2.94it/s]


Epoch 43, Train Loss: 0.00010498817058760324, Val Loss: 0.000293645779020153


100%|██████████| 400/400 [02:15<00:00,  2.95it/s]


Epoch 44, Train Loss: 0.00010310187893537659, Val Loss: 0.00030562567968445365


100%|██████████| 400/400 [02:15<00:00,  2.95it/s]


Epoch 45, Train Loss: 0.00010193783834438363, Val Loss: 0.00029751417081570254


100%|██████████| 400/400 [02:15<00:00,  2.95it/s]


Epoch 46, Train Loss: 0.00010041446727882431, Val Loss: 0.00030898382028681225


100%|██████████| 400/400 [02:13<00:00,  2.99it/s]


Epoch 47, Train Loss: 9.751858067374997e-05, Val Loss: 0.0003038774505694164


100%|██████████| 400/400 [02:13<00:00,  3.00it/s]


Epoch 48, Train Loss: 8.869182417583943e-05, Val Loss: 0.0002921200510900235


100%|██████████| 400/400 [02:12<00:00,  3.01it/s]


Epoch 49, Train Loss: 8.848019878769264e-05, Val Loss: 0.00028662819582677914


100%|██████████| 400/400 [02:12<00:00,  3.01it/s]


Epoch 50, Train Loss: 8.484399143526388e-05, Val Loss: 0.0002834280090610264


100%|██████████| 400/400 [02:13<00:00,  3.01it/s]


Epoch 51, Train Loss: 8.868386806625495e-05, Val Loss: 0.00029049731136183255


100%|██████████| 400/400 [02:12<00:00,  3.01it/s]


Epoch 52, Train Loss: 8.694727259353385e-05, Val Loss: 0.00028090614658140113


100%|██████████| 400/400 [02:13<00:00,  3.00it/s]


Epoch 53, Train Loss: 8.539836013369494e-05, Val Loss: 0.00029709742411796467


100%|██████████| 400/400 [02:12<00:00,  3.01it/s]


Epoch 54, Train Loss: 8.669940633808437e-05, Val Loss: 0.00028075890506443105


100%|██████████| 400/400 [02:13<00:00,  3.01it/s]


Epoch 55, Train Loss: 8.366520105937525e-05, Val Loss: 0.00028438938556064387


100%|██████████| 400/400 [02:13<00:00,  3.00it/s]


Epoch 56, Train Loss: 8.830341005705122e-05, Val Loss: 0.00028489213655120694


100%|██████████| 400/400 [02:13<00:00,  3.00it/s]


Epoch 57, Train Loss: 8.785802533111563e-05, Val Loss: 0.00029471170477336273


100%|██████████| 400/400 [02:14<00:00,  2.97it/s]


Epoch 58, Train Loss: 8.54581623843842e-05, Val Loss: 0.00028429915648302994


100%|██████████| 400/400 [02:13<00:00,  3.00it/s]


Epoch 59, Train Loss: 8.439149025434744e-05, Val Loss: 0.000279785948805511


100%|██████████| 400/400 [02:13<00:00,  3.00it/s]


Epoch 60, Train Loss: 8.898781783500453e-05, Val Loss: 0.00027728431261493826


100%|██████████| 400/400 [02:13<00:00,  2.99it/s]


Epoch 61, Train Loss: 8.610055870121869e-05, Val Loss: 0.0002876526644831756


100%|██████████| 400/400 [02:13<00:00,  3.00it/s]


Epoch 62, Train Loss: 8.220547785185772e-05, Val Loss: 0.0002787130964134121


100%|██████████| 400/400 [02:13<00:00,  3.00it/s]


Epoch 63, Train Loss: 8.546814889996312e-05, Val Loss: 0.00029720003964030185


100%|██████████| 400/400 [02:13<00:00,  2.99it/s]


Epoch 64, Train Loss: 8.162748848917544e-05, Val Loss: 0.00029051899931801015


100%|██████████| 400/400 [02:13<00:00,  3.00it/s]


Epoch 65, Train Loss: 8.335763810919161e-05, Val Loss: 0.00027929023788601624


100%|██████████| 400/400 [02:13<00:00,  3.00it/s]


Epoch 66, Train Loss: 8.756708922192047e-05, Val Loss: 0.0002805597509723157


100%|██████████| 400/400 [02:13<00:00,  3.00it/s]


Epoch 67, Train Loss: 8.425872676525614e-05, Val Loss: 0.00028162707094452343


100%|██████████| 400/400 [02:13<00:00,  3.00it/s]


Epoch 68, Train Loss: 8.181689377124712e-05, Val Loss: 0.0002834495989372954


100%|██████████| 400/400 [02:14<00:00,  2.98it/s]


Epoch 69, Train Loss: 8.751683202717686e-05, Val Loss: 0.00028554001277370845


100%|██████████| 400/400 [02:13<00:00,  2.99it/s]


Epoch 70, Train Loss: 8.218158139243314e-05, Val Loss: 0.0002821707821203745


100%|██████████| 400/400 [02:13<00:00,  3.00it/s]


Epoch 71, Train Loss: 8.262765845756803e-05, Val Loss: 0.0002925941268040333


100%|██████████| 400/400 [02:12<00:00,  3.02it/s]


Epoch 72, Train Loss: 8.352555722467514e-05, Val Loss: 0.00028565310029080135


100%|██████████| 400/400 [02:12<00:00,  3.02it/s]


Epoch 73, Train Loss: 8.58175928624405e-05, Val Loss: 0.0002821707205293933


100%|██████████| 400/400 [02:12<00:00,  3.02it/s]


Epoch 74, Train Loss: 8.486149583404767e-05, Val Loss: 0.0002891923008428421


100%|██████████| 400/400 [02:12<00:00,  3.01it/s]


Epoch 75, Train Loss: 8.211733795178589e-05, Val Loss: 0.00028501071290520485


100%|██████████| 400/400 [02:12<00:00,  3.01it/s]


Epoch 76, Train Loss: 8.199272831461712e-05, Val Loss: 0.0002835678413975984


100%|██████████| 400/400 [02:12<00:00,  3.01it/s]


Epoch 77, Train Loss: 8.167128793502342e-05, Val Loss: 0.00028060624455974906


100%|██████████| 400/400 [02:12<00:00,  3.02it/s]


Epoch 78, Train Loss: 8.135043708534795e-05, Val Loss: 0.00028023886210576165


 28%|██▊       | 110/400 [00:36<01:36,  3.00it/s]


KeyboardInterrupt: 

# CoHatNet- Cambridge Landmarks

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install transforms3d

Collecting transforms3d
  Downloading transforms3d-0.4.2-py3-none-any.whl.metadata (2.8 kB)
Downloading transforms3d-0.4.2-py3-none-any.whl (1.4 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.4 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m61.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: transforms3d
Successfully installed transforms3d-0.4.2


In [None]:
pip install einops



In [None]:
pip install kornia

Collecting kornia
  Downloading kornia-0.7.3-py2.py3-none-any.whl.metadata (7.7 kB)
Collecting kornia-rs>=0.1.0 (from kornia)
  Downloading kornia_rs-0.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.7 kB)
Downloading kornia-0.7.3-py2.py3-none-any.whl (833 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m833.3/833.3 kB[0m [31m52.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading kornia_rs-0.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m82.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: kornia-rs, kornia
Successfully installed kornia-0.7.3 kornia-rs-0.1.5


In [None]:
import os
import time
import random
import numpy as np
import pandas as pd
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
from torchvision import transforms
from torchvision.transforms import Compose, RandomHorizontalFlip, RandomRotation, ToTensor
from torchvision.transforms.functional import to_tensor
from torchvision.utils import make_grid
from einops import rearrange
from einops.layers.torch import Rearrange
from sklearn.model_selection import KFold
from tqdm import tqdm
import transforms3d.quaternions as txq
import matplotlib.pyplot as plt
from kornia.geometry.conversions import quaternion_to_rotation_matrix
import pickle

  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)


## Compute Xmin & Xmax

In [None]:
import torch
import os
import numpy as np
from torch.utils.data import Dataset
import cv2
from tqdm import tqdm
import kornia  # For quaternion to rotation matrix conversion

# Dataset class adjusted for your dataset structure
class VisualLandmarkDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.gt_data = self._load_ground_truth()

        # Load samples
        self.samples = self._load_samples()

        print(f"Number of samples: {len(self.samples)}")

    def _load_ground_truth(self):
        gt_path = os.path.join(self.root_dir, 'GT.txt')
        gt_data = {}
        with open(gt_path, 'r') as f:
            lines = f.readlines()[2:]  # Skip the first two lines (header)
            for line in lines:
                parts = line.strip().split()
                if len(parts) < 8:
                    continue  # Skip malformed lines
                image_path = parts[0]  # Image path relative to the dataset root
                pose = np.array([float(x) for x in parts[1:]])
                gt_data[os.path.normpath(image_path)] = pose
        return gt_data

    def _load_samples(self):
        samples = []
        for seq_folder in os.listdir(self.root_dir):
            seq_path = os.path.join(self.root_dir, seq_folder)
            if os.path.isdir(seq_path):
                image_files = sorted([f for f in os.listdir(seq_path) if f.endswith('.png')])
                for image_file in image_files:
                    image_path = os.path.join(seq_folder, image_file)
                    norm_image_path = os.path.normpath(image_path)
                    if norm_image_path in self.gt_data:
                        full_image_path = os.path.join(self.root_dir, image_path)
                        samples.append((full_image_path, self.gt_data[norm_image_path]))
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        color_path, pose = self.samples[idx]
        color_image = cv2.imread(color_path, cv2.IMREAD_COLOR)
        translation = pose[:3]
        quaternion = pose[3:]
        pose_matrix = np.concatenate([translation, quaternion])
        if self.transform:
            color_image = self.transform(color_image)
        return color_image, torch.tensor(pose_matrix, dtype=torch.float32)

# Function to parse reconstruction.nvm and extract 3D world points (w_P)
def parse_nvm_file(nvm_file_path):
    scene_coordinates = []
    with open(nvm_file_path, 'r') as file:
        lines = file.readlines()
        n_views = int(lines[2])  # Number of images
        n_points = int(lines[2 + n_views + 2])  # Number of 3D points

        for i in range(3 + n_views + 3, 3 + n_views + 3 + n_points):
            point_data = lines[i].strip().split()[:3]  # Extract only the first 3 values (XYZ coordinates)

            # Ensure the line contains valid 3D coordinates (3 floating-point numbers)
            if len(point_data) != 3:
                print(f"Skipping invalid point on line {i}: {lines[i].strip()}")
                continue  # Skip this line if it doesn't have exactly 3 values

            try:
                w_P = np.array([float(coord) for coord in point_data])
                scene_coordinates.append(w_P)  # Append valid 3D coordinates to scene_coordinates
            except ValueError:
                print(f"Skipping malformed point data on line {i}: {lines[i].strip()}")
                continue  # Skip lines with invalid float conversions

    if len(scene_coordinates) == 0:
        raise ValueError("No valid 3D points found in the reconstruction.nvm file.")

    return torch.tensor(scene_coordinates, dtype=torch.float32)

# Function to calculate xmin and xmax for each image in the dataset
def compute_xmin_xmax(w_P, c_R_w, w_t_c, xmin_percentile=0.025, xmax_percentile=0.975):
    # Project world points to camera frame
    c_P = c_R_w @ (w_P.T - w_t_c)

    # Depth values from the Z-axis (third column of c_P)
    depths = c_P[2, :]

    # Filter depths based on valid range (as in original code)
    valid_depths = depths[(depths > 0.2) & (depths < 1000)]

    # Sort valid depths to compute percentiles
    sorted_depths = torch.sort(valid_depths).values

    # Compute xmin and xmax using specified percentiles
    xmin = sorted_depths[int(xmin_percentile * (sorted_depths.shape[0] - 1))]
    xmax = sorted_depths[int(xmax_percentile * (sorted_depths.shape[0] - 1))]

    return xmin, xmax

# Convert quaternion to rotation matrix using Kornia
def quaternion_to_rotation_matrix(quaternion):
    """
    Converts a quaternion to a 3x3 rotation matrix using Kornia.
    Kornia expects quaternions in (batch_size, 4) format.
    """
    # Ensure the quaternion is in the correct format: (4,) -> (1, 4)
    if quaternion.ndim == 1:
        quaternion = quaternion.unsqueeze(0)  # Add a batch dimension

    # Ensure quaternion is normalized (optional, depending on your data)
    quaternion = kornia.geometry.quaternion.normalize_quaternion(quaternion)

    # Convert to rotation matrix
    rotation_matrix = kornia.geometry.conversions.quaternion_to_rotation_matrix(quaternion)

    # Return the first (and only) rotation matrix if there was only one quaternion
    return rotation_matrix[0] if rotation_matrix.shape[0] == 1 else rotation_matrix

# Main block for computing xmin and xmax for the dataset
if __name__ == "__main__":
    dataset_path = "/content/drive/MyDrive/KingsCollege"  # Update this with your dataset path
    nvm_file_path = os.path.join(dataset_path, 'reconstruction.nvm')

    # Load the dataset and parse the 3D world points (w_P)
    dataset = VisualLandmarkDataset(root_dir=dataset_path)
    w_P = parse_nvm_file(nvm_file_path)

    global_xmin = []
    global_xmax = []

    # Using tqdm for progress bar
    for i, (image, pose) in tqdm(enumerate(dataset), total=len(dataset), desc="Processing Images"):
        translation = pose[:3].view(3, 1)
        quaternion = pose[3:]  # Quaternion
        c_R_w = quaternion_to_rotation_matrix(quaternion)  # Using Kornia for quaternion to rotation matrix conversion

        # Compute xmin and xmax for this image
        xmin, xmax = compute_xmin_xmax(w_P, c_R_w, translation)

        global_xmin.append(xmin)
        global_xmax.append(xmax)

    # Compute global xmin and xmax for the dataset
    global_xmin = torch.min(torch.tensor(global_xmin))
    global_xmax = torch.max(torch.tensor(global_xmax))

    print(f"Global xmin: {global_xmin.item()}")
    print(f"Global xmax: {global_xmax.item()}")


  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)


Number of samples: 1563


  return torch.tensor(scene_coordinates, dtype=torch.float32)


Skipping invalid point on line 173480: 


Processing Images: 100%|██████████| 1563/1563 [31:15<00:00,  1.20s/it]

Global xmin: 2.0560340881347656
Global xmax: 130.1615447998047





## Loss Function

In [None]:
class GlobalHomographyLoss(torch.nn.Module):
    def __init__(self, xmin, xmax, device='cpu'):
        super().__init__()
        xmin = torch.tensor(xmin, dtype=torch.float32, device=device)
        xmax = torch.tensor(xmax, dtype=torch.float32, device=device)
        self.B_weight = torch.log(xmin / xmax) / (xmax - xmin)
        self.C_weight = xmin * xmax
        self.c_n = torch.tensor([0, 0, -1], dtype=torch.float32, device=device).view(3, 1)
        self.eye = torch.eye(3, device=device)

    def forward(self, batch):
        A, B, C = compute_ABC(batch['w_t_c'], batch['c_R_w'], batch['w_t_chat'], batch['chat_R_w'], self.c_n, self.eye)
        error = A + B * self.B_weight + C / self.C_weight
        error = error.diagonal(dim1=1, dim2=2).sum(dim=1).mean()
        return error

# Compute A, B, and C matrices for homography-based error
def compute_ABC(w_t_c, c_R_w, w_t_chat, chat_R_w, c_n, eye):
    w_t_c = w_t_c.float()
    c_R_w = c_R_w.float()
    w_t_chat = w_t_chat.float()
    chat_R_w = chat_R_w.float()
    c_n = c_n.float()
    eye = eye.float()

    chat_t_c = chat_R_w @ (w_t_c - w_t_chat)
    chat_R_c = chat_R_w @ c_R_w.transpose(1, 2)

    A = eye - chat_R_c
    C = c_n @ chat_t_c.transpose(1, 2)
    B = C @ A
    A = A @ A.transpose(1, 2)
    B = B + B.transpose(1, 2)
    C = C @ C.transpose(1, 2)

    return A, B, C

## Data Loader

In [None]:
class VisualLandmarkDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.gt_data = self._load_ground_truth()
        self.samples = self._load_samples()

    def _load_ground_truth(self):
        gt_path = os.path.join(self.root_dir, 'GT.txt')
        gt_data = {}
        with open(gt_path, 'r') as f:
            lines = f.readlines()[2:]  # Skip the first two lines (header)
            for line in lines:
                parts = line.strip().split()
                if len(parts) < 8:
                    continue
                image_path = parts[0]
                pose = np.array([float(x) for x in parts[1:]])  # Includes [X Y Z W P Q R]
                gt_data[os.path.normpath(image_path)] = pose
        return gt_data

    def _load_samples(self):
        samples = []
        for seq_folder in os.listdir(self.root_dir):
            seq_path = os.path.join(self.root_dir, seq_folder)
            if os.path.isdir(seq_path):
                image_files = sorted([f for f in os.listdir(seq_path) if f.endswith('.png')])
                for image_file in image_files:
                    image_path = os.path.join(seq_folder, image_file)
                    norm_image_path = os.path.normpath(image_path)
                    if norm_image_path in self.gt_data:
                        full_image_path = os.path.join(self.root_dir, image_path)
                        samples.append((full_image_path, self.gt_data[norm_image_path]))
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if idx >= len(self.samples):
            raise IndexError(f"Index {idx} out of bounds for samples of size {len(self.samples)}")
        color_path, pose = self.samples[idx]
        color_image = cv2.imread(color_path, cv2.IMREAD_COLOR)
        translation = pose[:3]  # [X, Y, Z]
        quaternion = pose[3:]   # [W, P, Q, R] (quaternion)
        pose_matrix = np.concatenate([translation, quaternion])
        if self.transform:
            color_image = self.transform(color_image)
        return color_image, torch.tensor(pose_matrix, dtype=torch.float32)

# Define the transformation for the images
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])


## CoHAtNet Model

In [None]:
from einops import rearrange
from einops.layers.torch import Rearrange


def conv_3x3_bn(inp, oup, image_size, downsample=False):
    stride = 1 if downsample == False else 2
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.GELU()
    )


class PreNorm(nn.Module):
    def __init__(self, dim, fn, norm):
        super().__init__()
        self.norm = norm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class SE(nn.Module):
    def __init__(self, inp, oup, expansion=0.25):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(oup, int(inp * expansion), bias=False),
            nn.GELU(),
            nn.Linear(int(inp * expansion), oup, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class MBConv(nn.Module):
    def __init__(self, inp, oup, image_size, downsample=False, expansion=4):
        super().__init__()
        self.downsample = downsample
        stride = 1 if self.downsample == False else 2
        hidden_dim = int(inp * expansion)

        if self.downsample:
            self.pool = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)

        if expansion == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
                          1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                # down-sample in the first conv
                nn.Conv2d(inp, hidden_dim, 1, stride, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1,
                          groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                SE(inp, hidden_dim),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

        self.conv = PreNorm(inp, self.conv, nn.BatchNorm2d)

    def forward(self, x):
        if self.downsample:
            return self.proj(self.pool(x)) + self.conv(x)
        else:
            return x + self.conv(x)


class Attention(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == inp)

        self.ih, self.iw = image_size

        self.heads = heads
        self.scale = dim_head ** -0.5

        # parameter table of relative position bias
        self.relative_bias_table = nn.Parameter(
            torch.zeros((2 * self.ih - 1) * (2 * self.iw - 1), heads))

        coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
        coords = torch.flatten(torch.stack(coords), 1)
        relative_coords = coords[:, :, None] - coords[:, None, :]

        relative_coords[0] += self.ih - 1
        relative_coords[1] += self.iw - 1
        relative_coords[0] *= 2 * self.iw - 1
        relative_coords = rearrange(relative_coords, 'c h w -> h w c')
        relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
        self.register_buffer("relative_index", relative_index)

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, oup),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(
            t, 'b n (h d) -> b h n d', h=self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # Use "gather" for more efficiency on GPUs
        relative_bias = self.relative_bias_table.gather(
            0, self.relative_index.repeat(1, self.heads))
        relative_bias = rearrange(
            relative_bias, '(h w) c -> 1 c h w', h=self.ih*self.iw, w=self.ih*self.iw)
        dots = dots + relative_bias

        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out

class HAttention(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == inp)

        self.ih, self.iw = image_size

        self.heads = heads
        self.scale = dim_head ** -0.5

        # parameter table of relative position bias
        self.relative_bias_table = nn.Parameter(
            torch.zeros((2 * self.ih - 1) * (2 * self.iw - 1), heads))

        coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
        coords = torch.flatten(torch.stack(coords), 1)
        relative_coords = coords[:, :, None] - coords[:, None, :]

        relative_coords[0] += self.ih - 1
        relative_coords[1] += self.iw - 1
        relative_coords[0] *= 2 * self.iw - 1
        relative_coords = rearrange(relative_coords, 'c h w -> h w c')
        relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
        self.register_buffer("relative_index", relative_index)

        self.attend = nn.Softmax(dim=-1)
        self.to_qk = nn.Linear(inp, inner_dim * 2, bias=False)

        self.to_out = nn.Sequential(
#            nn.Linear(inner_dim, oup),
            nn.Linear(oup, oup),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x, mbconv):
        qk = self.to_qk(x).chunk(2, dim=-1)
        q, k = map(lambda t: rearrange(
            t, 'b n (h d) -> b h n d', h=self.heads), qk)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # Use "gather" for more efficiency on GPUs
        relative_bias = self.relative_bias_table.gather(
            0, self.relative_index.repeat(1, self.heads))
        relative_bias = rearrange(
            relative_bias, '(h w) c -> 1 c h w', h=self.ih*self.iw, w=self.ih*self.iw)
        dots = dots + relative_bias

        attn = self.attend(dots)
        v = rearrange(mbconv, 'b c ih iw -> b (ih iw) c')
        v = rearrange(v, 'b n (h c) -> b h n c', h=self.heads)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out


class Transformer(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, downsample=False, dropout=0.):
        super().__init__()
        hidden_dim = int(inp * 4)

        self.ih, self.iw = image_size
        self.downsample = downsample

        if self.downsample:
            self.pool1 = nn.MaxPool2d(3, 2, 1)
            self.pool2 = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)

        self.attn = Attention(inp, oup, image_size, heads, dim_head, dropout)
        self.ff = FeedForward(oup, hidden_dim, dropout)

        self.attn = nn.Sequential(
            Rearrange('b c ih iw -> b (ih iw) c'),
            PreNorm(inp, self.attn, nn.LayerNorm),
            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
        )

        self.ff = nn.Sequential(
            Rearrange('b c ih iw -> b (ih iw) c'),
            PreNorm(oup, self.ff, nn.LayerNorm),
            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
        )

    def forward(self, x):
        if self.downsample:
            x = self.proj(self.pool1(x)) + self.attn(self.pool2(x))
        else:
            x = x + self.attn(x)
        x = x + self.ff(x)
        return x

class HTransformer(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, downsample=False, dropout=0.):
        super().__init__()
        hidden_dim = int(inp * 4)

        self.ih, self.iw = image_size
        self.downsample = downsample
        self.layer_norm = nn.LayerNorm(inp)
        self.MBConv = MBConv(inp, oup, image_size, downsample)

        if self.downsample:
            self.pool1 = nn.MaxPool2d(3, 2, 1)
            self.pool2 = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)

        self.attn = HAttention(inp, oup, image_size, heads, dim_head, dropout)
        self.ff = FeedForward(oup, hidden_dim, dropout)

#        self.attn = nn.Sequential(
#            Rearrange('b c ih iw -> b (ih iw) c'),
#            PreNorm(inp, self.attn, nn.LayerNorm),
#            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
#        )

        self.ff = nn.Sequential(
            Rearrange('b c ih iw -> b (ih iw) c'),
            PreNorm(oup, self.ff, nn.LayerNorm),
            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
        )

    def forward(self, x):
        mbconv = self.MBConv(x)
        if self.downsample:
            pool1 = self.pool2(x)
            pool1 = rearrange(pool1, 'b c ih iw -> b (ih iw) c')
            norm1 = self.layer_norm(pool1)
            attn1 = self.attn(norm1, mbconv)
            attn1 = rearrange(attn1, 'b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
            out1 = self.proj(self.pool1(x)) + attn1
#            x = self.proj(self.pool1(x)) + self.attn(self.pool2(x), y)
        else:
            xx = rearrange(x, 'b c ih iw -> b (ih iw) c')
            norm1 = self.layer_norm(xx)
            attn1 = self.attn(norm1, mbconv)
            attn1 = rearrange(attn1, 'b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
            out1 = x + attn1

#            x = x + self.attn(x,y)
        x = out1 + self.ff(out1)
        return x


class CoHAtNet(nn.Module):
    def __init__(self, image_size, in_channels, num_blocks, channels, num_classes=1000, block_types=['C', 'C', 'H', 'H']):
        super().__init__()
        ih, iw = image_size
        block = {'C': MBConv, 'T': Transformer, 'H': HTransformer}

        self.s0 = self._make_layer(
            conv_3x3_bn, in_channels, channels[0], num_blocks[0], (ih // 2, iw // 2))
        self.s1 = self._make_layer(
            block[block_types[0]], channels[0], channels[1], num_blocks[1], (ih // 4, iw // 4))
        self.s2 = self._make_layer(
            block[block_types[1]], channels[1], channels[2], num_blocks[2], (ih // 8, iw // 8))
        self.s3 = self._make_layer(
            block[block_types[2]], channels[2], channels[3], num_blocks[3], (ih // 16, iw // 16))
        self.s4 = self._make_layer(
            block[block_types[3]], channels[3], channels[4], num_blocks[4], (ih // 32, iw // 32))

        self.pool = nn.AvgPool2d(ih // 32, 1)
        hidden_dim = 1024
        self.fc1 = nn.Linear(channels[-1] * 2, hidden_dim)
        self.relu = nn.ReLU()  # Activation function between dense layers
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)  # Another dense layer
        self.fc3 = nn.Linear(hidden_dim // 2, 7)  # Final output layer adjusted to target size (e.g., 6 for 6DoF pose)

    def forward(self, x):
        x = self.s0(x)
        x = self.s1(x)
        x = self.s2(x)
        x = self.s3(x)
        x = self.s4(x)

        x = self.pool(x).view(-1, x.shape[1])

        # Concatenation might be necessary depending on model specifics; assuming you meant *2 in channels
        x = torch.cat((x, x), dim=1)  # Concatenation for expanding feature dimension

        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x


    def _make_layer(self, block, inp, oup, depth, image_size):
        layers = nn.ModuleList([])
        for i in range(depth):
            if i == 0:
                layers.append(block(inp, oup, image_size, downsample=True))
            else:
                layers.append(block(oup, oup, image_size))
        return nn.Sequential(*layers)


def cohatnet_0():
    num_blocks = [2, 2, 3, 5, 2]            # L
    channels = [64, 96, 192, 384, 768]      # D
    return CoHAtNet((256, 256), 3, num_blocks, channels, num_classes=1000)


def cohatnet_1():
    num_blocks = [2, 2, 6, 14, 2]           # L
    channels = [64, 96, 192, 384, 768]      # D
    return CoHAtNet((256, 256), 3, num_blocks, channels, num_classes=1000)


def cohatnet_2():
    num_blocks = [2, 2, 6, 14, 2]           # L
    channels = [128, 128, 256, 512, 1024]   # D
    return CoHAtNet((256, 256), 3, num_blocks, channels, num_classes=1000)


def cohatnet_3():
    num_blocks = [2, 2, 6, 14, 2]           # L
    channels = [192, 192, 384, 768, 1536]   # D
    return CoHAtNet((256, 256), 3, num_blocks, channels, num_classes=1000)


def cohatnet_4():
    num_blocks = [2, 2, 12, 28, 2]          # L
    channels = [192, 192, 384, 768, 1536]   # D
    return CoHAtNet((256, 256), 3, num_blocks, channels, num_classes=1000)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


if __name__ == '__main__':
    img = torch.randn(1, 3, 256, 256)

    net = cohatnet_0()
    out = net(img)
    print(out.shape, count_parameters(net))

    net = cohatnet_1()
    out = net(img)
    print(out.shape, count_parameters(net))

    net = cohatnet_2()
    out = net(img)
    print(out.shape, count_parameters(net))

    net = cohatnet_3()
    out = net(img)
    print(out.shape, count_parameters(net))

    net = cohatnet_4()
    out = net(img)
    print(out.shape, count_parameters(net))


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


torch.Size([1, 7]) 34337023
torch.Size([1, 7]) 62756935
torch.Size([1, 7]) 108877575
torch.Size([1, 7]) 238862151
torch.Size([1, 7]) 411054007


## Training

In [None]:
# Training with Global Homography Loss
if __name__ == '__main__':
    root_dir = '/content/drive/MyDrive/KingsCollege'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = VisualLandmarkDataset(root_dir=root_dir, transform=transform)

    # Instantiate Global Homography Loss with precomputed xmin and xmax
    global_loss_fn = GlobalHomographyLoss(xmin=2.05, xmax=130.16, device=device)

    kfold = KFold(n_splits=5, shuffle=True, random_state=42)
    num_epochs = 100

    for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
        print(f'Fold {fold + 1}')

        train_sampler = SubsetRandomSampler(train_idx)
        val_sampler = SubsetRandomSampler(val_idx)

        train_loader = DataLoader(dataset, batch_size=16, sampler=train_sampler, num_workers=4)
        val_loader = DataLoader(dataset, batch_size=16, sampler=val_sampler, num_workers=4)

        model = cohatnet_0().to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=4, factor=0.1)

        # Training loop
        for epoch in range(num_epochs):
            model.train()
            train_loss = 0.0

            for color_images, poses in tqdm(train_loader):
                color_images, poses = color_images.to(device), poses.to(device)
                optimizer.zero_grad()
                outputs = model(color_images)

                pred_translation = outputs[:, :3]
                pred_quaternion = outputs[:, 3:]
                pred_quaternion = F.normalize(pred_quaternion, p=2, dim=-1)

                # Prepare batch for loss computation
                batch = {
                    'w_t_c': pred_translation.unsqueeze(-1),
                    'c_R_w': quaternion_to_rotation_matrix(pred_quaternion),  # Convert predicted quaternion to rotation matrix
                    'w_t_chat': poses[:, :3].unsqueeze(-1),  # Ground truth translation
                    'chat_R_w': quaternion_to_rotation_matrix(poses[:, 3:])  # Convert ground truth quaternion to rotation matrix
                }

                loss = global_loss_fn(batch)
                loss.backward()
                optimizer.step()

                train_loss += loss.item()

            avg_train_loss = train_loss / len(train_loader)

            # Validation loop
            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for color_images, poses in val_loader:
                    color_images, poses = color_images.to(device), poses.to(device)
                    outputs = model(color_images)

                    pred_translation = outputs[:, :3]
                    pred_quaternion = outputs[:, 3:]
                    pred_quaternion = F.normalize(pred_quaternion, p=2, dim=-1)

                    batch = {
                        'w_t_c': pred_translation.unsqueeze(-1),
                        'c_R_w': quaternion_to_rotation_matrix(pred_quaternion),  # Convert predicted quaternion to rotation matrix
                        'w_t_chat': poses[:, :3].unsqueeze(-1),  # Ground truth translation
                        'chat_R_w': quaternion_to_rotation_matrix(poses[:, 3:])  # Convert ground truth quaternion to rotation matrix
                    }

                    loss = global_loss_fn(batch)
                    val_loss += loss.item()

            scheduler.step(val_loss)
            print(f'Epoch {epoch + 1}, Train Loss: {avg_train_loss}, Val Loss: {val_loss / len(val_loader)}')

Fold 1


100%|██████████| 79/79 [08:36<00:00,  6.54s/it]


Epoch 1, Train Loss: 4.835632733151883, Val Loss: 4.211350810527802


100%|██████████| 79/79 [00:20<00:00,  3.81it/s]


Epoch 2, Train Loss: 1.572676900821396, Val Loss: 2.9313368678092955


100%|██████████| 79/79 [00:20<00:00,  3.83it/s]


Epoch 3, Train Loss: 1.0168373128281365, Val Loss: 0.6819047793745995


100%|██████████| 79/79 [00:20<00:00,  3.84it/s]


Epoch 4, Train Loss: 0.5966121180902554, Val Loss: 0.5787667453289032


100%|██████████| 79/79 [00:20<00:00,  3.84it/s]


Epoch 5, Train Loss: 0.719934151519703, Val Loss: 0.7713652364909649


100%|██████████| 79/79 [00:20<00:00,  3.85it/s]


Epoch 6, Train Loss: 0.5105404842503464, Val Loss: 0.6938246741890908


100%|██████████| 79/79 [00:20<00:00,  3.84it/s]


Epoch 7, Train Loss: 0.5378586735717857, Val Loss: 0.239171876758337


100%|██████████| 79/79 [00:20<00:00,  3.83it/s]


Epoch 8, Train Loss: 0.5680619420695908, Val Loss: 1.1625386759638787


100%|██████████| 79/79 [00:20<00:00,  3.81it/s]


Epoch 9, Train Loss: 0.4106814161508898, Val Loss: 0.28597305417060853


100%|██████████| 79/79 [00:20<00:00,  3.84it/s]


Epoch 10, Train Loss: 0.25692207535988165, Val Loss: 0.6065299823880196


100%|██████████| 79/79 [00:20<00:00,  3.85it/s]


Epoch 11, Train Loss: 0.4624419523568093, Val Loss: 0.28082291334867476


100%|██████████| 79/79 [00:20<00:00,  3.82it/s]


Epoch 12, Train Loss: 0.21765380343304405, Val Loss: 0.16736471578478812


100%|██████████| 79/79 [00:20<00:00,  3.87it/s]


Epoch 13, Train Loss: 0.16312933086028583, Val Loss: 0.23454580195248126


100%|██████████| 79/79 [00:20<00:00,  3.85it/s]


Epoch 14, Train Loss: 0.14159141542224946, Val Loss: 0.1498720146715641


100%|██████████| 79/79 [00:20<00:00,  3.85it/s]


Epoch 15, Train Loss: 0.11042192318960081, Val Loss: 0.22131731137633323


100%|██████████| 79/79 [00:20<00:00,  3.85it/s]


Epoch 16, Train Loss: 0.11092270953179914, Val Loss: 0.14619202092289924


100%|██████████| 79/79 [00:20<00:00,  3.84it/s]


Epoch 17, Train Loss: 0.12014019475141659, Val Loss: 0.18386728465557098


100%|██████████| 79/79 [00:20<00:00,  3.85it/s]


Epoch 18, Train Loss: 0.10477848233112806, Val Loss: 0.13093410357832908


100%|██████████| 79/79 [00:20<00:00,  3.82it/s]


Epoch 19, Train Loss: 0.10357132591778719, Val Loss: 0.14263443946838378


100%|██████████| 79/79 [00:20<00:00,  3.83it/s]


Epoch 20, Train Loss: 0.07258059280111065, Val Loss: 0.126938870921731


100%|██████████| 79/79 [00:20<00:00,  3.79it/s]


Epoch 21, Train Loss: 0.06122600030201145, Val Loss: 0.10365483164787292


100%|██████████| 79/79 [00:20<00:00,  3.80it/s]


Epoch 22, Train Loss: 0.06136359033893935, Val Loss: 0.09086566083133221


100%|██████████| 79/79 [00:20<00:00,  3.82it/s]


Epoch 23, Train Loss: 0.06006005901513221, Val Loss: 0.0801632234826684


100%|██████████| 79/79 [00:20<00:00,  3.84it/s]


Epoch 24, Train Loss: 0.053779547067382666, Val Loss: 0.12270715162158012


100%|██████████| 79/79 [00:20<00:00,  3.81it/s]


Epoch 25, Train Loss: 0.05011114588926865, Val Loss: 0.10093873925507069


100%|██████████| 79/79 [00:20<00:00,  3.82it/s]


Epoch 26, Train Loss: 0.044287755213017706, Val Loss: 0.10475637391209602


100%|██████████| 79/79 [00:20<00:00,  3.85it/s]


Epoch 27, Train Loss: 0.06037984845004504, Val Loss: 0.1534708734601736


100%|██████████| 79/79 [00:20<00:00,  3.84it/s]


Epoch 28, Train Loss: 0.0853772972344975, Val Loss: 0.08372964654117823


100%|██████████| 79/79 [00:20<00:00,  3.85it/s]


Epoch 29, Train Loss: 0.030196519427095787, Val Loss: 0.0722780266776681


100%|██████████| 79/79 [00:20<00:00,  3.82it/s]


Epoch 30, Train Loss: 0.0264129209674046, Val Loss: 0.07779307514429093


100%|██████████| 79/79 [00:20<00:00,  3.79it/s]


Epoch 31, Train Loss: 0.02446314529810525, Val Loss: 0.07179590873420238


100%|██████████| 79/79 [00:20<00:00,  3.82it/s]


Epoch 32, Train Loss: 0.02532257559367373, Val Loss: 0.07543272124603391


100%|██████████| 79/79 [00:20<00:00,  3.84it/s]


Epoch 33, Train Loss: 0.02477998139124505, Val Loss: 0.0714752845466137


100%|██████████| 79/79 [00:20<00:00,  3.84it/s]


Epoch 34, Train Loss: 0.023567731150343448, Val Loss: 0.0716950163245201


100%|██████████| 79/79 [00:20<00:00,  3.85it/s]


Epoch 35, Train Loss: 0.0214528289352414, Val Loss: 0.0769429586827755


100%|██████████| 79/79 [00:20<00:00,  3.83it/s]


Epoch 36, Train Loss: 0.02309371192668435, Val Loss: 0.06966219237074256


100%|██████████| 79/79 [00:20<00:00,  3.80it/s]


Epoch 37, Train Loss: 0.03188523413212616, Val Loss: 0.07284466437995434


100%|██████████| 79/79 [00:20<00:00,  3.85it/s]


Epoch 38, Train Loss: 0.024483646407629116, Val Loss: 0.07195181399583817


100%|██████████| 79/79 [00:20<00:00,  3.77it/s]


Epoch 39, Train Loss: 0.020926999399745013, Val Loss: 0.06744818575680256


100%|██████████| 79/79 [00:20<00:00,  3.85it/s]


Epoch 40, Train Loss: 0.019193864462873602, Val Loss: 0.07292891778051853


100%|██████████| 79/79 [00:20<00:00,  3.83it/s]


Epoch 41, Train Loss: 0.02162107523483566, Val Loss: 0.06904572965577245


100%|██████████| 79/79 [00:20<00:00,  3.78it/s]


Epoch 42, Train Loss: 0.02024312957508277, Val Loss: 0.07505382997915148


100%|██████████| 79/79 [00:20<00:00,  3.77it/s]


Epoch 43, Train Loss: 0.018205066494455066, Val Loss: 0.07239061426371336


100%|██████████| 79/79 [00:20<00:00,  3.82it/s]


Epoch 44, Train Loss: 0.01760320634215693, Val Loss: 0.06890367409214378


100%|██████████| 79/79 [00:20<00:00,  3.86it/s]


Epoch 45, Train Loss: 0.021655154414474964, Val Loss: 0.06769340103492141


100%|██████████| 79/79 [00:20<00:00,  3.80it/s]


Epoch 46, Train Loss: 0.021572068044797905, Val Loss: 0.07911416385322809


100%|██████████| 79/79 [00:20<00:00,  3.78it/s]


Epoch 47, Train Loss: 0.01628580074050004, Val Loss: 0.07367403842508793


100%|██████████| 79/79 [00:21<00:00,  3.76it/s]


Epoch 48, Train Loss: 0.016613805507415834, Val Loss: 0.0711928667500615


100%|██████████| 79/79 [00:20<00:00,  3.83it/s]


Epoch 49, Train Loss: 0.01613051567486004, Val Loss: 0.07048201030120253


100%|██████████| 79/79 [00:20<00:00,  3.81it/s]


Epoch 50, Train Loss: 0.016875828934621206, Val Loss: 0.07085371073335409


100%|██████████| 79/79 [00:20<00:00,  3.81it/s]


Epoch 51, Train Loss: 0.01529378851852085, Val Loss: 0.06846006028354168


100%|██████████| 79/79 [00:20<00:00,  3.81it/s]


Epoch 52, Train Loss: 0.017724091664569664, Val Loss: 0.07443253006786107


100%|██████████| 79/79 [00:20<00:00,  3.82it/s]


Epoch 53, Train Loss: 0.014625633194382433, Val Loss: 0.07176572214812041


100%|██████████| 79/79 [00:20<00:00,  3.80it/s]


Epoch 54, Train Loss: 0.018483074452680878, Val Loss: 0.07991176322102547


100%|██████████| 79/79 [00:21<00:00,  3.76it/s]


Epoch 55, Train Loss: 0.015324291655251497, Val Loss: 0.0692306550219655


100%|██████████| 79/79 [00:20<00:00,  3.81it/s]


Epoch 56, Train Loss: 0.016445355755100144, Val Loss: 0.07122256709262728


100%|██████████| 79/79 [00:20<00:00,  3.83it/s]


Epoch 57, Train Loss: 0.019737895035856885, Val Loss: 0.0706227608025074


100%|██████████| 79/79 [00:20<00:00,  3.86it/s]


Epoch 58, Train Loss: 0.015696093552050334, Val Loss: 0.07164405100047588


100%|██████████| 79/79 [00:20<00:00,  3.80it/s]


Epoch 59, Train Loss: 0.015127163804784606, Val Loss: 0.06702355043962598


100%|██████████| 79/79 [00:20<00:00,  3.84it/s]


Epoch 60, Train Loss: 0.01717010413921332, Val Loss: 0.07465478535741568


100%|██████████| 79/79 [00:20<00:00,  3.84it/s]


Epoch 61, Train Loss: 0.015420476894212675, Val Loss: 0.07536697881296277


100%|██████████| 79/79 [00:20<00:00,  3.81it/s]


Epoch 62, Train Loss: 0.01585799883579529, Val Loss: 0.069304427690804


100%|██████████| 79/79 [00:20<00:00,  3.84it/s]


Epoch 63, Train Loss: 0.015168643356124056, Val Loss: 0.06911253994330764


100%|██████████| 79/79 [00:20<00:00,  3.81it/s]


Epoch 64, Train Loss: 0.015124036140645607, Val Loss: 0.066402546223253


100%|██████████| 79/79 [00:21<00:00,  3.75it/s]


Epoch 65, Train Loss: 0.016675963261035046, Val Loss: 0.07600725330412388


100%|██████████| 79/79 [00:20<00:00,  3.78it/s]


Epoch 66, Train Loss: 0.015021062526804737, Val Loss: 0.06978756533935666


100%|██████████| 79/79 [00:20<00:00,  3.81it/s]


Epoch 67, Train Loss: 0.022208596674041656, Val Loss: 0.06655892049893737


100%|██████████| 79/79 [00:20<00:00,  3.85it/s]


Epoch 68, Train Loss: 0.01740117537305702, Val Loss: 0.07361576408147812


100%|██████████| 79/79 [00:20<00:00,  3.79it/s]


Epoch 69, Train Loss: 0.015212656193283162, Val Loss: 0.07337656477466226


100%|██████████| 79/79 [00:20<00:00,  3.80it/s]


Epoch 70, Train Loss: 0.024038091702740405, Val Loss: 0.07771991901099681


100%|██████████| 79/79 [00:20<00:00,  3.80it/s]


Epoch 71, Train Loss: 0.01685760018734049, Val Loss: 0.07431906871497632


100%|██████████| 79/79 [00:20<00:00,  3.80it/s]


Epoch 72, Train Loss: 0.014638231513149377, Val Loss: 0.07280220119282603


100%|██████████| 79/79 [00:20<00:00,  3.83it/s]


Epoch 73, Train Loss: 0.015398125718288783, Val Loss: 0.071544589381665


100%|██████████| 79/79 [00:20<00:00,  3.80it/s]


Epoch 74, Train Loss: 0.015011802068145215, Val Loss: 0.07265951875597239


100%|██████████| 79/79 [00:20<00:00,  3.78it/s]


Epoch 75, Train Loss: 0.016695757903441598, Val Loss: 0.08107489831745625


100%|██████████| 79/79 [00:20<00:00,  3.83it/s]


Epoch 76, Train Loss: 0.016794908143391338, Val Loss: 0.07337910477072


100%|██████████| 79/79 [00:20<00:00,  3.83it/s]


Epoch 77, Train Loss: 0.017702955516833294, Val Loss: 0.07591576017439365


100%|██████████| 79/79 [00:20<00:00,  3.81it/s]


Epoch 78, Train Loss: 0.015613104629365705, Val Loss: 0.06971172224730253


100%|██████████| 79/79 [00:20<00:00,  3.84it/s]


Epoch 79, Train Loss: 0.015197464147040362, Val Loss: 0.06937008025124669


100%|██████████| 79/79 [00:20<00:00,  3.85it/s]


Epoch 80, Train Loss: 0.02184190057642475, Val Loss: 0.07107782857492566


100%|██████████| 79/79 [00:20<00:00,  3.85it/s]


Epoch 81, Train Loss: 0.015682744840749458, Val Loss: 0.06709403190761805


100%|██████████| 79/79 [00:20<00:00,  3.82it/s]


Epoch 82, Train Loss: 0.015400761820941786, Val Loss: 0.06998985819518566


100%|██████████| 79/79 [00:20<00:00,  3.81it/s]


Epoch 83, Train Loss: 0.015625430271029472, Val Loss: 0.0727739742025733


100%|██████████| 79/79 [00:20<00:00,  3.80it/s]


Epoch 84, Train Loss: 0.016064146610236245, Val Loss: 0.07037071045488119


100%|██████████| 79/79 [00:20<00:00,  3.83it/s]


Epoch 85, Train Loss: 0.017154627821490735, Val Loss: 0.07355728764086962


100%|██████████| 79/79 [00:20<00:00,  3.81it/s]


Epoch 86, Train Loss: 0.01734916732611158, Val Loss: 0.0715199408121407


100%|██████████| 79/79 [00:20<00:00,  3.84it/s]


Epoch 87, Train Loss: 0.015830826020174767, Val Loss: 0.0696901310235262


100%|██████████| 79/79 [00:20<00:00,  3.81it/s]


Epoch 88, Train Loss: 0.015766976059331923, Val Loss: 0.0679551638662815


100%|██████████| 79/79 [00:20<00:00,  3.84it/s]


Epoch 89, Train Loss: 0.029956351971560265, Val Loss: 0.06588344564661383


100%|██████████| 79/79 [00:20<00:00,  3.78it/s]


Epoch 90, Train Loss: 0.016500571601187126, Val Loss: 0.06914085438475012


100%|██████████| 79/79 [00:20<00:00,  3.84it/s]


Epoch 91, Train Loss: 0.01687917746367711, Val Loss: 0.06877485122531653


100%|██████████| 79/79 [00:20<00:00,  3.82it/s]


Epoch 92, Train Loss: 0.016843125430419097, Val Loss: 0.07671053204685449


100%|██████████| 79/79 [00:20<00:00,  3.82it/s]


Epoch 93, Train Loss: 0.01581092821905696, Val Loss: 0.07210729848593474


100%|██████████| 79/79 [00:20<00:00,  3.81it/s]


Epoch 94, Train Loss: 0.015126268108245692, Val Loss: 0.08454523421823978


100%|██████████| 79/79 [00:20<00:00,  3.79it/s]


Epoch 95, Train Loss: 0.01535173609286924, Val Loss: 0.06882009543478489


100%|██████████| 79/79 [00:20<00:00,  3.83it/s]


Epoch 96, Train Loss: 0.017129340244433546, Val Loss: 0.07018350306898355


100%|██████████| 79/79 [00:20<00:00,  3.83it/s]


Epoch 97, Train Loss: 0.01584874200222047, Val Loss: 0.07109872214496135


100%|██████████| 79/79 [00:20<00:00,  3.81it/s]


Epoch 98, Train Loss: 0.014908612169372507, Val Loss: 0.06803834112361073


100%|██████████| 79/79 [00:20<00:00,  3.82it/s]


Epoch 99, Train Loss: 0.015532991931408266, Val Loss: 0.06698701893910766


100%|██████████| 79/79 [00:20<00:00,  3.83it/s]


Epoch 100, Train Loss: 0.017549244121094293, Val Loss: 0.07387749096378685
Fold 2


100%|██████████| 79/79 [00:20<00:00,  3.85it/s]


Epoch 1, Train Loss: 4.65821190725399, Val Loss: 2.4290404200553892


100%|██████████| 79/79 [00:20<00:00,  3.86it/s]


Epoch 2, Train Loss: 1.6636644733857504, Val Loss: 2.1821664392948152


100%|██████████| 79/79 [00:20<00:00,  3.84it/s]


Epoch 3, Train Loss: 1.1161381213725368, Val Loss: 0.6460889011621476


100%|██████████| 79/79 [00:20<00:00,  3.79it/s]


Epoch 4, Train Loss: 0.7347933282203312, Val Loss: 0.6696714602410794


100%|██████████| 79/79 [00:20<00:00,  3.80it/s]


Epoch 5, Train Loss: 1.8006245416553714, Val Loss: 1.2479082942008972


100%|██████████| 79/79 [00:20<00:00,  3.80it/s]


Epoch 6, Train Loss: 0.8702359180661696, Val Loss: 0.40766422376036643


100%|██████████| 79/79 [00:20<00:00,  3.81it/s]


Epoch 7, Train Loss: 0.6652618605124799, Val Loss: 0.28969397246837614


100%|██████████| 79/79 [00:20<00:00,  3.85it/s]


Epoch 8, Train Loss: 0.3571737791541256, Val Loss: 0.354629073292017


100%|██████████| 79/79 [00:20<00:00,  3.82it/s]


Epoch 9, Train Loss: 0.3080367526110214, Val Loss: 0.2707753129303455


100%|██████████| 79/79 [00:20<00:00,  3.82it/s]


Epoch 10, Train Loss: 0.2939045507507988, Val Loss: 0.4594097249209881


100%|██████████| 79/79 [00:20<00:00,  3.77it/s]


Epoch 11, Train Loss: 0.2819042139792744, Val Loss: 0.24419419839978218


100%|██████████| 79/79 [00:20<00:00,  3.82it/s]


Epoch 12, Train Loss: 0.2049806741408155, Val Loss: 0.17653886526823043


100%|██████████| 79/79 [00:20<00:00,  3.79it/s]


Epoch 13, Train Loss: 0.16590441122085234, Val Loss: 0.27243184223771094


100%|██████████| 79/79 [00:20<00:00,  3.78it/s]


Epoch 14, Train Loss: 0.1641909213194364, Val Loss: 0.20791743472218513


100%|██████████| 79/79 [00:20<00:00,  3.82it/s]


Epoch 15, Train Loss: 0.13546907887617243, Val Loss: 0.2268209882080555


100%|██████████| 79/79 [00:21<00:00,  3.74it/s]


Epoch 16, Train Loss: 0.11013925924331328, Val Loss: 0.14466933868825435


100%|██████████| 79/79 [00:20<00:00,  3.78it/s]


Epoch 17, Train Loss: 0.11490315566711788, Val Loss: 0.399067759513855


100%|██████████| 79/79 [00:20<00:00,  3.81it/s]


Epoch 18, Train Loss: 0.2649229859249501, Val Loss: 0.2122339803725481


100%|██████████| 79/79 [00:20<00:00,  3.80it/s]


Epoch 19, Train Loss: 0.15625824604796457, Val Loss: 0.4619454324245453


100%|██████████| 79/79 [00:20<00:00,  3.77it/s]


Epoch 20, Train Loss: 0.8984631166050706, Val Loss: 0.5170629099011421


100%|██████████| 79/79 [00:21<00:00,  3.75it/s]


Epoch 21, Train Loss: 0.26439049368417716, Val Loss: 0.30974554605782034


100%|██████████| 79/79 [00:20<00:00,  3.79it/s]


Epoch 22, Train Loss: 0.14553517145635206, Val Loss: 0.1560083631426096


100%|██████████| 79/79 [00:20<00:00,  3.79it/s]


Epoch 23, Train Loss: 0.1395242646147933, Val Loss: 0.16265131160616875


100%|██████████| 79/79 [00:20<00:00,  3.81it/s]


Epoch 24, Train Loss: 0.12061783020632176, Val Loss: 0.14109850451350212


100%|██████████| 79/79 [00:20<00:00,  3.77it/s]


Epoch 25, Train Loss: 0.10340696752448625, Val Loss: 0.15453064497560262


100%|██████████| 79/79 [00:20<00:00,  3.81it/s]


Epoch 26, Train Loss: 0.10157929898440084, Val Loss: 0.13551596459001303


100%|██████████| 79/79 [00:20<00:00,  3.78it/s]


Epoch 27, Train Loss: 0.09576884228028829, Val Loss: 0.12778357677161695


100%|██████████| 79/79 [00:20<00:00,  3.83it/s]


Epoch 28, Train Loss: 0.10424991715935213, Val Loss: 0.17008389793336393


100%|██████████| 79/79 [00:20<00:00,  3.78it/s]


Epoch 29, Train Loss: 0.09129343526084212, Val Loss: 0.13800155483186244


 86%|████████▌ | 68/79 [00:18<00:02,  3.67it/s]
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-9-378ce34f2717>", line 31, in <cell line: 2>
    for color_images, poses in tqdm(train_loader):
  File "/usr/local/lib/python3.10/dist-packages/tqdm/std.py", line 1181, in __iter__
    for obj in iterable:
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1327, in _next_data
    idx, data = self._get_data()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1293, in _get_data
    success, data = self._try_get_data()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1131, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
 

TypeError: object of type 'NoneType' has no len()