<a href="https://colab.research.google.com/github/Husseinhhameed/CoAtXnet/blob/main/CoAtXnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fire

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install transforms3d



In [None]:
pip install einops




In [None]:
import os
import time
import random
import numpy as np
import pandas as pd
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
from torchvision import transforms
from torchvision.transforms import Compose, RandomHorizontalFlip, RandomRotation, ToTensor
from torchvision.transforms.functional import to_tensor
from torchvision.utils import make_grid
from einops import rearrange
from einops.layers.torch import Rearrange
from sklearn.model_selection import KFold
from tqdm import tqdm
import transforms3d.quaternions as txq
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [None]:
def ModifiedDSACLoss(y_true, y_pred, weight_translation_initial=1.0, weight_quaternion_initial=1.0):
    # Split the true and predicted values into delta positions and quaternions
    true_positions, true_quaternions = torch.split(y_true, [3, 4], dim=-1)
    pred_positions, pred_quaternions = torch.split(y_pred, [3, 4], dim=-1)

    # Normalize the predicted quaternions
    pred_quaternions = F.normalize(pred_quaternions, p=2, dim=-1)

    # Calculate the translation error (L2 distance)
    t_error = torch.sqrt(torch.sum((true_positions - pred_positions) ** 2, dim=-1))
    mean_t_error = torch.mean(t_error)

    # Calculate the quaternion angle error
    dot_product = torch.sum(true_quaternions * pred_quaternions, dim=-1)
    dot_product = torch.clamp(dot_product, -1.0, 1.0)  # Ensure dot product is within [-1, 1] to avoid NaNs in acos
    angle_error = 2.0 * torch.acos(torch.abs(dot_product))
    mean_q_error = torch.mean(angle_error)

    # Dynamic weight adjustment based on the error ratio
    error_ratio = mean_t_error / (mean_q_error + 1e-8)  # Adding a small epsilon to avoid division by zero
    weight_translation = weight_translation_initial * error_ratio
    weight_quaternion = weight_quaternion_initial / error_ratio

    # Apply dynamic weights to the respective errors
    weighted_t_error = weight_translation * mean_t_error
    weighted_q_error = weight_quaternion * mean_q_error

    # Final combined weighted error
    combined_weighted_error = weighted_t_error + weighted_q_error

    return combined_weighted_error

In [None]:

def qlog(q):
    q = q / np.linalg.norm(q)
    sinhalftheta = np.linalg.norm(q[1:])
    coshalftheta = q[0]
    r = np.arctan2(sinhalftheta, coshalftheta)
    if sinhalftheta > 1e-6:
        qlog = r * q[1:] / sinhalftheta
    else:
        qlog = np.zeros(3)
    return qlog

def process_poses(poses_in, mean_t, std_t, align_R, align_t, align_s):
    poses_out = np.zeros((len(poses_in), 6))
    poses_out[:, 0:3] = poses_in[:, [3, 7, 11]]  # Translation components

    # Align and process rotation
    for i in range(len(poses_out)):
        R = poses_in[i].reshape((3, 4))[:3, :3]
        q = txq.mat2quat(np.dot(align_R, R))
        q *= np.sign(q[0])  # Constrain to hemisphere
        q = qlog(q)
        poses_out[i, 3:] = q
        t = poses_out[i, :3] - align_t
        poses_out[i, :3] = align_s * np.dot(align_R, t[:, np.newaxis]).squeeze()

    # Normalize translation
    poses_out[:, :3] -= mean_t
    poses_out[:, :3] /= std_t

    return poses_out

def load_and_process_poses(data_dir, seqs, train=True, real=False, vo_lib='orbslam'):
    ps = {}
    vo_stats = {}
    all_poses = []
    for seq in seqs:
        seq_dir = os.path.join(data_dir, seq)
        p_filenames = [n for n in os.listdir(seq_dir) if n.find('pose') >= 0]
        if real:
            pose_file = os.path.join(data_dir, '{:s}_poses'.format(vo_lib), seq)
            pss = np.loadtxt(pose_file)
            frame_idx = pss[:, 0].astype(int)
            if vo_lib == 'libviso2':
                frame_idx -= 1
            ps[seq] = pss[:, 1:13]
            vo_stats_filename = os.path.join(seq_dir, '{:s}_vo_stats.pkl'.format(vo_lib))
            with open(vo_stats_filename, 'rb') as f:
                vo_stats[seq] = pickle.load(f)
        else:
            frame_idx = np.array(range(len(p_filenames)), dtype=int)
            pss = [np.loadtxt(os.path.join(seq_dir, 'frame-{:06d}.pose.txt'.format(i))).flatten()[:12] for i in frame_idx if os.path.exists(os.path.join(seq_dir, 'frame-{:06d}.pose.txt'.format(i)))]
            ps[seq] = np.asarray(pss)
            vo_stats[seq] = {'R': np.eye(3), 't': np.zeros(3), 's': 1}

        all_poses.append(ps[seq])

    all_poses = np.vstack(all_poses)
    pose_stats_filename = os.path.join(data_dir, 'pose_stats.txt')
    if train and not real:
        mean_t = np.mean(all_poses[:, [3, 7, 11]], axis=0)
        std_t = np.std(all_poses[:, [3, 7, 11]], axis=0)
        np.savetxt(pose_stats_filename, np.vstack((mean_t, std_t)), fmt='%8.7f')
    else:
        mean_t, std_t = np.loadtxt(pose_stats_filename)

    # Process and normalize poses
    processed_poses = []
    for seq in seqs:
        pss = process_poses(poses_in=ps[seq], mean_t=mean_t, std_t=std_t,
                            align_R=vo_stats[seq]['R'], align_t=vo_stats[seq]['t'],
                            align_s=vo_stats[seq]['s'])
        processed_poses.append(pss)

    return np.vstack(processed_poses)

# Define the transformation for the images
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Dataset class
class FireDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.seqs = [seq for seq in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, seq))]
        self.samples = self._load_samples()
        self.processed_poses = self._load_processed_poses()

        # Debugging prints
        print(f"Number of samples: {len(self.samples)}")
        print(f"Number of processed poses: {self.processed_poses.shape[0]}")

        # Ensure consistency between samples and processed poses
        min_length = min(len(self.samples), self.processed_poses.shape[0])
        self.samples = self.samples[:min_length]
        self.processed_poses = self.processed_poses[:min_length]

    def _load_samples(self):
        samples = []
        for seq_folder in self.seqs:
            seq_path = os.path.join(self.root_dir, seq_folder)
            color_files = sorted([f for f in os.listdir(seq_path) if f.endswith('.color.png')])
            depth_files = sorted([f for f in os.listdir(seq_path) if f.endswith('.depth.png')])
            pose_files = sorted([f for f in os.listdir(seq_path) if f.endswith('.pose.txt')])

            for color_file, depth_file, pose_file in zip(color_files, depth_files, pose_files):
                samples.append((os.path.join(seq_path, color_file),
                                os.path.join(seq_path, depth_file),
                                os.path.join(seq_path, pose_file)))
        return samples

    def _load_processed_poses(self):
        return load_and_process_poses(self.root_dir, self.seqs)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if idx >= len(self.processed_poses):
            raise IndexError(f"Index {idx} out of bounds for processed poses of size {len(self.processed_poses)}")

        color_path, depth_path, pose_path = self.samples[idx]

        color_image = cv2.imread(color_path, cv2.IMREAD_COLOR)
        depth_image = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
        pose_matrix = self.processed_poses[idx]

        if self.transform:
            color_image = self.transform(color_image)
            depth_image = (depth_image / depth_image.max() * 255).astype(np.uint8)
            depth_image = self.transform(depth_image)

        return color_image, depth_image, pose_matrix


