<a href="https://colab.research.google.com/github/Husseinhhameed/CoAtXnet/blob/main/12Scence_dataset_CoAtXNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install transforms3d

Collecting transforms3d
  Downloading transforms3d-0.4.2-py3-none-any.whl.metadata (2.8 kB)
Downloading transforms3d-0.4.2-py3-none-any.whl (1.4 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.4 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m55.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: transforms3d
Successfully installed transforms3d-0.4.2


In [None]:
pip install einops




In [None]:
import os
import time
import random
import numpy as np
import pandas as pd
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
from torchvision import transforms
from torchvision.transforms import Compose, RandomHorizontalFlip, RandomRotation, ToTensor
from torchvision.transforms.functional import to_tensor
from torchvision.utils import make_grid
from einops import rearrange
from einops.layers.torch import Rearrange
from sklearn.model_selection import KFold
from tqdm import tqdm
import transforms3d.quaternions as txq
import matplotlib.pyplot as plt
import torch.nn.functional as F
import re


In [None]:
def ModifiedDSACLoss(y_true, y_pred, weight_translation_initial=1.0, weight_quaternion_initial=1.0):
    # Split the true and predicted values into delta positions and quaternions
    true_positions, true_quaternions = torch.split(y_true, [3, 4], dim=-1)
    pred_positions, pred_quaternions = torch.split(y_pred, [3, 4], dim=-1)

    # Normalize the predicted quaternions
    pred_quaternions = F.normalize(pred_quaternions, p=2, dim=-1)

    # Calculate the translation error (L2 distance)
    t_error = torch.sqrt(torch.sum((true_positions - pred_positions) ** 2, dim=-1))
    mean_t_error = torch.mean(t_error)

    # Calculate the quaternion angle error
    dot_product = torch.sum(true_quaternions * pred_quaternions, dim=-1)
    dot_product = torch.clamp(dot_product, -1.0, 1.0)  # Ensure dot product is within [-1, 1] to avoid NaNs in acos
    angle_error = 2.0 * torch.acos(torch.abs(dot_product))
    mean_q_error = torch.mean(angle_error)

    # Dynamic weight adjustment based on the error ratio
    error_ratio = mean_t_error / (mean_q_error + 1e-8)  # Adding a small epsilon to avoid division by zero
    weight_translation = weight_translation_initial * error_ratio
    weight_quaternion = weight_quaternion_initial / error_ratio

    # Apply dynamic weights to the respective errors
    weighted_t_error = weight_translation * mean_t_error
    weighted_q_error = weight_quaternion * mean_q_error

    # Final combined weighted error
    combined_weighted_error = weighted_t_error + weighted_q_error

    return combined_weighted_error


In [None]:
# Utility function to extract frame numbers from filenames
def extract_frame_number(filename):
    match = re.search(r'frame-(\d+)', filename)
    return int(match.group(1)) if match else -1

# Function to process poses
def process_poses(poses_in, mean_t, std_t, align_R, align_t, align_s):
    poses_out = np.zeros((len(poses_in), 7))
    poses_out[:, 0:3] = poses_in[:, [3, 7, 11]]  # Translation components

    # Align and process rotation
    for i in range(len(poses_out)):
        R = poses_in[i].reshape((3, 4))[:3, :3]
        q = txq.mat2quat(np.dot(align_R, R))
        q *= np.sign(q[0])  # Constrain to hemisphere
        poses_out[i, 3:] = q  # Keep the quaternion as 4D
        t = poses_out[i, :3] - align_t
        poses_out[i, :3] = align_s * np.dot(align_R, t[:, np.newaxis]).squeeze()

    # Normalize translation
    poses_out[:, :3] -= mean_t
    poses_out[:, :3] /= std_t

    return poses_out

# Function to load and process poses
def load_and_process_poses(data_dir, train=True, real=False, vo_lib='orbslam'):
    ps = []
    vo_stats = {'R': np.eye(3), 't': np.zeros(3), 's': 1}
    pose_files = [f for f in os.listdir(data_dir) if f.endswith('.pose.txt')]
    pose_files.sort(key=extract_frame_number)

    for filename in pose_files:
        pose_path = os.path.join(data_dir, filename)
        if os.path.exists(pose_path):
            pose = np.loadtxt(pose_path).flatten()[:12]
            ps.append(pose)

    ps = np.array(ps)
    pose_stats_filename = os.path.join(data_dir, 'pose_stats.txt')
    if train and not real:
        mean_t = np.mean(ps[:, [3, 7, 11]], axis=0)
        std_t = np.std(ps[:, [3, 7, 11]], axis=0)
        np.savetxt(pose_stats_filename, np.vstack((mean_t, std_t)), fmt='%8.7f')
    else:
        mean_t, std_t = np.loadtxt(pose_stats_filename)

    # Process and normalize poses
    processed_poses = process_poses(poses_in=ps, mean_t=mean_t, std_t=std_t,
                                    align_R=vo_stats['R'], align_t=vo_stats['t'],
                                    align_s=vo_stats['s'])
    return processed_poses

# Define the transformation for the images
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Dataset class
class FireDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = self._load_samples()
        self.processed_poses = self._load_processed_poses()

        print(f"Number of samples: {len(self.samples)}")
        print(f"Number of processed poses: {self.processed_poses.shape[0]}")

        if len(self.samples) != self.processed_poses.shape[0]:
            raise ValueError("Mismatch between number of samples and processed poses.")

    def _load_samples(self):
        samples = []
        color_files = [f for f in os.listdir(self.root_dir) if f.endswith('.color.jpg')]
        depth_files = [f for f in os.listdir(self.root_dir) if f.endswith('.depth.png')]
        pose_files = [f for f in os.listdir(self.root_dir) if f.endswith('.pose.txt')]

        color_files.sort(key=extract_frame_number)
        depth_files.sort(key=extract_frame_number)
        pose_files.sort(key=extract_frame_number)

        for color_file, depth_file, pose_file in zip(color_files, depth_files, pose_files):
            samples.append((os.path.join(self.root_dir, color_file),
                            os.path.join(self.root_dir, depth_file),
                            os.path.join(self.root_dir, pose_file)))
        return samples

    def _load_processed_poses(self):
        return load_and_process_poses(self.root_dir)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        color_path, depth_path, _ = self.samples[idx]

        # Read and process color image
        color_image = cv2.imread(color_path, cv2.IMREAD_COLOR)
        if color_image is None:
            raise FileNotFoundError(f"Color image not found at {color_path}")
        color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)

        # Read and process depth image
        depth_image = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
        if depth_image is None:
            raise FileNotFoundError(f"Depth image not found at {depth_path}")

        # Get the corresponding pose
        pose_matrix = self.processed_poses[idx]

        if self.transform:
            color_image = self.transform(color_image)

            # Normalize depth image and convert to uint8
            max_depth = depth_image.max()
            if max_depth > 0:
                depth_image = (depth_image / max_depth * 255).astype(np.uint8)
            else:
                depth_image = np.zeros_like(depth_image, dtype=np.uint8)
            depth_image = self.transform(depth_image)

        return color_image, depth_image, pose_matrix

