<a href="https://colab.research.google.com/github/LiJiEGG/sam2_app/blob/main/notebooks/sam2_cucumber_finetune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/facebookresearch/sam2.git

In [None]:
%cd ./sam2/

In [None]:
!cd ./checkpoints && ./download_ckpts.sh

In [None]:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

sam2_checkpoint = "/content/sam2/checkpoints/sam2.1_hiera_small.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
predictor = SAM2ImagePredictor(sam2_model)

In [None]:
import numpy as np
import cv2
import os

def read_data(data_dir):
    data = []
    for name in os.listdir(data_dir):
        rgb_path = os.path.join(data_dir, name, "rgb.png")
        depth_path = os.path.join(data_dir, name, "depth.png")
        ir_path = os.path.join(data_dir, name, "ir.png")
        label_path = os.path.join(data_dir, name, "label.png")
        data.append({"rgb": rgb_path, "depth": depth_path, "ir": ir_path, "label": label_path})
    return data

def preprocess(image):
    # 图像预处理操作，如调整大小、归一化等
    return image

In [None]:
import torch

optimizer = torch.optim.AdamW(params=predictor.model.parameters(), lr=1e-5, weight_decay=4e-5)
scaler = torch.cuda.amp.GradScaler()

for epoch in range(num_epochs):
    for batch in data_loader:
        rgb = preprocess(batch['rgb'])
        depth = preprocess(batch['depth'])
        ir = preprocess(batch['ir'])
        label = batch['label']

        # 前向传播
        with torch.cuda.amp.autocast():
            output = predictor(rgb, depth, ir)
            loss = loss_function(output, label)

        # 反向传播
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

In [None]:
import torch
import os
import cv2
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader

class MultiModalCucumberDataset(Dataset):
    def __init__(self, csv_file, data_root, transform=None):
        """
        Args:
            csv_file (str): 包含图像路径及标签的 CSV 文件。
            data_root (str): 数据存储的根目录。
            transform (callable, optional): 预处理函数，如归一化或数据增强。
        """
        self.data_info = pd.read_csv(csv_file)
        self.data_root = data_root
        self.transform = transform

    def __len__(self):
        return len(self.data_info)

    def __getitem__(self, idx):
        # 读取 RGB、深度、IR 图像路径
        rgb_path = os.path.join(self.data_root, self.data_info.iloc[idx]['rgb'])
        depth_path = os.path.join(self.data_root, self.data_info.iloc[idx]['depth'])
        ir_path = os.path.join(self.data_root, self.data_info.iloc[idx]['ir'])

        # 读取图像数据
        rgb_img = cv2.imread(rgb_path)  # BGR格式
        rgb_img = cv2.cvtColor(rgb_img, cv2.COLOR_BGR2RGB)  # 转为 RGB 格式

        depth_img = np.load(depth_path)  # 读取 `.npy` 格式的深度图
        ir_img = np.load(ir_path)  # 读取 `.npy` 格式的 IR 图像

        # 读取标签信息（如株高）
        label = self.data_info.iloc[idx]['height']

        # 归一化处理
        rgb_img = rgb_img.astype(np.float32) / 255.0
        depth_img = depth_img.astype(np.float32) / np.max(depth_img)  # 归一化
        ir_img = ir_img.astype(np.float32) / np.max(ir_img)  # 归一化

        # 转换为 PyTorch 张量
        rgb_img = torch.tensor(rgb_img).permute(2, 0, 1)  # 形状调整为 (C, H, W)
        depth_img = torch.tensor(depth_img).unsqueeze(0)  # (1, H, W)
        ir_img = torch.tensor(ir_img).unsqueeze(0)  # (1, H, W)
        label = torch.tensor(label, dtype=torch.float32)

        # 应用数据增强
        if self.transform:
            rgb_img = self.transform(rgb_img)

        return {
            'rgb': rgb_img,
            'depth': depth_img,
            'ir': ir_img,
            'label': label
        }

# 示例：创建数据集
csv_file = "cucumber_dataset.csv"
data_root = "data/"
dataset = MultiModalCucumberDataset(csv_file, data_root)

# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4)