In [None]:
import os
import time
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import loadmat
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torch.utils.data as data
from torch.utils.data import random_split, DataLoader, Dataset
from torchvision.datasets import VisionDataset
from torchvision import tv_tensors
import torchvision.transforms.v2 as v2
from torch.cuda.amp import autocast, GradScaler
from dataclasses import dataclass

from transformers import SegformerModel

from Depth2HHA_python_master.getHHA import wrap_getHHA
from SegFormer_block import Block, OverlapPatchMerging

  from .autonotebook import tqdm as notebook_tqdm


## Preprocess

In [None]:
# カラーマップ生成関数：セグメンテーションの可視化用
def colormap(N=256, normalized=False):
    def bitget(byteval, idx):
        return ((byteval & (1 << idx)) != 0)

    dtype = 'float32' if normalized else 'uint8'
    cmap = np.zeros((N, 3), dtype=dtype)
    for i in range(N):
        r = g = b = 0
        c = i
        for j in range(8):
            r = r | (bitget(c, 0) << 7-j)
            g = g | (bitget(c, 1) << 7-j)
            b = b | (bitget(c, 2) << 7-j)
            c = c >> 3

        cmap[i] = np.array([r, g, b])

    cmap = cmap/255 if normalized else cmap
    return cmap


# NYUv2データセット：RGB画像、セグメンテーション、深度、法線マップを提供するデータセット
class NYUv2(Dataset):
    """NYUv2 dataset

    Args:
        root (string): Root directory path.
        split (string, optional): 'train' for training set, and 'test' for test set. Default: 'train'.
        target_type (string, optional): Type of target to use, ``semantic``, ``depth``.
        transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version.
        target_transform (callable, optional): A function/transform that takes in the target and transforms it.
    """
    cmap = colormap()
    def __init__(self,
                root: str,
                config,
                split='train',
                ):
        super().__init__()

        # データセットの基本設定
        assert(split in ('train', 'test'))
        self.root = root
        self.split = split
        self.train_idx = np.array([255, ] + list(range(13)))  # 13クラス分類用
        self.config = config

        # 画像ファイルのパスリストを作成
        img_names = os.listdir(os.path.join(self.root, self.split, 'image'))
        img_names.sort()
        images_dir = os.path.join(self.root, self.split, 'image')
        self.images = [os.path.join(images_dir, name) for name in img_names]

        label_dir = os.path.join(self.root, self.split, 'label')
        if (self.split == 'train'):
            self.targets = [os.path.join(label_dir, name) for name in img_names]

        depth_dir = os.path.join(self.root, self.split, 'depth')
        self.depths = [os.path.join(depth_dir, name) for name in img_names]

    @property
    def transform(self):
        return v2.Compose([
            v2.RandomHorizontalFlip(0.5),
            v2.RandomRotation((-10, 10)),
            v2.RandomResizedCrop(self.config.image_size),
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    @property
    def color_jitter(self):  # RGBデータのみに適用
        return v2.Compose([
            v2.ColorJitter()
        ])

    @property
    def test_transform(self):
        return v2.Compose([
            v2.ToImage(),
            v2.ToDtype(torch.float32),
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def __getitem__(self, idx):
        if self.split == 'train':
            image = Image.open(self.images[idx])
            depth = Image.open(self.depths[idx])
            target = Image.open(self.targets[idx])

            image = v2.functional.pil_to_tensor(image).to(torch.float32) / 255.0
            depth = v2.functional.pil_to_tensor(depth).to(torch.float32) / 65535.0
            target = v2.functional.pil_to_tensor(target).to(torch.int32)

            # HHA encodingする
            depth = np.array(depth)
            hha_depth = wrap_getHHA(depth)

            # (B, C, H, W)の形でラップする
            image = tv_tensors.Image(torch.from_numpy(np.array(image)) / 255.0)
            depth = tv_tensors.Image(torch.from_numpy(hha_depth).permute(0, 3, 1, 2))
            target = tv_tensors.Mask(target)

            image, depth, target = self.transform(image, depth, target)
            image = self.color_jitter(image)

            return image, depth, target

        if self.split=='test':
            image = Image.open(self.images[idx])
            depth = Image.open(self.depths[idx])

            # HHA encodingする
            depth = np.array(depth)
            hha_depth = wrap_getHHA(depth)

            image = tv_tensors.Image(torch.from_numpy(np.array(image)) / 255.0)
            depth = tv_tensors.Image(torch.from_numpy(hha_depth).permute(0, 3, 1, 2))

            image = self.test_transform(image)
            depth = self.test_transform(depth)
            return image, depth

    def __len__(self):
        return len(self.images)

## Build model

RuntimeError: Given groups=1, weight of size [80, 160, 1, 1], expected input[4, 256, 1, 1] to have 160 channels, but got 256 channels instead

In [2]:
class DCMAF(nn.Module):
    def __init__(self, num_channels: int):
        super().__init__()
        """
        Discriminative Cross-Modal Attention Fusion (DCMAF) Module. 

        Args:
            num_channels (int): channels of feature map
        """
        super().__init__()
        self.rgb_process = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(num_channels, num_channels//2, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(num_channels//2, num_channels, kernel_size=1),
        )
        self.d_process = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(num_channels, num_channels//2, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(num_channels//2, num_channels, kernel_size=1),
        )

    def forward(self, rgb_map: torch.Tensor, depth_map: torch.Tensor):
        """

        Args:
            rgb_map (torch.Tensor): (B, C, H, W)
            depth_map (torch.Tensor): (B, C, H, W)

        Returns:
            torch.Tensor: (B, C, H, W)
        """
        print(f"{rgb_map.size()=}")
        print(f"{depth_map.size()=}")
        rgb_vec = self.rgb_process(rgb_map)
        depth_vec = self.d_process(depth_map)
        w = F.softmax(rgb_vec - depth_vec, dim=1)
        return w * rgb_map + (1-w) * depth_map


class MyEncoder(nn.Module):
    def __init__(self, RGB_model_name, D_model_name, dims, reductions, depths):
        super().__init__()
        self.RGB_model = SegformerModel.from_pretrained(
            RGB_model_name
        )
        self.D_model = SegformerModel.from_pretrained(
            D_model_name
        )
        self.RGB_backbone = self.RGB_model.encoder
        self.D_backbone = self.D_model.encoder

        # stage 2
        self.dcmaf2 = DCMAF(dims[1])
        self.blocks2 = nn.ModuleList([Block(dims[1], reduction=reductions[1]) for _ in range(depths[1])])
        self.patch_merge2 = OverlapPatchMerging(dims[1], dims[2], padding=1, stride=2, kernel=3)

        # stage 3
        self.dcmaf3 = DCMAF(dims[2])
        self.blocks3 = nn.ModuleList([Block(dims[2], reduction=reductions[2]) for _ in range(depths[2])])
        self.patch_merge3 = OverlapPatchMerging(dims[2], dims[3], padding=1, stride=2, kernel=3)

        # stage 4
        self.dcmaf4 = DCMAF(dims[3])
        self.blocks4 = nn.ModuleList([Block(dims[3], reduction=reductions[3]) for _ in range(depths[3])])
        self.patch_merge4 = OverlapPatchMerging(dims[3], dims[4], padding=1, stride=2, kernel=3)

        # stage 5
        # ここのpatch_mergeはfeature mapのサイズを変えないようにする
        print(f"{dims[4]=}")
        self.dcmaf5 = DCMAF(dims[4])
        self.blocks5 = nn.ModuleList([Block(dims[4], reduction=reductions[4]) for _ in range(depths[4])])

    def forward(self, RGB_image, D_image):
        outputs = []
        RGB_outputs = self.RGB_backbone(
            pixel_values=RGB_image,
            output_hidden_states=True
        )
        D_outputs = self.D_backbone(
            pixel_values=D_image,
            output_hidden_states=True
        )

        RGB_hidden_states = RGB_outputs.hidden_states
        D_hidden_states = D_outputs.hidden_states

        # stage 2
        hidden_state: torch.Tensor = self.dcmaf2(RGB_hidden_states[0], D_hidden_states[0])
        _, _, h, w = hidden_state.size()
        x = hidden_state.flatten(2).permute(0, 2, 1)
        for block in self.blocks2:
            x = block(x, h, w)
        b, l, dim = x.size()
        x = x.reshape(b, h, w, dim).permute(0, 3, 1, 2)
        outputs.append(x)
        print(f"in stage 2(before patch merge)\n{x.size()=}")
        x, h, w = self.patch_merge2(x)
        print(f"in stage 2(after patch merge)\n{x.size()=}\n{h=}\n{w=}")

        # stage 3
        hidden_state = self.dcmaf3(RGB_hidden_states[1], D_hidden_states[1])
        _, _, h, w = hidden_state.size()
        hidden_state = hidden_state.flatten(2).permute(0, 2, 1)
        print(f"{x.size()=}")
        print(f"{hidden_state.size()=}")
        x += hidden_state
        for block in self.blocks3:
            x = block(x, h, w)
        b, l, dim = x.size()
        x = x.reshape(b, h, w, dim).permute(0, 3, 1, 2)
        outputs.append(x)
        x, h, w = self.patch_merge3(x)

        # stage 4
        hidden_state = self.dcmaf4(RGB_hidden_states[2], D_hidden_states[2])
        _, _, h, w = hidden_state.size()
        hidden_state = hidden_state.flatten(2).permute(0, 2, 1)
        x += hidden_state
        for block in self.blocks4:
            x = block(x, h, w)
        b, l, dim = x.size()
        x = x.reshape(b, h, w, dim).permute(0, 3, 1, 2)
        outputs.append(x)
        x, h, w = self.patch_merge4(x)

        # stage 5
        print("stage 5")
        hidden_state = self.dcmaf5(RGB_hidden_states[3], D_hidden_states[3])
        _, _, h, w = hidden_state.size()
        hidden_state = hidden_state.flatten(2).permute(0, 2, 1)
        x += hidden_state
        for block in self.blocks5:
            x = block(x, h, w)
        b, l, dim = x.size()
        x = x.reshape(b, h, w, dim).permute(0, 3, 1, 2)
        outputs.append(x)
        return outputs


In [27]:
import gc
del encoder
gc.collect()

11231

In [None]:
encoder = MyEncoder(config.rgb_model, config.depth_model, config.dims, config.reductions, config.depths)
dummy_rgb = torch.rand((4, 3, 480, 640))
dummy_depth = torch.rand((4, 3, 480, 640))

outputs = encoder(dummy_rgb, dummy_depth)

In [8]:
for hidden_state in outputs:
    print(f"{hidden_state.size()=}")

hidden_state.size()=torch.Size([4, 32, 120, 160])
hidden_state.size()=torch.Size([4, 64, 60, 80])
hidden_state.size()=torch.Size([4, 160, 30, 40])
hidden_state.size()=torch.Size([4, 256, 15, 20])


In [3]:
class MyDecoder(nn.Module):
    def __init__(self, hid_channel, num_classes, dims, image_size: tuple[int, int]):
        # TODO
        # ここの次元の調整などはまだ見てない
        super().__init__()
        self.linear1 = nn.Conv2d(dims[1], hid_channel, 1)
        self.linear2 = nn.Conv2d(dims[2], hid_channel, 1)
        self.linear3 = nn.Conv2d(dims[3], hid_channel, 1)
        self.linear4 = nn.Conv2d(dims[4], hid_channel, 1)

        self.upsample = nn.Upsample(image_size)

        self.all_linear = nn.Conv2d(4 * hid_channel, hid_channel, 1)
        self.classify = nn.Conv2d(hid_channel, num_classes, 1)

    def forward(self, outputs: list[torch.Tensor]):
        # upsampleに渡すときはCが前
        """
        MLP onlyのdecoder

        Args:
            outputs (list[torch.Tensor]): それぞれのテンソルのサイズは(B, C, H, W)

        Returns:
            torch.Tensor: (B, H, W, C)
        """
        f1 = self.upsample(self.linear1(outputs[0]))
        f2 = self.upsample(self.linear2(outputs[1]))
        f3 = self.upsample(self.linear3(outputs[2]))
        f4 = self.upsample(self.linear4(outputs[3]))
        F = self.all_linear(torch.cat((f1, f2, f3, f4), dim=1))
        result = self.classify(F).permute(0, 2, 3, 1)
        return result


In [None]:
decoder = MyDecoder(768, 13, config.dims, (480, 640))
pred = decoder(outputs)

## Self-Supervised Learning

In [None]:
import gc
gc.collect()

7

In [None]:
Depth_model = SegformerModel.from_pretrained(
    "nvidia/segformer-b0-finetuned-ade-512-512"
)
Depth_encoder = Depth_model.encoder
image = torch.Tensor(np.array(Image.open("./data/train/image/000002.png")))
print(image.size())
image = image.unsqueeze(0).permute(0, 3, 1, 2)
print(image.size())

def random_indexes(size: int):
    """
    パッチをランダムに並べ替えるためのindexを生成する関数．

    Argument
    --------
    size : int
        入力されるパッチの数（系列長Nと同じ値）．
    """
    forward_indexes = np.arange(size)  # 0からsizeまでを並べた配列を作成
    np.random.shuffle(forward_indexes)  # 生成した配列をシャッフルすることで，パッチの順番をランダムに決定
    backward_indexes = np.argsort(forward_indexes)  # 並べ替えたパッチをもとの順番に戻すためのidx

    return forward_indexes, backward_indexes


def take_indexes(sequences: torch.Tensor, indexes: torch.Tensor):
    """
    sequences: (B, C, N, patch_H, patch_W)
    indexes: (B, N)
    """
    B, C, N, patch_H, patch_W = sequences.size()
    return torch.gather(sequences, dim=2, index=indexes.reshape(B, 1, N, 1, 1).repeat(1, C, 1, patch_H, patch_W).to(torch.int).to(torch.long))


class MaskImage(nn.Module):
    def __init__(self, image_size=(480, 640), patch_size=(40, 40), mask_ratio=0.5):
        super().__init__()
        self.mask_ratio = mask_ratio
        self.image_size = image_size
        self.patch_size = patch_size
        self.patch_num = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
        self.patch_len = self.patch_num[0] * self.patch_num[1]
        self.mask = nn.Parameter(torch.randn((1, 1, 1, patch_size[0], patch_size[1])))  # (1, 1, 1, patch_H, patch_W)

    def forward(self, image: torch.Tensor):
        """
        マスクしてSSL用のインプット画像を作る

        Args:
            image (torch.Tensor): (B, C, H, W)
        """
        b, c, h, w = image.size()
        remain_N = int(self.patch_len * (1 - self.mask_ratio))
        image_patch = image.reshape((b, c, self.patch_num[0], self.patch_size[0], self.patch_num[1], self.patch_size[1]))
        print(image.size())
        image_patch = image_patch.permute(0, 1, 2, 4, 3, 5).flatten(2, 3)
        print(image.size())
        indexes = [random_indexes(self.patch_len) for _ in range(b)]
        forward_indexes = torch.Tensor(np.array([indexes[i][0] for i in range(b)]))
        backward_indexes = torch.Tensor(np.array([indexes[i][1] for i in range(b)]))
        shuffled_image_patch = take_indexes(image_patch, forward_indexes)  # (B, C, N, patch_H, patch_W)
        shuffled_image_patch = shuffled_image_patch[:, :, :remain_N, :, :]
        shuffled_image_patch = torch.cat((shuffled_image_patch, self.mask.repeat(b, c, self.patch_len - remain_N, 1, 1)), dim=2)
        masked_image_patch = take_indexes(shuffled_image_patch, backward_indexes)  # (B, C, N, patch_H, patch_W)
        masked_image = masked_image_patch.reshape(b, c, self.patch_num[0], self.patch_num[1], self.patch_size[0], self.patch_size[1])
        masked_image = masked_image.permute((0, 1, 2, 4, 3, 5)).reshape(b, c, h, w)
        return masked_image


mask_image = MaskImage()
image = image / 255
masked_image: torch.Tensor = mask_image(image)
masked_image = (masked_image * 255).squeeze().permute(1, 2, 0)
print(f"{masked_image.size()=}")
pil_img = Image.fromarray(masked_image.to(torch.uint8).detach().numpy())
pil_img.save("masked_image_000002.png")

torch.Size([480, 640, 3])
torch.Size([1, 3, 480, 640])
torch.Size([1, 3, 480, 640])
torch.Size([1, 3, 480, 640])
masked_image.size()=torch.Size([480, 640, 3])


## Train and Valid

In [None]:
# config
@dataclass
class TrainingConfig:
    # データセットパス
    dataset_root: str = "data"

    # データ関連
    batch_size: int = 32
    num_workers: int = 4

    # モデル関連
    in_channels: int = 3
    num_classes: int = 13  # NYUv2データセットの場合

    # 学習関連
    epochs: int = 100
    learning_rate: float = 0.001
    weight_decay: float = 1e-4

    # データ分割関連
    train_val_split: float = 0.8  # 訓練データの割合

    # デバイス設定
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # チェックポイント関連
    checkpoint_dir: str = "checkpoints"
    save_interval: int = 5  # エポックごとのモデル保存間隔

    # データ拡張・前処理関連
    image_size: tuple = (256, 256)
    normalize_mean: tuple = (0.485, 0.456, 0.406)  # ImageNetの標準化パラメータ
    normalize_std: tuple = (0.229, 0.224, 0.225)
    
    # 学習のハイパーパラメータ
    rgb_model: str = "nvidia/segformer-b0-finetuned-ade-512-512"
    depth_model: str = "nvidia/segformer-b0-finetuned-ade-512-512"
    dims: tuple = (3, 32, 64, 160, 256, 256)
    reductions: tuple = (64, 16, 4, 1, 1)
    depths: tuple = (3, 4, 18, 18, 3)  # mit_b3を参考にした
    decoder_hiddim: int  # mit_b3を参考にした

    def __post_init__(self):
        import os
        os.makedirs(self.checkpoint_dir, exist_ok=True)

In [None]:
config = TrainingConfig(
    dataset_root='data',
    batch_size=16,
    num_workers=4,
    learning_rate=1e-4,
    epochs=100,
    image_size=(480, 640),
    in_channels=3,

    rgb_model = "nvidia/segformer-b0-finetuned-ade-512-512",
    depth_model = "nvidia/segformer-b0-finetuned-ade-512-512",
    dims = (3, 32, 64, 160, 256, 256),
    reductions = (64, 16, 4, 1, 1),
    depths = (1, 1, 1, 1, 1),
    decoder_hiddim = 768
)

In [None]:
train_dataset = NYUv2(root=config.dataset_root, split="train", config=config)
test_dataset = NYUv2(root=config.dataset_root, split="test", config=config)
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)