In [None]:
def conv_3x3_bn(inp, oup, image_size, downsample=False):
    stride = 1 if downsample == False else 2
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.GELU()
    )

class PreNorm(nn.Module):
    def __init__(self, dim, fn, norm):
        super().__init__()
        self.norm = norm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class SE(nn.Module):
    def __init__(self, inp, oup, expansion=0.25):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(oup, int(inp * expansion), bias=False),
            nn.GELU(),
            nn.Linear(int(inp * expansion), oup, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class XMBConv(nn.Module):
    def __init__(self, inp, oup, image_size, downsample=False, expansion=4):
        super().__init__()
        self.downsample = downsample
        stride = 1 if self.downsample == False else 2
        hidden_dim = int(inp * expansion)

        if self.downsample:
            self.pool = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)
            self.poolx = nn.MaxPool2d(3, 2, 1)
            self.projx = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)

        if expansion == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
                          1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
            self.convx = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
                          1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                # down-sample in the first conv
                nn.Conv2d(inp, hidden_dim, 1, stride, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1,
                          groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                SE(inp, hidden_dim),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
            self.convx = nn.Sequential(
                # pw
                # down-sample in the first conv
                nn.Conv2d(inp, hidden_dim, 1, stride, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1,
                          groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                SE(inp, hidden_dim),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        self.conv = PreNorm(inp, self.conv, nn.BatchNorm2d)
        self.convx = PreNorm(inp, self.convx, nn.BatchNorm2d)

    def forward(self, x):
        y = x[1]
        x = x[0]

        if self.downsample:
            out1 = self.proj(self.pool(x)) + self.conv(x)
            out2 = self.projx(self.poolx(y)) + self.convx(y)
        else:
            out1 = x + self.conv(x)
            out2 = y + self.convx(y)
        return [out1, out2]

class XAttention(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == inp)

        self.ih, self.iw = image_size

        self.heads = heads
        self.scale = dim_head ** -0.5

        # parameter table of relative position bias
        self.relative_bias_table = nn.Parameter(
            torch.zeros((2 * self.ih - 1) * (2 * self.iw - 1), heads))

        coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
        coords = torch.flatten(torch.stack(coords), 1)
        relative_coords = coords[:, :, None] - coords[:, None, :]

        relative_coords[0] += self.ih - 1
        relative_coords[1] += self.iw - 1
        relative_coords[0] *= 2 * self.iw - 1
        relative_coords = rearrange(relative_coords, 'c h w -> h w c')
        relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
        self.register_buffer("relative_index", relative_index)

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False)
        self.to_kx = nn.Linear(inp, inner_dim * 1, bias=False)
        self.to_qx = nn.Linear(inp, inner_dim * 1, bias=False)


        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, oup),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x, y):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        kx = self.to_kx(y)
        qx = self.to_qx(x)

        q, k, v = map(lambda t: rearrange(
            t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        kx = rearrange(kx, 'b n (h d) -> b h n d', h=self.heads)
        qx = rearrange(qx, 'b n (h d) -> b h n d', h=self.heads)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # Use "gather" for more efficiency on GPUs
        relative_bias = self.relative_bias_table.gather(
            0, self.relative_index.repeat(1, self.heads))
        relative_bias = rearrange(
            relative_bias, '(h w) c -> 1 c h w', h=self.ih*self.iw, w=self.ih*self.iw)
        dots = dots + relative_bias

        attn = self.attend(dots)
        out = torch.matmul(attn, v)

        dots = torch.matmul(qx, kx.transpose(-1, -2)) * self.scale
        dots = dots + relative_bias
        attn = self.attend(dots)
        out = torch.matmul(attn, out)

        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out


class XTransformer(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, downsample=False, dropout=0.):
        super().__init__()
        hidden_dim = int(inp * 4)

        self.ih, self.iw = image_size
        self.downsample = downsample

        self.layer_norm = nn.LayerNorm(inp)
        self.layer_normx = nn.LayerNorm(inp)

        if self.downsample:
            self.pool1 = nn.MaxPool2d(3, 2, 1)
            self.pool2 = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)
            self.pool1x = nn.MaxPool2d(3, 2, 1)
            self.pool2x = nn.MaxPool2d(3, 2, 1)
            self.projx = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)
        self.attn = XAttention(inp, oup, image_size, heads, dim_head, dropout)
        self.ff = FeedForward(oup, hidden_dim, dropout)
        self.attnx = XAttention(inp, oup, image_size, heads, dim_head, dropout)
        self.ffx = FeedForward(oup, hidden_dim, dropout)



        self.ff = nn.Sequential(
            Rearrange('b c ih iw -> b (ih iw) c'),
            PreNorm(oup, self.ff, nn.LayerNorm),
            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
            )
        self.ffx = nn.Sequential(
            Rearrange('b c ih iw -> b (ih iw) c'),
            PreNorm(oup, self.ffx, nn.LayerNorm),
            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
            )

    def forward(self, x):
        y = x[1]
        x = x[0]

        if self.downsample:
            pool1 = self.pool2(x)
            pool1 = rearrange(pool1, 'b c ih iw -> b (ih iw) c')
            norm1 = self.layer_norm(pool1)
            pool2 = self.pool2x(y)
            pool2 = rearrange(pool2, 'b c ih iw -> b (ih iw) c')
            norm2 = self.layer_normx(pool2)

            attn1 = self.attn(norm1, norm2)
            attn1 = rearrange(attn1, 'b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
            attn2 = self.attnx(norm2, norm1)
            attn2 = rearrange(attn2, 'b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)

            out1 = self.proj(self.pool1(x)) + attn1
            out2 = self.projx(self.pool1x(y)) + attn2
        else:
            xx = rearrange(x, 'b c ih iw -> b (ih iw) c')
            norm1 = self.layer_norm(xx)
            yy = rearrange(y, 'b c ih iw -> b (ih iw) c')
            norm2 = self.layer_normx(yy)

            attn1 = self.attn(norm1, norm2)
            attn1 = rearrange(attn1, 'b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
            attn2 = self.attnx(norm2, norm1)
            attn2 = rearrange(attn2, 'b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
            out1 = x + attn1
            out2 = y + attn2

        out1 = out1 + self.ff(out1)
        out2 = out2 + self.ffx(out2)
        return [out1, out2]


class CoAtXNet(nn.Module):
    def __init__(self, image_size, in_channels, aux_channels, num_blocks, channels, block_types=['C', 'C', 'T', 'T']):
        super().__init__()
        ih, iw = image_size
        block = {'C': XMBConv, 'T': XTransformer}

        self.s0 = self._make_layer(
            conv_3x3_bn, in_channels, channels[0], num_blocks[0], (ih // 2, iw // 2))
        self.s0x = self._make_layer(
            conv_3x3_bn, aux_channels, channels[0], num_blocks[0], (ih // 2, iw // 2))

        self.s1 = self._make_layer(
            block[block_types[0]], channels[0], channels[1], num_blocks[1], (ih // 4, iw // 4))
        self.s2 = self._make_layer(
            block[block_types[1]], channels[1], channels[2], num_blocks[2], (ih // 8, iw // 8))
        self.s3 = self._make_layer(
            block[block_types[2]], channels[2], channels[3], num_blocks[3], (ih // 16, iw // 16))
        self.s4 = self._make_layer(
            block[block_types[3]], channels[3], channels[4], num_blocks[4], (ih // 32, iw // 32))

        self.pool = nn.AvgPool2d(ih // 32, 1)
        hidden_dim = 1024
        self.fc1 = nn.Linear(channels[-1] * 2, hidden_dim)
        self.relu = nn.ReLU()  # Activation function between dense layers
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)  # Another dense layer
        self.fc3 = nn.Linear(hidden_dim // 2, 6)  # Final output layer adjusted to target size

    def forward(self, x, y):

        x = self.s0(x)
        y = self.s0x(y)

        xy = self.s1([x,y])
        xy = self.s2(xy)
        xy = self.s3(xy)
        xy = self.s4(xy)

        x = xy[0]
        y = xy[1]

        x = self.pool(x).view(-1, x.shape[1])
        y = self.pool(y).view(-1, y.shape[1])

        x = torch.cat((x,y), -1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)

        return x

    def _make_layer(self, block, inp, oup, depth, image_size):
        layers = nn.ModuleList([])
        for i in range(depth):
            if i == 0:
                layers.append(block(inp, oup, image_size, downsample=True))
            else:
                layers.append(block(oup, oup, image_size))
        return nn.Sequential(*layers)



def coatxnet_0():
    num_blocks = [2, 2, 3, 5, 2]            # L
    channels = [64, 96, 192, 384, 768]      # D
    return CoAtXNet((256, 256), 3, 1, num_blocks, channels, )


def coatxnet_1():
    num_blocks = [2, 2, 6, 14, 2]           # L
    channels = [64, 96, 192, 384, 768]      # D
    return CoAtXNet((256, 256), 3, 1, num_blocks, channels, )


def coatxnet_2():
    num_blocks = [2, 2, 6, 14, 2]           # L
    channels = [128, 128, 256, 512, 1026]   # D
    return CoAtXNet((256, 256), 3, 1, num_blocks, channels, )


def coatxnet_3():
    num_blocks = [2, 2, 6, 14, 2]           # L
    channels = [192, 192, 384, 768, 1536]   # D
    return CoAtXNet((256, 256), 3, 1, num_blocks, channels, )


def coatxnet_4():
    num_blocks = [2, 2, 12, 28, 2]          # L
    channels = [192, 192, 384, 768, 1536]   # D
    return CoAtXNet((256, 256), 3, 1, num_blocks, channels, )


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


if __name__ == '__main__':
    img = torch.randn(1, 3, 256, 256)
    aux = torch.randn(1, 1, 256, 256)

    net = coatxnet_0()
    out = net(img, aux)
    print(out.shape, count_parameters(net))

    net = coatxnet_1()
    out = net(img, aux)
    print(out.shape, count_parameters(net))

    net = coatxnet_2()
    out = net(img, aux)
    print(out.shape, count_parameters(net))

    net = coatxnet_3()
    out = net(img, aux)
    print(out.shape, count_parameters(net))

    net = coatxnet_4()
    out = net(img, aux)
    print(out.shape, count_parameters(net))


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


torch.Size([1, 6]) 39113334
torch.Size([1, 6]) 73447686
torch.Size([1, 6]) 120817534
torch.Size([1, 6]) 249078022
torch.Size([1, 6]) 432611814


In [None]:
if __name__ == '__main__':
    root_dir = '/content/drive/MyDrive/fire'
    dataset = FireDataset(root_dir=root_dir, transform=transform)

    kfold = KFold(n_splits=5, shuffle=True, random_state=42)
    num_epochs = 100  # Define the number of epochs

    # Check for GPU availability
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
        print(f'Fold {fold + 1}')

        train_sampler = SubsetRandomSampler(train_idx)
        val_sampler = SubsetRandomSampler(val_idx)

        train_loader = DataLoader(dataset, batch_size=16, sampler=train_sampler, num_workers=4)
        val_loader = DataLoader(dataset, batch_size=16, sampler=val_sampler, num_workers=4)

        model = coatxnet_0().to(device)  # Move model to GPU
        optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=4, factor=0.1)

        # Training loop
        for epoch in range(num_epochs):
            model.train()
            train_loss = 0.0
            for color_images, depth_images, poses in tqdm(train_loader):
                color_images, depth_images, poses = color_images.to(device), depth_images.to(device), poses.to(device)  # Move data to GPU
                optimizer.zero_grad()
                outputs = model(color_images, depth_images)
                loss = ModifiedDSACLoss(poses, outputs)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

            # Validation loop
            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for color_images, depth_images, poses in val_loader:
                    color_images, depth_images, poses = color_images.to(device), depth_images.to(device), poses.to(device)  # Move data to GPU
                    outputs = model(color_images, depth_images)
                    loss = ModifiedDSACLoss(poses, outputs)
                    val_loss += loss.item()

            # Step the scheduler
            scheduler.step(val_loss)

            print(f'Epoch {epoch + 1}, Train Loss: {train_loss / len(train_loader)}, Val Loss: {val_loss / len(val_loader)}')



Number of samples: 4003
Number of processed poses: 4000
Using device: cuda
Fold 1


100%|██████████| 200/200 [01:57<00:00,  1.71it/s]


Epoch 1, Train Loss: 1.993081813863203, Val Loss: 1.342734135745442


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 2, Train Loss: 0.9658291913098104, Val Loss: 0.8676206565434796


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 3, Train Loss: 0.6989103474238176, Val Loss: 0.6251437482967169


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 4, Train Loss: 0.5218564901809246, Val Loss: 0.5353583551580716


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 5, Train Loss: 0.47359679377628566, Val Loss: 0.3963516964535389


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 6, Train Loss: 0.4002213074400416, Val Loss: 0.3980103263006832


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 7, Train Loss: 0.3714294430822656, Val Loss: 0.31901823638348614


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 8, Train Loss: 0.37567142263951275, Val Loss: 0.41692652170206235


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 9, Train Loss: 0.3490380869266771, Val Loss: 0.3640785355383345


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 10, Train Loss: 0.3196897188636759, Val Loss: 0.35211216782080534


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 11, Train Loss: 0.2924918472776177, Val Loss: 0.36402951486870466


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 12, Train Loss: 0.2852795043122164, Val Loss: 0.29195322906530025


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 13, Train Loss: 0.2906443618856737, Val Loss: 0.2704663847746235


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 14, Train Loss: 0.2764799658818994, Val Loss: 0.29301771710626917


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 15, Train Loss: 0.25589874241065264, Val Loss: 0.3417427917358372


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 16, Train Loss: 0.24686881488549342, Val Loss: 0.28533759292339755


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 17, Train Loss: 0.23762950547745323, Val Loss: 0.2537030939050054


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 18, Train Loss: 0.23330379756050834, Val Loss: 0.2934001698328361


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 19, Train Loss: 0.2517409859986225, Val Loss: 0.2737914588499213


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 20, Train Loss: 0.23013957228681828, Val Loss: 0.29402469039640006


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 21, Train Loss: 0.2359200574850113, Val Loss: 0.2199113802446532


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 22, Train Loss: 0.22428912503266538, Val Loss: 0.25954228140969


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 23, Train Loss: 0.21367057965250866, Val Loss: 0.21425162465194791


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 24, Train Loss: 0.20972587798832426, Val Loss: 0.2388325921145366


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 25, Train Loss: 0.2201044695925851, Val Loss: 0.20651749247688514


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 26, Train Loss: 0.20449698102541883, Val Loss: 0.20830642411844968


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 27, Train Loss: 0.1854488479469725, Val Loss: 0.2296237209787755


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 28, Train Loss: 0.19679306952951425, Val Loss: 0.2627188230917013


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 29, Train Loss: 0.18635946726460573, Val Loss: 0.2267239670553307


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 30, Train Loss: 0.19041061615504365, Val Loss: 0.21375600145215812


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 31, Train Loss: 0.10609335090851339, Val Loss: 0.10304507703529485


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 32, Train Loss: 0.07590008388287368, Val Loss: 0.10393898522327094


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 33, Train Loss: 0.06849639585851264, Val Loss: 0.09298749940944873


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 34, Train Loss: 0.06320208568505167, Val Loss: 0.10071684801213739


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 35, Train Loss: 0.06516647013169892, Val Loss: 0.08934462650201025


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 36, Train Loss: 0.058511080965204625, Val Loss: 0.08968720868974644


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 37, Train Loss: 0.05586276073274659, Val Loss: 0.08942123949896948


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 38, Train Loss: 0.05123415451348163, Val Loss: 0.08537135678727908


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 39, Train Loss: 0.053600711801501694, Val Loss: 0.0914419219188082


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 40, Train Loss: 0.050306716097115525, Val Loss: 0.08810283674984634


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 41, Train Loss: 0.05138737884284542, Val Loss: 0.09413731636673596


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 42, Train Loss: 0.049518169076566226, Val Loss: 0.09221198389882351


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 43, Train Loss: 0.047928958023687816, Val Loss: 0.08957258236612368


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 44, Train Loss: 0.040228569963272405, Val Loss: 0.0819921978770416


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 45, Train Loss: 0.03664603921198504, Val Loss: 0.08177976414612113


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 46, Train Loss: 0.036029112830419104, Val Loss: 0.08108790803523096


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 47, Train Loss: 0.036356574069241, Val Loss: 0.08590253675021442


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 48, Train Loss: 0.03531358744276075, Val Loss: 0.08444169286490705


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 49, Train Loss: 0.03473574723663344, Val Loss: 0.08304508347640883


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 50, Train Loss: 0.03521336427091646, Val Loss: 0.08441893229803071


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 51, Train Loss: 0.034095362612645654, Val Loss: 0.0921290936057079


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 52, Train Loss: 0.03420557604355843, Val Loss: 0.08769997796257727


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 53, Train Loss: 0.033861859940441476, Val Loss: 0.0883238790022642


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 54, Train Loss: 0.03321892847650239, Val Loss: 0.08636812211821603


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 55, Train Loss: 0.03304571201169126, Val Loss: 0.08688399399052896


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 56, Train Loss: 0.03250617886805878, Val Loss: 0.08848238178189258


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 57, Train Loss: 0.032950720485830524, Val Loss: 0.08525164992022022


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 58, Train Loss: 0.03328098749833187, Val Loss: 0.08387427329635161


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 59, Train Loss: 0.03277448180586397, Val Loss: 0.08261084336711381


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 60, Train Loss: 0.033032194197522784, Val Loss: 0.08497118037378437


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 61, Train Loss: 0.03268516572174336, Val Loss: 0.08559771737193528


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 62, Train Loss: 0.0323632280385949, Val Loss: 0.08491846867818338


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 63, Train Loss: 0.0332257199608498, Val Loss: 0.08325519242233949


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 64, Train Loss: 0.03243675097173406, Val Loss: 0.0863402939974001


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 65, Train Loss: 0.03287939306210186, Val Loss: 0.08562019890632727


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 66, Train Loss: 0.033142621614470784, Val Loss: 0.08274469177555285


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 67, Train Loss: 0.03277209845026597, Val Loss: 0.08252571427305382


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 68, Train Loss: 0.03212446193995691, Val Loss: 0.08396444337845976


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 69, Train Loss: 0.03251794536220134, Val Loss: 0.0939030239655982


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 70, Train Loss: 0.03240819663988994, Val Loss: 0.08409225073513431


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 71, Train Loss: 0.03336278001725261, Val Loss: 0.08156803157646608


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 72, Train Loss: 0.032415017988924134, Val Loss: 0.08736801877397067


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 73, Train Loss: 0.03256668335260484, Val Loss: 0.08532750975007886


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 74, Train Loss: 0.032734897776716794, Val Loss: 0.08462067024701143


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 75, Train Loss: 0.03274727325242893, Val Loss: 0.08610068802337123


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 76, Train Loss: 0.03291786886134636, Val Loss: 0.08688459577987866


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 77, Train Loss: 0.03265877287095932, Val Loss: 0.08658923176238435


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 78, Train Loss: 0.03279096791466047, Val Loss: 0.08505740182855465


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 79, Train Loss: 0.03242267338098398, Val Loss: 0.08551828428382896


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 80, Train Loss: 0.03295913569043028, Val Loss: 0.08551890742097956


100%|██████████| 200/200 [01:58<00:00,  1.69it/s]


Epoch 81, Train Loss: 0.03285468637969745, Val Loss: 0.08354877143805188


  8%|▊         | 16/200 [00:11<02:08,  1.43it/s]


KeyboardInterrupt: 

## Chess

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install transforms3d

Collecting transforms3d
  Downloading transforms3d-0.4.2-py3-none-any.whl (1.4 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.4 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.1/1.4 MB[0m [31m4.0 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.4/1.4 MB[0m [31m5.6 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━[0m [32m0.7/1.4 MB[0m [31m6.5 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━[0m [32m1.0/1.4 MB[0m [31m7.3 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.4/1.4 MB[0m [31m8.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
Installing coll

In [None]:
pip install einops

Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0


In [None]:
import os
import time
import random
import numpy as np
import pandas as pd
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
from torchvision import transforms
from torchvision.transforms import Compose, RandomHorizontalFlip, RandomRotation, ToTensor
from torchvision.transforms.functional import to_tensor
from torchvision.utils import make_grid
from einops import rearrange
from einops.layers.torch import Rearrange
from sklearn.model_selection import KFold
from tqdm import tqdm
import transforms3d.quaternions as txq
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [None]:
def ModifiedDSACLoss(y_true, y_pred, weight_translation_initial=1.0, weight_quaternion_initial=1.0):
    # Split the true and predicted values into delta positions and logarithmic quaternions
    true_positions, true_log_quaternions = torch.split(y_true, [3, 3], dim=-1)
    pred_positions, pred_log_quaternions = torch.split(y_pred, [3, 3], dim=-1)

    # Calculate the translation error (L2 distance)
    t_error = torch.sqrt(torch.sum((true_positions - pred_positions) ** 2, dim=-1))
    mean_t_error = torch.mean(t_error)

    # Calculate the logarithmic quaternion error (L2 distance)
    q_error = torch.sqrt(torch.sum((true_log_quaternions - pred_log_quaternions) ** 2, dim=-1))
    mean_q_error = torch.mean(q_error)

    # Dynamic weight adjustment based on the error ratio
    error_ratio = mean_t_error / (mean_q_error + 1e-8)  # Adding a small epsilon to avoid division by zero
    weight_translation = weight_translation_initial * error_ratio
    weight_quaternion = weight_quaternion_initial / error_ratio

    # Apply dynamic weights to the respective errors
    weighted_t_error = weight_translation * mean_t_error
    weighted_q_error = weight_quaternion * mean_q_error

    # Final combined weighted error
    combined_weighted_error = weighted_t_error + weighted_q_error

    return combined_weighted_error


In [None]:

def qlog(q):
    q = q / np.linalg.norm(q)
    sinhalftheta = np.linalg.norm(q[1:])
    coshalftheta = q[0]
    r = np.arctan2(sinhalftheta, coshalftheta)
    if sinhalftheta > 1e-6:
        qlog = r * q[1:] / sinhalftheta
    else:
        qlog = np.zeros(3)
    return qlog

def process_poses(poses_in, mean_t, std_t, align_R, align_t, align_s):
    poses_out = np.zeros((len(poses_in), 6))
    poses_out[:, 0:3] = poses_in[:, [3, 7, 11]]  # Translation components

    # Align and process rotation
    for i in range(len(poses_out)):
        R = poses_in[i].reshape((3, 4))[:3, :3]
        q = txq.mat2quat(np.dot(align_R, R))
        q *= np.sign(q[0])  # Constrain to hemisphere
        q = qlog(q)
        poses_out[i, 3:] = q
        t = poses_out[i, :3] - align_t
        poses_out[i, :3] = align_s * np.dot(align_R, t[:, np.newaxis]).squeeze()

    # Normalize translation
    poses_out[:, :3] -= mean_t
    poses_out[:, :3] /= std_t

    return poses_out

def load_and_process_poses(data_dir, seqs, train=True, real=False, vo_lib='orbslam'):
    ps = {}
    vo_stats = {}
    all_poses = []
    for seq in seqs:
        seq_dir = os.path.join(data_dir, seq)
        p_filenames = [n for n in os.listdir(seq_dir) if n.find('pose') >= 0]
        if real:
            pose_file = os.path.join(data_dir, '{:s}_poses'.format(vo_lib), seq)
            pss = np.loadtxt(pose_file)
            frame_idx = pss[:, 0].astype(int)
            if vo_lib == 'libviso2':
                frame_idx -= 1
            ps[seq] = pss[:, 1:13]
            vo_stats_filename = os.path.join(seq_dir, '{:s}_vo_stats.pkl'.format(vo_lib))
            with open(vo_stats_filename, 'rb') as f:
                vo_stats[seq] = pickle.load(f)
        else:
            frame_idx = np.array(range(len(p_filenames)), dtype=int)
            pss = [np.loadtxt(os.path.join(seq_dir, 'frame-{:06d}.pose.txt'.format(i))).flatten()[:12] for i in frame_idx if os.path.exists(os.path.join(seq_dir, 'frame-{:06d}.pose.txt'.format(i)))]
            ps[seq] = np.asarray(pss)
            vo_stats[seq] = {'R': np.eye(3), 't': np.zeros(3), 's': 1}

        all_poses.append(ps[seq])

    all_poses = np.vstack(all_poses)
    pose_stats_filename = os.path.join(data_dir, 'pose_stats.txt')
    if train and not real:
        mean_t = np.mean(all_poses[:, [3, 7, 11]], axis=0)
        std_t = np.std(all_poses[:, [3, 7, 11]], axis=0)
        np.savetxt(pose_stats_filename, np.vstack((mean_t, std_t)), fmt='%8.7f')
    else:
        mean_t, std_t = np.loadtxt(pose_stats_filename)

    # Process and normalize poses
    processed_poses = []
    for seq in seqs:
        pss = process_poses(poses_in=ps[seq], mean_t=mean_t, std_t=std_t,
                            align_R=vo_stats[seq]['R'], align_t=vo_stats[seq]['t'],
                            align_s=vo_stats[seq]['s'])
        processed_poses.append(pss)

    return np.vstack(processed_poses)

# Define the transformation for the images
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Dataset class
class FireDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.seqs = [seq for seq in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, seq))]
        self.samples = self._load_samples()
        self.processed_poses = self._load_processed_poses()

        # Debugging prints
        print(f"Number of samples: {len(self.samples)}")
        print(f"Number of processed poses: {self.processed_poses.shape[0]}")

        # Ensure consistency between samples and processed poses
        min_length = min(len(self.samples), self.processed_poses.shape[0])
        self.samples = self.samples[:min_length]
        self.processed_poses = self.processed_poses[:min_length]

    def _load_samples(self):
        samples = []
        for seq_folder in self.seqs:
            seq_path = os.path.join(self.root_dir, seq_folder)
            color_files = sorted([f for f in os.listdir(seq_path) if f.endswith('.color.png')])
            depth_files = sorted([f for f in os.listdir(seq_path) if f.endswith('.depth.png')])
            pose_files = sorted([f for f in os.listdir(seq_path) if f.endswith('.pose.txt')])

            for color_file, depth_file, pose_file in zip(color_files, depth_files, pose_files):
                samples.append((os.path.join(seq_path, color_file),
                                os.path.join(seq_path, depth_file),
                                os.path.join(seq_path, pose_file)))
        return samples

    def _load_processed_poses(self):
        return load_and_process_poses(self.root_dir, self.seqs)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if idx >= len(self.processed_poses):
            raise IndexError(f"Index {idx} out of bounds for processed poses of size {len(self.processed_poses)}")

        color_path, depth_path, pose_path = self.samples[idx]

        color_image = cv2.imread(color_path, cv2.IMREAD_COLOR)
        depth_image = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
        pose_matrix = self.processed_poses[idx]

        if self.transform:
            color_image = self.transform(color_image)
            depth_image = (depth_image / depth_image.max() * 255).astype(np.uint8)
            depth_image = self.transform(depth_image)

        return color_image, depth_image, pose_matrix


In [None]:
def conv_3x3_bn(inp, oup, image_size, downsample=False):
    stride = 1 if downsample == False else 2
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.GELU()
    )

class PreNorm(nn.Module):
    def __init__(self, dim, fn, norm):
        super().__init__()
        self.norm = norm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class SE(nn.Module):
    def __init__(self, inp, oup, expansion=0.25):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(oup, int(inp * expansion), bias=False),
            nn.GELU(),
            nn.Linear(int(inp * expansion), oup, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class XMBConv(nn.Module):
    def __init__(self, inp, oup, image_size, downsample=False, expansion=4):
        super().__init__()
        self.downsample = downsample
        stride = 1 if self.downsample == False else 2
        hidden_dim = int(inp * expansion)

        if self.downsample:
            self.pool = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)
            self.poolx = nn.MaxPool2d(3, 2, 1)
            self.projx = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)

        if expansion == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
                          1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
            self.convx = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
                          1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                # down-sample in the first conv
                nn.Conv2d(inp, hidden_dim, 1, stride, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1,
                          groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                SE(inp, hidden_dim),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
            self.convx = nn.Sequential(
                # pw
                # down-sample in the first conv
                nn.Conv2d(inp, hidden_dim, 1, stride, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1,
                          groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                SE(inp, hidden_dim),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        self.conv = PreNorm(inp, self.conv, nn.BatchNorm2d)
        self.convx = PreNorm(inp, self.convx, nn.BatchNorm2d)

    def forward(self, x):
        y = x[1]
        x = x[0]

        if self.downsample:
            out1 = self.proj(self.pool(x)) + self.conv(x)
            out2 = self.projx(self.poolx(y)) + self.convx(y)
        else:
            out1 = x + self.conv(x)
            out2 = y + self.convx(y)
        return [out1, out2]

class XAttention(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == inp)

        self.ih, self.iw = image_size

        self.heads = heads
        self.scale = dim_head ** -0.5

        # parameter table of relative position bias
        self.relative_bias_table = nn.Parameter(
            torch.zeros((2 * self.ih - 1) * (2 * self.iw - 1), heads))

        coords = torch.meshgrid((torch.arange(self.ih), torch.arange(self.iw)))
        coords = torch.flatten(torch.stack(coords), 1)
        relative_coords = coords[:, :, None] - coords[:, None, :]

        relative_coords[0] += self.ih - 1
        relative_coords[1] += self.iw - 1
        relative_coords[0] *= 2 * self.iw - 1
        relative_coords = rearrange(relative_coords, 'c h w -> h w c')
        relative_index = relative_coords.sum(-1).flatten().unsqueeze(1)
        self.register_buffer("relative_index", relative_index)

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(inp, inner_dim * 3, bias=False)
        self.to_kx = nn.Linear(inp, inner_dim * 1, bias=False)
        self.to_qx = nn.Linear(inp, inner_dim * 1, bias=False)


        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, oup),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x, y):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        kx = self.to_kx(y)
        qx = self.to_qx(x)

        q, k, v = map(lambda t: rearrange(
            t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        kx = rearrange(kx, 'b n (h d) -> b h n d', h=self.heads)
        qx = rearrange(qx, 'b n (h d) -> b h n d', h=self.heads)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # Use "gather" for more efficiency on GPUs
        relative_bias = self.relative_bias_table.gather(
            0, self.relative_index.repeat(1, self.heads))
        relative_bias = rearrange(
            relative_bias, '(h w) c -> 1 c h w', h=self.ih*self.iw, w=self.ih*self.iw)
        dots = dots + relative_bias

        attn = self.attend(dots)
        out = torch.matmul(attn, v)

        dots = torch.matmul(qx, kx.transpose(-1, -2)) * self.scale
        dots = dots + relative_bias
        attn = self.attend(dots)
        out = torch.matmul(attn, out)

        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out


class XTransformer(nn.Module):
    def __init__(self, inp, oup, image_size, heads=8, dim_head=32, downsample=False, dropout=0.):
        super().__init__()
        hidden_dim = int(inp * 4)

        self.ih, self.iw = image_size
        self.downsample = downsample

        self.layer_norm = nn.LayerNorm(inp)
        self.layer_normx = nn.LayerNorm(inp)

        if self.downsample:
            self.pool1 = nn.MaxPool2d(3, 2, 1)
            self.pool2 = nn.MaxPool2d(3, 2, 1)
            self.proj = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)
            self.pool1x = nn.MaxPool2d(3, 2, 1)
            self.pool2x = nn.MaxPool2d(3, 2, 1)
            self.projx = nn.Conv2d(inp, oup, 1, 1, 0, bias=False)
        self.attn = XAttention(inp, oup, image_size, heads, dim_head, dropout)
        self.ff = FeedForward(oup, hidden_dim, dropout)
        self.attnx = XAttention(inp, oup, image_size, heads, dim_head, dropout)
        self.ffx = FeedForward(oup, hidden_dim, dropout)



        self.ff = nn.Sequential(
            Rearrange('b c ih iw -> b (ih iw) c'),
            PreNorm(oup, self.ff, nn.LayerNorm),
            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
            )
        self.ffx = nn.Sequential(
            Rearrange('b c ih iw -> b (ih iw) c'),
            PreNorm(oup, self.ffx, nn.LayerNorm),
            Rearrange('b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
            )

    def forward(self, x):
        y = x[1]
        x = x[0]

        if self.downsample:
            pool1 = self.pool2(x)
            pool1 = rearrange(pool1, 'b c ih iw -> b (ih iw) c')
            norm1 = self.layer_norm(pool1)
            pool2 = self.pool2x(y)
            pool2 = rearrange(pool2, 'b c ih iw -> b (ih iw) c')
            norm2 = self.layer_normx(pool2)

            attn1 = self.attn(norm1, norm2)
            attn1 = rearrange(attn1, 'b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
            attn2 = self.attnx(norm2, norm1)
            attn2 = rearrange(attn2, 'b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)

            out1 = self.proj(self.pool1(x)) + attn1
            out2 = self.projx(self.pool1x(y)) + attn2
        else:
            xx = rearrange(x, 'b c ih iw -> b (ih iw) c')
            norm1 = self.layer_norm(xx)
            yy = rearrange(y, 'b c ih iw -> b (ih iw) c')
            norm2 = self.layer_normx(yy)

            attn1 = self.attn(norm1, norm2)
            attn1 = rearrange(attn1, 'b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
            attn2 = self.attnx(norm2, norm1)
            attn2 = rearrange(attn2, 'b (ih iw) c -> b c ih iw', ih=self.ih, iw=self.iw)
            out1 = x + attn1
            out2 = y + attn2

        out1 = out1 + self.ff(out1)
        out2 = out2 + self.ffx(out2)
        return [out1, out2]


class CoAtXNet(nn.Module):
    def __init__(self, image_size, in_channels, aux_channels, num_blocks, channels, block_types=['C', 'C', 'T', 'T']):
        super().__init__()
        ih, iw = image_size
        block = {'C': XMBConv, 'T': XTransformer}

        self.s0 = self._make_layer(
            conv_3x3_bn, in_channels, channels[0], num_blocks[0], (ih // 2, iw // 2))
        self.s0x = self._make_layer(
            conv_3x3_bn, aux_channels, channels[0], num_blocks[0], (ih // 2, iw // 2))

        self.s1 = self._make_layer(
            block[block_types[0]], channels[0], channels[1], num_blocks[1], (ih // 4, iw // 4))
        self.s2 = self._make_layer(
            block[block_types[1]], channels[1], channels[2], num_blocks[2], (ih // 8, iw // 8))
        self.s3 = self._make_layer(
            block[block_types[2]], channels[2], channels[3], num_blocks[3], (ih // 16, iw // 16))
        self.s4 = self._make_layer(
            block[block_types[3]], channels[3], channels[4], num_blocks[4], (ih // 32, iw // 32))

        self.pool = nn.AvgPool2d(ih // 32, 1)
        hidden_dim = 1024
        self.fc1 = nn.Linear(channels[-1] * 2, hidden_dim)
        self.relu = nn.ReLU()  # Activation function between dense layers
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)  # Another dense layer
        self.fc3 = nn.Linear(hidden_dim // 2, 6)  # Final output layer adjusted to target size

    def forward(self, x, y):

        x = self.s0(x)
        y = self.s0x(y)

        xy = self.s1([x,y])
        xy = self.s2(xy)
        xy = self.s3(xy)
        xy = self.s4(xy)

        x = xy[0]
        y = xy[1]

        x = self.pool(x).view(-1, x.shape[1])
        y = self.pool(y).view(-1, y.shape[1])

        x = torch.cat((x,y), -1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)

        return x

    def _make_layer(self, block, inp, oup, depth, image_size):
        layers = nn.ModuleList([])
        for i in range(depth):
            if i == 0:
                layers.append(block(inp, oup, image_size, downsample=True))
            else:
                layers.append(block(oup, oup, image_size))
        return nn.Sequential(*layers)



def coatxnet_0():
    num_blocks = [2, 2, 3, 5, 2]            # L
    channels = [64, 96, 192, 384, 768]      # D
    return CoAtXNet((256, 256), 3, 1, num_blocks, channels, )


def coatxnet_1():
    num_blocks = [2, 2, 6, 14, 2]           # L
    channels = [64, 96, 192, 384, 768]      # D
    return CoAtXNet((256, 256), 3, 1, num_blocks, channels, )


def coatxnet_2():
    num_blocks = [2, 2, 6, 14, 2]           # L
    channels = [128, 128, 256, 512, 1026]   # D
    return CoAtXNet((256, 256), 3, 1, num_blocks, channels, )


def coatxnet_3():
    num_blocks = [2, 2, 6, 14, 2]           # L
    channels = [192, 192, 384, 768, 1536]   # D
    return CoAtXNet((256, 256), 3, 1, num_blocks, channels, )


def coatxnet_4():
    num_blocks = [2, 2, 12, 28, 2]          # L
    channels = [192, 192, 384, 768, 1536]   # D
    return CoAtXNet((256, 256), 3, 1, num_blocks, channels, )


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


if __name__ == '__main__':
    img = torch.randn(1, 3, 256, 256)
    aux = torch.randn(1, 1, 256, 256)

    net = coatxnet_0()
    out = net(img, aux)
    print(out.shape, count_parameters(net))

    net = coatxnet_1()
    out = net(img, aux)
    print(out.shape, count_parameters(net))

    net = coatxnet_2()
    out = net(img, aux)
    print(out.shape, count_parameters(net))

    net = coatxnet_3()
    out = net(img, aux)
    print(out.shape, count_parameters(net))

    net = coatxnet_4()
    out = net(img, aux)
    print(out.shape, count_parameters(net))


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


torch.Size([1, 6]) 39113334
torch.Size([1, 6]) 73447686
torch.Size([1, 6]) 120817534
torch.Size([1, 6]) 249078022
torch.Size([1, 6]) 432611814


In [None]:
if __name__ == '__main__':
    root_dir = '/content/drive/MyDrive/Chess'
    dataset = FireDataset(root_dir=root_dir, transform=transform)

    kfold = KFold(n_splits=5, shuffle=True, random_state=42)
    num_epochs = 100  # Define the number of epochs

    # Check for GPU availability
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
        print(f'Fold {fold + 1}')

        train_sampler = SubsetRandomSampler(train_idx)
        val_sampler = SubsetRandomSampler(val_idx)

        train_loader = DataLoader(dataset, batch_size=16, sampler=train_sampler, num_workers=4)
        val_loader = DataLoader(dataset, batch_size=16, sampler=val_sampler, num_workers=4)

        model = coatxnet_0().to(device)  # Move model to GPU
        optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=4, factor=0.1)

        # Training loop
        for epoch in range(num_epochs):
            model.train()
            train_loss = 0.0
            for color_images, depth_images, poses in tqdm(train_loader):
                color_images, depth_images, poses = color_images.to(device), depth_images.to(device), poses.to(device)  # Move data to GPU
                optimizer.zero_grad()
                outputs = model(color_images, depth_images)
                loss = ModifiedDSACLoss(poses, outputs)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

            # Validation loop
            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for color_images, depth_images, poses in val_loader:
                    color_images, depth_images, poses = color_images.to(device), depth_images.to(device), poses.to(device)  # Move data to GPU
                    outputs = model(color_images, depth_images)
                    loss = ModifiedDSACLoss(poses, outputs)
                    val_loss += loss.item()

            # Step the scheduler
            scheduler.step(val_loss)

            print(f'Epoch {epoch + 1}, Train Loss: {train_loss / len(train_loader)}, Val Loss: {val_loss / len(val_loader)}')



Number of samples: 6003
Number of processed poses: 6000
Using device: cuda
Fold 1


100%|██████████| 300/300 [02:54<00:00,  1.72it/s]


Epoch 1, Train Loss: 1.7983068812204634, Val Loss: 1.4698871979458124


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 2, Train Loss: 0.9421943558084785, Val Loss: 0.621728616601762


100%|██████████| 300/300 [03:01<00:00,  1.66it/s]


Epoch 3, Train Loss: 0.6453091052652233, Val Loss: 0.47830394957867695


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 4, Train Loss: 0.5295329828460815, Val Loss: 0.5127048591807912


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 5, Train Loss: 0.44878955994185576, Val Loss: 0.5002289101562173


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 6, Train Loss: 0.40433002019266556, Val Loss: 0.38171451737208967


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 7, Train Loss: 0.37682791260871634, Val Loss: 0.46488227879369715


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 8, Train Loss: 0.3375694711015532, Val Loss: 0.32010125273413054


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 9, Train Loss: 0.3236744967858662, Val Loss: 0.3064340718462233


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 10, Train Loss: 0.3152138344676683, Val Loss: 0.31025875052596935


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 11, Train Loss: 0.2903263163669485, Val Loss: 0.34930042338157796


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 12, Train Loss: 0.2771455389924778, Val Loss: 0.32009038828569203


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 13, Train Loss: 0.2766239274392061, Val Loss: 0.2816841714004821


100%|██████████| 300/300 [03:01<00:00,  1.66it/s]


Epoch 14, Train Loss: 0.28384906457551684, Val Loss: 0.2899942185679435


100%|██████████| 300/300 [03:01<00:00,  1.66it/s]


Epoch 15, Train Loss: 0.2672019974024281, Val Loss: 0.32893296217596024


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 16, Train Loss: 0.24190573395040066, Val Loss: 0.27922463266839714


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 17, Train Loss: 0.23784741596224204, Val Loss: 0.2665744481486065


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 18, Train Loss: 0.2364740354158783, Val Loss: 0.2717955419074725


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 19, Train Loss: 0.22966669940638634, Val Loss: 0.286048843583338


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 20, Train Loss: 0.2452580335200247, Val Loss: 0.24978358723181732


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 21, Train Loss: 0.22208745801344013, Val Loss: 0.28093981285650726


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 22, Train Loss: 0.22738439782185316, Val Loss: 0.2507329718137206


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 23, Train Loss: 0.2277851046847933, Val Loss: 0.2662957851023282


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 24, Train Loss: 0.2097722825143312, Val Loss: 0.18204084159957337


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 25, Train Loss: 0.20902300397036847, Val Loss: 0.2449037877469742


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 26, Train Loss: 0.20628602131351195, Val Loss: 0.21894727092730434


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 27, Train Loss: 0.19574622298211422, Val Loss: 0.20710562725831907


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 28, Train Loss: 0.2096861280418125, Val Loss: 0.22225067662837816


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 29, Train Loss: 0.19787219847965934, Val Loss: 0.1932280211849779


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 30, Train Loss: 0.11657965745547157, Val Loss: 0.12061688874799312


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 31, Train Loss: 0.08611123544110139, Val Loss: 0.10679693840827414


100%|██████████| 300/300 [03:01<00:00,  1.66it/s]


Epoch 32, Train Loss: 0.07501723822659047, Val Loss: 0.09911411896525857


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 33, Train Loss: 0.06829326855165758, Val Loss: 0.10010879600071208


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 34, Train Loss: 0.06424999502313687, Val Loss: 0.09418810316499163


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 35, Train Loss: 0.05967541590167159, Val Loss: 0.10237924102935016


100%|██████████| 300/300 [03:01<00:00,  1.66it/s]


Epoch 36, Train Loss: 0.0592864251222336, Val Loss: 0.0998707430578007


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 37, Train Loss: 0.056142614318329676, Val Loss: 0.09967093110245899


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 38, Train Loss: 0.05504282905962176, Val Loss: 0.09499499764996822


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 39, Train Loss: 0.052190239585114234, Val Loss: 0.09741960331192379


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 40, Train Loss: 0.04506234755249034, Val Loss: 0.09165270497516992


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 41, Train Loss: 0.04258542775758136, Val Loss: 0.09090732329613284


100%|██████████| 300/300 [03:01<00:00,  1.66it/s]


Epoch 42, Train Loss: 0.0419270329048948, Val Loss: 0.09237736261048883


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 43, Train Loss: 0.0409978322385846, Val Loss: 0.09175909764266278


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 44, Train Loss: 0.04060836570261072, Val Loss: 0.0944516489216771


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 45, Train Loss: 0.04054346183201674, Val Loss: 0.09243858590243616


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 46, Train Loss: 0.03934732701928314, Val Loss: 0.09594620462516788


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 47, Train Loss: 0.038965218373618916, Val Loss: 0.09299864035541411


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 48, Train Loss: 0.038903577993509984, Val Loss: 0.0947981290370632


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 49, Train Loss: 0.03846752406118333, Val Loss: 0.0932763530198907


100%|██████████| 300/300 [03:01<00:00,  1.65it/s]


Epoch 50, Train Loss: 0.03883842599549363, Val Loss: 0.09166460834471213


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 51, Train Loss: 0.03892337336567418, Val Loss: 0.09352934731650298


100%|██████████| 300/300 [03:01<00:00,  1.66it/s]


Epoch 52, Train Loss: 0.03875948257969771, Val Loss: 0.09425068324497454


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 53, Train Loss: 0.03847804934818474, Val Loss: 0.09735323239567086


100%|██████████| 300/300 [03:01<00:00,  1.65it/s]


Epoch 54, Train Loss: 0.038865883784084614, Val Loss: 0.09491084218989614


100%|██████████| 300/300 [03:01<00:00,  1.65it/s]


Epoch 55, Train Loss: 0.038573385501272914, Val Loss: 0.09324473874595204


100%|██████████| 300/300 [03:02<00:00,  1.64it/s]


Epoch 56, Train Loss: 0.038861309975708726, Val Loss: 0.09432581787635395


100%|██████████| 300/300 [03:02<00:00,  1.64it/s]


Epoch 57, Train Loss: 0.038752527312994285, Val Loss: 0.09164700935142493


100%|██████████| 300/300 [03:01<00:00,  1.65it/s]


Epoch 58, Train Loss: 0.038937592787670505, Val Loss: 0.09431615262441444


100%|██████████| 300/300 [03:01<00:00,  1.66it/s]


Epoch 59, Train Loss: 0.038577241060484976, Val Loss: 0.0904695927998115


100%|██████████| 300/300 [03:01<00:00,  1.66it/s]


Epoch 60, Train Loss: 0.03865823851288361, Val Loss: 0.09264367550349277


100%|██████████| 300/300 [03:01<00:00,  1.66it/s]


Epoch 61, Train Loss: 0.038487190739091404, Val Loss: 0.09382611500906297


100%|██████████| 300/300 [03:01<00:00,  1.66it/s]


Epoch 62, Train Loss: 0.03882631533905583, Val Loss: 0.09367172937260865


100%|██████████| 300/300 [03:01<00:00,  1.66it/s]


Epoch 63, Train Loss: 0.03875736012357891, Val Loss: 0.09153322261124316


100%|██████████| 300/300 [03:01<00:00,  1.65it/s]


Epoch 64, Train Loss: 0.03864737864728714, Val Loss: 0.09558250196593691


100%|██████████| 300/300 [03:01<00:00,  1.66it/s]


Epoch 65, Train Loss: 0.038791682140708, Val Loss: 0.09338097344607302


100%|██████████| 300/300 [03:01<00:00,  1.66it/s]


Epoch 66, Train Loss: 0.0383913537124795, Val Loss: 0.0960997957864358


100%|██████████| 300/300 [03:01<00:00,  1.66it/s]


Epoch 67, Train Loss: 0.03878554185190682, Val Loss: 0.0915024559700136


100%|██████████| 300/300 [03:01<00:00,  1.66it/s]


Epoch 68, Train Loss: 0.03874986984652797, Val Loss: 0.09152825627449503


100%|██████████| 300/300 [03:01<00:00,  1.66it/s]


Epoch 69, Train Loss: 0.0389417280916881, Val Loss: 0.0947799199651748


100%|██████████| 300/300 [03:01<00:00,  1.65it/s]


Epoch 70, Train Loss: 0.03794767185616998, Val Loss: 0.09383560374319139


100%|██████████| 300/300 [03:02<00:00,  1.64it/s]


Epoch 71, Train Loss: 0.038870321080480726, Val Loss: 0.0925121464245651


100%|██████████| 300/300 [03:02<00:00,  1.64it/s]


Epoch 72, Train Loss: 0.03816235509890276, Val Loss: 0.09293557604426486


100%|██████████| 300/300 [03:01<00:00,  1.65it/s]


Epoch 73, Train Loss: 0.038292031135935586, Val Loss: 0.09286262149707225


100%|██████████| 300/300 [03:01<00:00,  1.65it/s]


Epoch 74, Train Loss: 0.03916591256709015, Val Loss: 0.09293019448695947


100%|██████████| 300/300 [03:01<00:00,  1.66it/s]


Epoch 75, Train Loss: 0.038803353310441185, Val Loss: 0.0932691456649693


100%|██████████| 300/300 [03:00<00:00,  1.66it/s]


Epoch 76, Train Loss: 0.038471819164891945, Val Loss: 0.09443484264506707


100%|██████████| 300/300 [03:01<00:00,  1.66it/s]


Epoch 77, Train Loss: 0.038471983438920554, Val Loss: 0.09327671034705852


100%|██████████| 300/300 [03:01<00:00,  1.65it/s]


Epoch 78, Train Loss: 0.03843073467175622, Val Loss: 0.09280733350266689


100%|██████████| 300/300 [03:01<00:00,  1.65it/s]


Epoch 79, Train Loss: 0.03791804730324913, Val Loss: 0.09318139773115641


100%|██████████| 300/300 [03:01<00:00,  1.65it/s]


Epoch 80, Train Loss: 0.03847005410679672, Val Loss: 0.09338204968889283


100%|██████████| 300/300 [03:01<00:00,  1.65it/s]


Epoch 81, Train Loss: 0.03816146563955275, Val Loss: 0.09399395356088697


100%|██████████| 300/300 [03:01<00:00,  1.65it/s]


Epoch 82, Train Loss: 0.03878569026385626, Val Loss: 0.09230499797317407


  1%|          | 3/300 [00:03<04:19,  1.14it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7af4780be560>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1443, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.10/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 43, in wait
    return self.poll(os.WNOHANG if timeout == 0.0 else 0)
  File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 27, in poll
    pid, sts = os.waitpid(self.pid, flag)
KeyboardInterrupt: 
  1%|          | 3/300 [00:03<06:03,  1.22s/it]


KeyboardInterrupt: 