In [None]:
def conv_3x3_bn(inp, oup, image_size, downsample=False):
    stride = 1 if downsample == False else 2
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.GELU()
    )

class PreNorm(nn.Module):
    def __init__(self, dim, fn, norm):
        super().__init__()
        self.norm = norm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class SE(nn.Module):
    def __init__(self, inp, oup, expansion=0.25):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(oup, int(inp * expansion), bias=False),
            nn.GELU(),
            nn.Linear(int(inp * expansion), oup, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class XMBConv(nn.Module):
    def __init__(self, inp, oup, image_size, downsample=False, expansion=4):
        super().__init__()
        self.downsample = downsample
        stride = 1 if self.downsample == False else 2
        hidden_dim = int(inp * expansion)

        if self.downsample:
            self.pool = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)
            self.poolx = nn.MaxPool2d(3, 2, 1)
            self.projx = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)

        if expansion == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
                          1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
            self.convx = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
                          1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                # down-sample in the first conv
                nn.Conv2d(inp, hidden_dim, 1, stride, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1,
                          groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                SE(inp, hidden_dim),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
            self.convx = nn.Sequential(
                # pw
                # down-sample in the first conv
                nn.Conv2d(inp, hidden_dim, 1, stride, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1,
                          groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                SE(inp, hidden_dim),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        self.conv = PreNorm(inp, self.conv, nn.BatchNorm2d)
        self.convx = PreNorm(inp, self.convx, nn.BatchNorm2d)

    def forward(self, x):
        y = x[1]
        x = x[0]

        if self.downsample:
            out1 = self.proj(self.pool(x)) + self.conv(x)
            out2 = self.projx(self.poolx(y)) + self.convx(y)
        else:
            out1 = x + self.conv(x)
            out2 = y + self.convx(y)
        return [out1, out2]

class XAttention(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == inp)

        self.ih, self.iw = image_size

        self.heads = heads
        self.scale = dim_head ** -0.5

        # parameter table of relative position bias
        self.relative_bias_table = nn.Parameter(
            torch.zeros((2 * self.ih - 1) * (2 * self.iw - 1), heads))

        coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
        coords = torch.flatten(torch.stack(coords), 1)
        relative_coords = coords[:, :, None] - coords[:, None, :]

        relative_coords[0] += self.ih - 1
        relative_coords[1] += self.iw - 1
        relative_coords[0] *= 2 * self.iw - 1
        relative_coords = rearrange(relative_coords, 'c h w -> h w c')
        relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
        self.register_buffer("relative_index", relative_index)

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False)
        self.to_kx = nn.Linear(inp, inner_dim * 1, bias=False)
        self.to_qx = nn.Linear(inp, inner_dim * 1, bias=False)


        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, oup),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x, y):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        kx = self.to_kx(y)
        qx = self.to_qx(x)

        q, k, v = map(lambda t: rearrange(
            t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        kx = rearrange(kx, 'b n (h d) -> b h n d', h=self.heads)
        qx = rearrange(qx, 'b n (h d) -> b h n d', h=self.heads)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # Use "gather" for more efficiency on GPUs
        relative_bias = self.relative_bias_table.gather(
            0, self.relative_index.repeat(1, self.heads))
        relative_bias = rearrange(
            relative_bias, '(h w) c -> 1 c h w', h=self.ih*self.iw, w=self.ih*self.iw)
        dots = dots + relative_bias

        attn = self.attend(dots)
        out = torch.matmul(attn, v)

        dots = torch.matmul(qx, kx.transpose(-1, -2)) * self.scale
        dots = dots + relative_bias
        attn = self.attend(dots)
        out = torch.matmul(attn, out)

        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out


class XTransformer(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, downsample=False, dropout=0.):
        super().__init__()
        hidden_dim = int(inp * 4)

        self.ih, self.iw = image_size
        self.downsample = downsample

        self.layer_norm = nn.LayerNorm(inp)
        self.layer_normx = nn.LayerNorm(inp)

        if self.downsample:
            self.pool1 = nn.MaxPool2d(3, 2, 1)
            self.pool2 = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)
            self.pool1x = nn.MaxPool2d(3, 2, 1)
            self.pool2x = nn.MaxPool2d(3, 2, 1)
            self.projx = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)
        self.attn = XAttention(inp, oup, image_size, heads, dim_head, dropout)
        self.ff = FeedForward(oup, hidden_dim, dropout)
        self.attnx = XAttention(inp, oup, image_size, heads, dim_head, dropout)
        self.ffx = FeedForward(oup, hidden_dim, dropout)



        self.ff = nn.Sequential(
            Rearrange('b c ih iw -> b (ih iw) c'),
            PreNorm(oup, self.ff, nn.LayerNorm),
            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
            )
        self.ffx = nn.Sequential(
            Rearrange('b c ih iw -> b (ih iw) c'),
            PreNorm(oup, self.ffx, nn.LayerNorm),
            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
            )

    def forward(self, x):
        y = x[1]
        x = x[0]

        if self.downsample:
            pool1 = self.pool2(x)
            pool1 = rearrange(pool1, 'b c ih iw -> b (ih iw) c')
            norm1 = self.layer_norm(pool1)
            pool2 = self.pool2x(y)
            pool2 = rearrange(pool2, 'b c ih iw -> b (ih iw) c')
            norm2 = self.layer_normx(pool2)

            attn1 = self.attn(norm1, norm2)
            attn1 = rearrange(attn1, 'b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
            attn2 = self.attnx(norm2, norm1)
            attn2 = rearrange(attn2, 'b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)

            out1 = self.proj(self.pool1(x)) + attn1
            out2 = self.projx(self.pool1x(y)) + attn2
        else:
            xx = rearrange(x, 'b c ih iw -> b (ih iw) c')
            norm1 = self.layer_norm(xx)
            yy = rearrange(y, 'b c ih iw -> b (ih iw) c')
            norm2 = self.layer_normx(yy)

            attn1 = self.attn(norm1, norm2)
            attn1 = rearrange(attn1, 'b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
            attn2 = self.attnx(norm2, norm1)
            attn2 = rearrange(attn2, 'b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
            out1 = x + attn1
            out2 = y + attn2

        out1 = out1 + self.ff(out1)
        out2 = out2 + self.ffx(out2)
        return [out1, out2]


class CoAtXNet(nn.Module):
    def __init__(self, image_size, in_channels, aux_channels, num_blocks, channels, block_types=['C', 'C', 'T', 'T']):
        super().__init__()
        ih, iw = image_size
        block = {'C': XMBConv, 'T': XTransformer}

        self.s0 = self._make_layer(
            conv_3x3_bn, in_channels, channels[0], num_blocks[0], (ih // 2, iw // 2))
        self.s0x = self._make_layer(
            conv_3x3_bn, aux_channels, channels[0], num_blocks[0], (ih // 2, iw // 2))

        self.s1 = self._make_layer(
            block[block_types[0]], channels[0], channels[1], num_blocks[1], (ih // 4, iw // 4))
        self.s2 = self._make_layer(
            block[block_types[1]], channels[1], channels[2], num_blocks[2], (ih // 8, iw // 8))
        self.s3 = self._make_layer(
            block[block_types[2]], channels[2], channels[3], num_blocks[3], (ih // 16, iw // 16))
        self.s4 = self._make_layer(
            block[block_types[3]], channels[3], channels[4], num_blocks[4], (ih // 32, iw // 32))

        self.pool = nn.AvgPool2d(ih // 32, 1)
        hidden_dim = 1024
        self.fc1 = nn.Linear(channels[-1] * 2, hidden_dim)
        self.relu = nn.ReLU()  # Activation function between dense layers
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)  # Another dense layer
        self.fc3 = nn.Linear(hidden_dim // 2, 7)  # Final output layer adjusted to target size

    def forward(self, x, y):

        x = self.s0(x)
        y = self.s0x(y)

        xy = self.s1([x,y])
        xy = self.s2(xy)
        xy = self.s3(xy)
        xy = self.s4(xy)

        x = xy[0]
        y = xy[1]

        x = self.pool(x).view(-1, x.shape[1])
        y = self.pool(y).view(-1, y.shape[1])

        x = torch.cat((x,y), -1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)

        return x

    def _make_layer(self, block, inp, oup, depth, image_size):
        layers = nn.ModuleList([])
        for i in range(depth):
            if i == 0:
                layers.append(block(inp, oup, image_size, downsample=True))
            else:
                layers.append(block(oup, oup, image_size))
        return nn.Sequential(*layers)



def coatxnet_0():
    num_blocks = [2, 2, 3, 5, 2]            # L
    channels = [64, 96, 192, 384, 768]      # D
    return CoAtXNet((256, 256), 3, 1, num_blocks, channels, )


def coatxnet_1():
    num_blocks = [2, 2, 6, 14, 2]           # L
    channels = [64, 96, 192, 384, 768]      # D
    return CoAtXNet((256, 256), 3, 1, num_blocks, channels, )


def coatxnet_2():
    num_blocks = [2, 2, 6, 14, 2]           # L
    channels = [128, 128, 256, 512, 1026]   # D
    return CoAtXNet((256, 256), 3, 1, num_blocks, channels, )


def coatxnet_3():
    num_blocks = [2, 2, 6, 14, 2]           # L
    channels = [192, 192, 384, 768, 1536]   # D
    return CoAtXNet((256, 256), 3, 1, num_blocks, channels, )


def coatxnet_4():
    num_blocks = [2, 2, 12, 28, 2]          # L
    channels = [192, 192, 384, 768, 1536]   # D
    return CoAtXNet((256, 256), 3, 1, num_blocks, channels, )


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


if __name__ == '__main__':
    img = torch.randn(1, 3, 256, 256)
    aux = torch.randn(1, 1, 256, 256)

    net = coatxnet_0()
    out = net(img, aux)
    print(out.shape, count_parameters(net))

    net = coatxnet_1()
    out = net(img, aux)
    print(out.shape, count_parameters(net))

    net = coatxnet_2()
    out = net(img, aux)
    print(out.shape, count_parameters(net))

    net = coatxnet_3()
    out = net(img, aux)
    print(out.shape, count_parameters(net))

    net = coatxnet_4()
    out = net(img, aux)
    print(out.shape, count_parameters(net))


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


torch.Size([1, 7]) 39113847
torch.Size([1, 7]) 73448199
torch.Size([1, 7]) 120818047
torch.Size([1, 7]) 249078535
torch.Size([1, 7]) 432612327


In [None]:
if __name__ == '__main__':
    root_dir = '/content/drive/MyDrive/12-Scense/apt1/living/data'
    dataset = FireDataset(root_dir=root_dir, transform=transform)

    kfold = KFold(n_splits=5, shuffle=True, random_state=42)
    num_epochs = 100  # Define the number of epochs

    # Check for GPU availability
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
        print(f'Fold {fold + 1}')

        train_sampler = SubsetRandomSampler(train_idx)
        val_sampler = SubsetRandomSampler(val_idx)

        train_loader = DataLoader(dataset, batch_size=16, sampler=train_sampler, num_workers=4)
        val_loader = DataLoader(dataset, batch_size=16, sampler=val_sampler, num_workers=4)

        model = coatxnet_1().to(device)  # Move model to GPU
        optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=4, factor=0.1)

        # Training loop
        for epoch in range(num_epochs):
            model.train()
            train_loss = 0.0
            for color_images, depth_images, poses in tqdm(train_loader):
                color_images, depth_images, poses = color_images.to(device), depth_images.to(device), poses.to(device)  # Move data to GPU
                optimizer.zero_grad()
                outputs = model(color_images, depth_images)
                loss = ModifiedDSACLoss(poses, outputs)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

            # Validation loop
            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for color_images, depth_images, poses in val_loader:
                    color_images, depth_images, poses = color_images.to(device), depth_images.to(device), poses.to(device)  # Move data to GPU
                    outputs = model(color_images, depth_images)
                    loss = ModifiedDSACLoss(poses, outputs)
                    val_loss += loss.item()

            # Step the scheduler
            scheduler.step(val_loss)

            print(f'Epoch {epoch + 1}, Train Loss: {train_loss / len(train_loader)}, Val Loss: {val_loss / len(val_loader)}')



Number of samples: 1528
Number of processed poses: 1528
Using device: cuda
Fold 1


100%|██████████| 77/77 [00:56<00:00,  1.36it/s]


Epoch 1, Train Loss: 2.072328303062944, Val Loss: 2.494295763351945


100%|██████████| 77/77 [00:51<00:00,  1.48it/s]


Epoch 2, Train Loss: 1.2727857613115212, Val Loss: 1.036811010035882


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 3, Train Loss: 0.9768855484260851, Val Loss: 0.8662330734260406


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 4, Train Loss: 0.843186569762465, Val Loss: 0.9122369341924808


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 5, Train Loss: 0.7475745154278491, Val Loss: 0.5908044438338605


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 6, Train Loss: 0.639247814247892, Val Loss: 0.5550305825158682


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 7, Train Loss: 0.6055205736821808, Val Loss: 0.4744993823600585


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 8, Train Loss: 0.523979890101342, Val Loss: 0.5085714862093688


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 9, Train Loss: 0.4634256213071149, Val Loss: 0.4729018904899952


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 10, Train Loss: 0.4750901092540258, Val Loss: 0.41193860970568474


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 11, Train Loss: 0.516185639210243, Val Loss: 0.4729047711058839


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 12, Train Loss: 0.41925692052593627, Val Loss: 0.3485435430603486


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 13, Train Loss: 0.3540437111748613, Val Loss: 0.3614929090286971


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 14, Train Loss: 0.4088431346805357, Val Loss: 0.3845573378971928


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 15, Train Loss: 0.381948138947756, Val Loss: 0.3164068551686605


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 16, Train Loss: 0.328044555571106, Val Loss: 0.30799316094036966


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 17, Train Loss: 0.3252420235510637, Val Loss: 0.3563695046372598


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 18, Train Loss: 0.33265807483870025, Val Loss: 0.35995830281343666


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 19, Train Loss: 0.4014612808178525, Val Loss: 0.32209474002104244


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 20, Train Loss: 0.3327822263701255, Val Loss: 0.3453987339784173


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 21, Train Loss: 0.3110412321339518, Val Loss: 0.33246333164513203


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 22, Train Loss: 0.22169944084340354, Val Loss: 0.20814580164458102


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 23, Train Loss: 0.17775325524369323, Val Loss: 0.17787409857365097


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 24, Train Loss: 0.16014766658463414, Val Loss: 0.17045557501592853


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 25, Train Loss: 0.14842619957912437, Val Loss: 0.15677044642097632


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 26, Train Loss: 0.14456519693592862, Val Loss: 0.16773152176523992


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 27, Train Loss: 0.1439904901673825, Val Loss: 0.14344106691252584


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 28, Train Loss: 0.1337606690443588, Val Loss: 0.1497787066095297


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 29, Train Loss: 0.13093934903633125, Val Loss: 0.14236595389735268


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 30, Train Loss: 0.12712536584360098, Val Loss: 0.14723221738927084


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 31, Train Loss: 0.12016344827835677, Val Loss: 0.17122375033190068


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 32, Train Loss: 0.12263334221265082, Val Loss: 0.14480630850820522


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 33, Train Loss: 0.12048695428568124, Val Loss: 0.1408036843138683


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 34, Train Loss: 0.11115766484903705, Val Loss: 0.1389001880152764


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 35, Train Loss: 0.1147133452927346, Val Loss: 0.14136187513010132


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 36, Train Loss: 0.11165816242886337, Val Loss: 0.14739826953746593


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 37, Train Loss: 0.10515114559576516, Val Loss: 0.12759148125794229


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 38, Train Loss: 0.10138757891795437, Val Loss: 0.13904046453565952


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 39, Train Loss: 0.10245447786599446, Val Loss: 0.13258313433179522


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 40, Train Loss: 0.104509706078322, Val Loss: 0.13783220861974926


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 41, Train Loss: 0.10260272188280187, Val Loss: 0.1257864091937678


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 42, Train Loss: 0.10632021572984812, Val Loss: 0.13412370791559164


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 43, Train Loss: 0.09943270458266136, Val Loss: 0.12486940568295071


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 44, Train Loss: 0.09514083572587285, Val Loss: 0.13436339651462798


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 45, Train Loss: 0.09419031399682326, Val Loss: 0.13671603179020658


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 46, Train Loss: 0.10031795945847137, Val Loss: 0.1258221665407692


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 47, Train Loss: 0.09969867719764074, Val Loss: 0.11945489906823356


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 48, Train Loss: 0.08953750248698997, Val Loss: 0.13041351740709056


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 49, Train Loss: 0.09848669770660831, Val Loss: 0.12767382123766327


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 50, Train Loss: 0.08345185620770276, Val Loss: 0.12923133327851188


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 51, Train Loss: 0.0873342704063732, Val Loss: 0.1369958835810472


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 52, Train Loss: 0.08632880800593333, Val Loss: 0.1272938673996975


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 53, Train Loss: 0.0736090533749387, Val Loss: 0.12626089187620454


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 54, Train Loss: 0.07083755994318562, Val Loss: 0.12173106612845133


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 55, Train Loss: 0.06643919964576617, Val Loss: 0.11807227536771996


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 56, Train Loss: 0.0637227689046717, Val Loss: 0.1183331946890512


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 57, Train Loss: 0.06640171775752372, Val Loss: 0.12404548245873974


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 58, Train Loss: 0.06766977338771078, Val Loss: 0.1233881800641752


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 59, Train Loss: 0.06385319858739961, Val Loss: 0.1284548838248405


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 60, Train Loss: 0.06571405809812682, Val Loss: 0.11693665631702


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 61, Train Loss: 0.06382831861269898, Val Loss: 0.11304953518520619


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 62, Train Loss: 0.06549969618900849, Val Loss: 0.11686790680880879


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 63, Train Loss: 0.06361175109609923, Val Loss: 0.11904935129746877


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 64, Train Loss: 0.06763482371805568, Val Loss: 0.12335452215232394


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 65, Train Loss: 0.06446347522283241, Val Loss: 0.13842804821585344


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 66, Train Loss: 0.06304934860408336, Val Loss: 0.11283054354186621


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 67, Train Loss: 0.06178970711192597, Val Loss: 0.11558238843774984


100%|██████████| 77/77 [00:52<00:00,  1.47it/s]


Epoch 68, Train Loss: 0.06304435657664595, Val Loss: 0.1216596076487813


 55%|█████▍    | 42/77 [00:30<00:25,  1.40it/s]


KeyboardInterrupt: 