In [1]:
import os

# 1. リポジトリをクローン
!git clone https://github.com/DHW-Master/NEU_Seg.git

# 2. データのパスを確認
# NEU_Seg内には 'images' フォルダに画像が含まれています
dataset_root = "NEU_Seg/images"

if os.path.exists(dataset_root):
    print("Dataset successfully downloaded.")
    # フォルダ内のファイル数を表示
    files = [f for f in os.listdir(dataset_root) if f.endswith('.jpg')]
    print(f"Total images found: {len(files)}")
    # ファイル名の例を表示（Cr, In, Pa, PS, Sc, RSなどの略称が含まれます）
    # 例: Cr_1.jpg (Cracks: ひび割れ)
    print("Example filenames:", files[:5])
else:
    print("Failed to download the dataset.")

Cloning into 'NEU_Seg'...
remote: Enumerating objects: 8947, done.[K
remote: Counting objects: 100% (2/2), done.[K
remote: Compressing objects: 100% (2/2), done.[K
remote: Total 8947 (delta 0), reused 0 (delta 0), pack-reused 8945 (from 1)[K
Receiving objects: 100% (8947/8947), 35.08 MiB | 11.53 MiB/s, done.
Dataset successfully downloaded.
Total images found: 0
Example filenames: []


```
├──./datasets/
│    ├── Dataset_name
│    │    ├── train
│    │    │    ├── images
│    │    │    │    ├── 000001.png
│    │    │    │    ├── ...
│    │    │    ├── masks
│    │    │    │    ├── 000001.png
│    │    │    │    ├── ...
│    │    ├── test
│    │    │    ├── images
│    │    │    │    ├── 000002.png
```

In [16]:
!rm -rf /content/datasets

In [17]:
import os
import shutil
from glob import glob

# 1. 設定：新しいディレクトリ構造の定義
base_dir = "./datasets/NEU_Seg_Custom"
structure = [
    "train/images", "train/masks",
    "test/images", "test/masks"
]

for path in structure:
    os.makedirs(os.path.join(base_dir, path), exist_ok=True)

# 2. パス設定
src_img_dir = "/content/NEU_Seg/images"
src_ann_dir = "/content/NEU_Seg/annotations"

def organize_neu_data():
    dirs_source = ["test", "training"] # 元リポジトリのサブフォルダ名

    for dir_source in dirs_source:
        # ソースディレクトリ内の全画像を取得
        current_src_img_path = os.path.join(src_img_dir, dir_source)
        all_images = sorted(glob(os.path.join(current_src_img_path, "*.jpg")))

        print(f"\nScanning directory: {dir_source}")

        # マスクが存在する有効なペアのみを抽出
        valid_pairs = []
        for img_path in all_images:
            fname = os.path.basename(img_path)
            mask_name = fname.replace(".jpg", ".png")

            # マスクの存在確認（複数の候補地をチェック）
            potential_mask_paths = [
                os.path.join(src_ann_dir, mask_name),
                os.path.join(src_ann_dir, "test", mask_name),
                os.path.join(src_ann_dir, "train", mask_name),
                os.path.join(src_ann_dir, "training", mask_name)
            ]

            found_mask_path = None
            for m_path in potential_mask_paths:
                if os.path.exists(m_path):
                    found_mask_path = m_path
                    break

            # マスクが見つかった場合のみリストに追加
            if found_mask_path:
                valid_pairs.append((img_path, found_mask_path))

        print(f"Total images found: {len(all_images)}")
        print(f"Valid pairs with masks: {len(valid_pairs)}")

        if len(valid_pairs) == 0:
            continue

        # 8:2 で分割
        split_idx = int(len(valid_pairs) * 0.8)
        train_pairs = valid_pairs[:split_idx]
        test_pairs = valid_pairs[split_idx:]

        def copy_valid_files(pairs, target_sub_dir):
            for img_src, mask_src in pairs:
                fname = os.path.basename(img_src)
                mname = os.path.basename(mask_src)

                # 画像のコピー
                shutil.copy(img_src, os.path.join(base_dir, target_sub_dir, "images", fname))
                # マスクのコピー
                shutil.copy(mask_src, os.path.join(base_dir, target_sub_dir, "masks", mname))

        print(f"Copying {len(train_pairs)} pairs to Train...")
        copy_valid_files(train_pairs, "train")

        print(f"Copying {len(test_pairs)} pairs to Test...")
        copy_valid_files(test_pairs, "test")

    print("\nData organization complete!")

organize_neu_data()


Scanning directory: test
Total images found: 840
Valid pairs with masks: 840
Copying 672 pairs to Train...
Copying 168 pairs to Test...

Scanning directory: training
Total images found: 3630
Valid pairs with masks: 3630
Copying 2904 pairs to Train...
Copying 726 pairs to Test...

Data organization complete!


In [18]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# 1. Datasetの定義
class NEUSegDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.img_names = [f for f in os.listdir(img_dir) if f.endswith(('.jpg', '.png'))]
        self.transform = transform

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_names[idx])
        image = Image.open(img_path).convert("L") # グレースケール
        if self.transform:
            image = self.transform(image)
        return image

# 2. RPCANet モデルの実装 (Deep Unfolding RPCA)
class RPCANet(nn.Module):
    def __init__(self, layers=5):
        super(RPCANet, self).__init__()
        self.layers = layers
        # ISTAのステップサイズとしきい値を層ごとに学習
        self.eta = nn.Parameter(torch.ones(layers) * 0.1)
        self.theta = nn.Parameter(torch.ones(layers) * 0.01)

        # 変換行列（畳み込み層として実装することで近傍情報を活用）
        self.conv_W = nn.ModuleList([
            nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False) for _ in range(layers)
        ])
        for conv in self.conv_W:
            nn.init.constant_(conv.weight, 1.0/9.0) # 平均化フィルタに近い初期値

    def soft_threshold(self, x, theta):
        return torch.sign(x) * torch.relu(torch.abs(x) - F.softplus(theta))

    def forward(self, M):
        S = torch.zeros_like(M)
        for i in range(self.layers):
            # L = M - S (低ランク成分の推定)
            # 勾配降下ステップの展開
            residual = M - S
            grad = self.conv_W[i](residual)
            S = self.soft_threshold(S + self.eta[i] * grad, self.theta[i])

        L = M - S
        return L, S

class NEUSegDataset(Data.Dataset):
    def __init__(self, base_dir, mode='train', base_size=256):
        self.img_dir = os.path.join(base_dir, mode, 'images')
        self.mask_dir = os.path.join(base_dir, mode, 'masks')
        self.img_names = sorted([f for f in os.listdir(self.img_dir) if f.endswith('.jpg')])
        self.base_size = base_size
        self.transform = transforms.Compose([
            transforms.Resize((base_size, base_size)),
            transforms.ToTensor(),
        ])

    def __getitem__(self, idx):
        name = self.img_names[idx]
        img_path = os.path.join(self.img_dir, name)
        mask_path = os.path.join(self.mask_dir, name.replace('.jpg', '.png'))

        image = Image.open(img_path).convert('L')
        mask = Image.open(mask_path).convert('L') # マスクもグレースケールで読み込み

        image = self.transform(image)
        mask = self.transform(mask)

        # マスクを0or1のバイナリにする（SoftIoU用）
        mask = (mask > 0.5).float()

        return image, mask

    def __len__(self):
        return len(self.img_names)




In [19]:

class Avg_ChannelAttention_n(nn.Module):
    def __init__(self, channels, r=4):
        super(Avg_ChannelAttention_n, self).__init__()
        self.avg_channel = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),  # 全局平均池化 bz,C_out,h,w -> bz,C_out,1,1
            nn.Conv2d(channels, channels // r, 1, 1, 0),  # bz,C_out,1,1 -> bz,C_out/r,1,1
            nn.BatchNorm2d(channels // r),
            nn.ReLU(True),
            nn.Conv2d(channels // r, channels, 1, 1, 0),  # bz,C_out/r,1,1 -> bz,C_out,1,1
            nn.BatchNorm2d(channels),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.avg_channel(x)



class Avg_ChannelAttention(nn.Module):
    def __init__(self, channels, r=4):
        super(Avg_ChannelAttention, self).__init__()
        self.avg_channel = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),  # 全局平均池化 bz,C_out,h,w -> bz,C_out,1,1
            nn.Conv2d(channels, channels // r, 1, 1, 0),  # bz,C_out,1,1 -> bz,C_out/r,1,1
            nn.ReLU(True),
            nn.Conv2d(channels // r, channels, 1, 1, 0),  # bz,C_out/r,1,1 -> bz,C_out,1,1
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.avg_channel(x)


class AttnContrastLayer(nn.Module):
    def __init__(self, channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=False):
        super(AttnContrastLayer, self).__init__()

        self.conv = nn.Conv2d(channels, channels, kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, groups=groups, bias=bias)

        self.attn = Avg_ChannelAttention(channels)

    def forward(self, x):

        out_normal = self.conv(x)

        theta = self.attn(x)

        kernel_w1 = self.conv.weight.sum(2).sum(2)

        kernel_w2 = kernel_w1[:, :, None, None]

        out_center = F.conv2d(input=x, weight=kernel_w2, bias=self.conv.bias, stride=self.conv.stride,
                              padding=0, groups=self.conv.groups)

        return theta * out_center - out_normal



class AttnContrastLayer_n(nn.Module):
    def __init__(self, channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=False):
        super(AttnContrastLayer_n, self).__init__()

        self.conv = nn.Conv2d(channels, channels, kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, groups=groups, bias=bias)

        self.attn = Avg_ChannelAttention_n(channels)

    def forward(self, x):
        out_normal = self.conv(x)
        theta = self.attn(x)


        kernel_w1 = self.conv.weight.sum(2).sum(2)
        kernel_w2 = kernel_w1[:, :, None, None]

        out_center = F.conv2d(input=x, weight=kernel_w2, bias=self.conv.bias, stride=self.conv.stride,
                              padding=0, groups=self.conv.groups)

        return theta * out_center - out_normal

class AttnContrastLayer_d(nn.Module):
    def __init__(self, channels, kernel_size=3, stride=1, padding=1, dilation=2, groups=1, bias=False):
        super(AttnContrastLayer_d, self).__init__()

        self.conv = nn.Conv2d(channels, channels, kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, groups=groups, bias=bias)

        self.attn = Avg_ChannelAttention(channels)

    def forward(self, x):

        out_normal = self.conv(x)

        theta = self.attn(x)

        kernel_w1 = self.conv.weight.sum(2).sum(2)
        kernel_w2 = kernel_w1[:, :, None, None]
        out_center = F.conv2d(input=x, weight=kernel_w2, bias=self.conv.bias, stride=self.conv.stride,
                              padding=0, groups=self.conv.groups)

        return out_center - theta * out_normal

class AtrousAttnWeight(nn.Module):
    def __init__(self, channels):
        super(AtrousAttnWeight, self).__init__()
        self.attn = Avg_ChannelAttention(channels)

    def forward(self, x):
        return self.attn(x)

In [20]:
import torch
import torch.nn as nn
from einops import rearrange, repeat
import math
import torch.nn.functional as F

import numpy as np

__all__ = ['RPCANet9','RPCANet_LSTM']


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, bias=True, res_scale=1):

        super(ResidualBlock, self).__init__()
        self.res_scale = res_scale
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias)
        self.act1 = nn.ReLU(inplace=True)

    def forward(self, x):
        input = x
        x = self.conv1(x)
        x = self.act1(x)
        x = self.conv2(x)
        res = x
        x = res + input
        return x

class RPCANet9(nn.Module):
    def __init__(self, stage_num=6, slayers=6, llayers=3, mlayers=3, channel=32, mode='train'):
        super(RPCANet9, self).__init__()
        self.stage_num = stage_num
        self.decos = nn.ModuleList()
        self.mode = mode
        for _ in range(stage_num):
            self.decos.append(DecompositionModule9(slayers=slayers, llayers=llayers,
                                                  mlayers=mlayers, channel=channel))
        for m in self.modules():
            # 也可以判断是否为conv2d，使用相应的初始化方式
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
                #nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, D):
        T = torch.zeros(D.shape).to(D.device)
        for i in range(self.stage_num):
            D, T = self.decos[i](D, T)
        if self.mode == 'train':
            return D,T
        else:
            return T

class DecompositionModule9(nn.Module):
    def __init__(self, slayers=6, llayers=3, mlayers=3, channel=32):
        super(DecompositionModule9, self).__init__()
        self.lowrank = LowrankModule9(channel=channel, layers=llayers)
        self.sparse = SparseModule9(channel=channel, layers=slayers)
        self.merge = MergeModule9(channel=channel, layers=mlayers)

    def forward(self, D, T):
        B = self.lowrank(D, T)
        T = self.sparse(D, B, T)
        D = self.merge(B, T)
        return D, T

class LowrankModule9(nn.Module):
    def __init__(self, channel=32, layers=3):
        super(LowrankModule9, self).__init__()

        convs = [nn.Conv2d(1, channel, kernel_size=3, padding=1, stride=1),
                 nn.BatchNorm2d(channel),
                 nn.ReLU(True)]
        for i in range(layers):
            convs.append(nn.Conv2d(channel, channel, kernel_size=3, padding=1, stride=1))
            convs.append(nn.BatchNorm2d(channel))
            convs.append(nn.ReLU(True))
        convs.append(nn.Conv2d(channel, 1, kernel_size=3, padding=1, stride=1))
        self.convs = nn.Sequential(*convs)

    def forward(self, D, T):
        x = D - T
        B = x + self.convs(x)
        return B

class SparseModule9(nn.Module):
    def __init__(self, channel=32, layers=6) -> object:
        super(SparseModule9, self).__init__()
        convs = [nn.Conv2d(1, channel, kernel_size=3, padding=1, stride=1),
                 nn.ReLU(True)]
        for i in range(layers):
            convs.append(nn.Conv2d(channel, channel, kernel_size=3, padding=1, stride=1))
            convs.append(nn.ReLU(True))
        convs.append(nn.Conv2d(channel, 1, kernel_size=3, padding=1, stride=1))
        self.convs = nn.Sequential(*convs)
        self.epsilon = nn.Parameter(torch.Tensor([0.01]), requires_grad=True)

    def forward(self, D, B, T):
        x = T + D - B
        T = x - self.epsilon * self.convs(x)
        return T

class MergeModule9(nn.Module):
    def __init__(self, channel=32, layers=3):
        super(MergeModule9, self).__init__()
        convs = [nn.Conv2d(1, channel, kernel_size=3, padding=1, stride=1),
                 nn.BatchNorm2d(channel),
                 nn.ReLU(True)]
        for i in range(layers):
            convs.append(nn.Conv2d(channel, channel, kernel_size=3, padding=1, stride=1))
            convs.append(nn.BatchNorm2d(channel))
            convs.append(nn.ReLU(True))
        convs.append(nn.Conv2d(channel, 1, kernel_size=3, padding=1, stride=1))
        self.mapping = nn.Sequential(*convs)

    def forward(self, B, T):
        x = B + T
        D = self.mapping(x)
        return D


class ConvLSTM(nn.Module):
    def __init__(self, inp_dim, oup_dim, kernel):
        super().__init__()
        pad_x = 1
        self.conv_xf = nn.Conv2d(inp_dim, oup_dim, kernel, padding=pad_x)
        self.conv_xi = nn.Conv2d(inp_dim, oup_dim, kernel, padding=pad_x)
        self.conv_xo = nn.Conv2d(inp_dim, oup_dim, kernel, padding=pad_x)
        self.conv_xj = nn.Conv2d(inp_dim, oup_dim, kernel, padding=pad_x)

        pad_h = 1
        self.conv_hf = nn.Conv2d(oup_dim, oup_dim, kernel, padding=pad_h)
        self.conv_hi = nn.Conv2d(oup_dim, oup_dim, kernel, padding=pad_h)
        self.conv_ho = nn.Conv2d(oup_dim, oup_dim, kernel, padding=pad_h)
        self.conv_hj = nn.Conv2d(oup_dim, oup_dim, kernel, padding=pad_h)

    def forward(self, x, h, c):

        if h is None and c is None:
            i = F.sigmoid(self.conv_xi(x))
            o = F.sigmoid(self.conv_xo(x))
            j = F.tanh(self.conv_xj(x))
            c = i * j
            h = o * c
        else:
            f = F.sigmoid(self.conv_xf(x) + self.conv_hf(h))
            i = F.sigmoid(self.conv_xi(x) + self.conv_hi(h))
            o = F.sigmoid(self.conv_xo(x) + self.conv_ho(h))
            j = F.tanh(self.conv_xj(x) + self.conv_hj(h))
            c = f * c + i * j
            h = o * F.tanh(c)

        return h, h, c

class RPCANet_LSTM(nn.Module):
    def __init__(self, stage_num=6, slayers=6, mlayers=3, channel=32, mode='train'):
        super(RPCANet_LSTM, self).__init__()
        self.stage_num = stage_num
        self.decos = nn.ModuleList()
        self.mode = mode
        for i in range(stage_num):
            self.decos.append(DecompositionModule_LSTM(slayers=slayers, mlayers=mlayers, channel=channel))
        for m in self.modules():
            # 也可以判断是否为conv2d，使用相应的初始化方式
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
                #nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, D):
        T = torch.zeros(D.shape).to(D.device)
        [h, c] = [None, None]
        # B = torch.zeros(D.shape).to(D.device)
        for i in range(self.stage_num):
            D, T, h, c = self.decos[i](D, T, h, c)
        if self.mode == 'train':
            return D,T
        else:
            return T

class DecompositionModule_LSTM(nn.Module):
    def __init__(self, slayers=6, mlayers=3, channel=32):
        super(DecompositionModule_LSTM, self).__init__()
        self.lowrank = LowrankModule_LSTM(channel=channel)
        self.sparse = SparseModule_LSTM(channel=channel, layers=slayers)
        self.merge = MergeModule_LSTM(channel=channel, layers=mlayers)

    def forward(self, D, T, h, c):
        B, h, c = self.lowrank(D, T, h, c)
        T = self.sparse(D, B, T)
        D = self.merge(B, T)
        return D, T, h, c

class LowrankModule_LSTM(nn.Module):
    def __init__(self, channel=32):
        super(LowrankModule_LSTM, self).__init__()
        self.conv1_C = nn.Sequential(nn.Conv2d(1, channel, kernel_size=3, padding=1, stride=1),
                         nn.BatchNorm2d(channel),
                         nn.ReLU(True))
        self.RB_1 = ResidualBlock(channel, channel, 3, bias=True, res_scale=1)
        self.RB_2 = ResidualBlock(channel, channel, 3, bias=True, res_scale=1)
        self.convC_1 = nn.Conv2d(channel, 1, kernel_size=3, padding=1, stride=1)
        self.ConvLSTM = ConvLSTM(channel, channel, 3)

    def forward(self, D, T, h, c):
        x = D - T
        x_c = self.conv1_C(x)
        x_c1 = self.RB_1(x_c)
        x_ct, h, c = self.ConvLSTM(x_c1, h, c)
        x_c2 = self.RB_2(x_ct)
        x_1 = self.convC_1(x_c2)
        B = x + x_1
        return B, h, c

class SparseModule_LSTM(nn.Module):
    def __init__(self, channel=32, layers=6) -> object:
        super(SparseModule_LSTM, self).__init__()
        convs = [nn.Conv2d(1, channel, kernel_size=3, padding=1, stride=1),
                 nn.ReLU(True)]
        for i in range(layers):
            convs.append(nn.Conv2d(channel, channel, kernel_size=3, padding=1, stride=1))
            convs.append(nn.ReLU(True))
        convs.append(nn.Conv2d(channel, 1, kernel_size=3, padding=1, stride=1))
        self.convs = nn.Sequential(*convs)

        self.epsilon = nn.Parameter(torch.Tensor([0.01]), requires_grad=True)
        self.contrast = nn.Sequential(
                nn.Conv2d(1, channel, kernel_size=3, padding=1, stride=1),
                nn.ReLU(True),
                AttnContrastLayer_n(channel, kernel_size=17, padding=8),
                nn.BatchNorm2d(channel),
                nn.LeakyReLU(0.1, inplace=True),
                nn.Conv2d(channel, 1, kernel_size=3, padding=1, stride=1)
            )


    def forward(self, D, B, T):
        x = T + D - B
        w = self.contrast(x)
        T = x - self.epsilon * self.convs(x + w)
        return T

class MergeModule_LSTM(nn.Module):
    def __init__(self,  channel=32, layers=3):
        super(MergeModule_LSTM, self).__init__()
        convs = [nn.Conv2d(1, channel, kernel_size=3, padding=1, stride=1),
                 nn.BatchNorm2d(channel),
                 nn.ReLU(True)]
        for i in range(layers):
            convs.append(nn.Conv2d(channel, channel, kernel_size=3, padding=1, stride=1))
            convs.append(nn.BatchNorm2d(channel))
            convs.append(nn.ReLU(True))
        convs.append(nn.Conv2d(channel, 1, kernel_size=3, padding=1, stride=1))
        self.mapping = nn.Sequential(*convs)

    def forward(self, B, T):
        x = B + T
        D = self.mapping(x)
        return D

In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ハイパーパラメータ
batch_size = 8
lr = 1e-4
num_epochs = 50

# ---- データローダ ----
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

train_dataset = NEUSegDataset(
    base_dir="/content/datasets/NEU_Seg_Custom",
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# ---- モデル ----
model = RPCANet9(stage_num=6)  # 必要に応じてステージ数を調整
model.to(device)

# ---- 損失と最適化 ----

# 再構成誤差として L2 Loss
recon_loss_fn = nn.MSELoss()

# オプティマイザ
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# ---- 訓練ループ ----
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for batch_idx, data in enumerate(train_loader):
        inputs = data[0].to(device)

        # Forward: low-rank + sparse component
        # RPCANet では復元背景 & スパース部分を出力する構造があるはずです
        # 例: low_rank, sparse, recon = model(inputs)
        low_rank, sparse = model(inputs)

        # 再構成：
        recon = low_rank + sparse

        # 再構成誤差
        loss_recon = recon_loss_fn(recon, inputs)

        # スパース部の正則化（L1 など）
        loss_sparse = torch.mean(torch.abs(sparse))

        # 総合損失
        loss = loss_recon + 0.1 * loss_sparse

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.6f}")

    # モデルの保存
    checkpoint_path = f"checkpoints/rpcanet_anomaly_epoch{epoch+1}.pth"
    os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
    torch.save(model.state_dict(), checkpoint_path)

print("Training Completed.")



In [None]:
# 5. 推論と結果の可視化
model.eval()
test_images = next(iter(train_loader)).to(device)
with torch.no_grad():
    L, S = model(test_images)

# 最初の1枚を表示
idx = 0
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.imshow(test_images[idx].cpu().squeeze(), cmap='gray')
plt.title("Original (Metal Surface)")
plt.subplot(1, 3, 2)
plt.imshow(L[idx].cpu().squeeze(), cmap='gray')
plt.title("Low-Rank (Background)")
plt.subplot(1, 3, 3)
plt.imshow(torch.abs(S[idx]).cpu().squeeze(), cmap='hot')
plt.title("Sparse (Detected Defect)")
plt.show()