### 1. 跑 combine_channel.py， 生成 output.tif

### 2. 跑这个 notebook 里面的代码

In [None]:
import random
import torch

torch.manual_seed(3407)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(3407)
random.seed(3407)

In [None]:
import numpy as np
import rasterio
from tif2pngs import ROOT
import os

combined_file_path = os.path.join(ROOT, 'datasets', 'main', 'result.tif') # TODO
root_path = ROOT

with rasterio.open(combined_file_path) as src:
    print(src.meta)
    width = src.meta['width']
    height = src.meta['height']
    channels = src.meta['count']
    # 分块保存ndarray，每块大小为 256 * 256
    block_size = 256
    for i in range(0, height, block_size // 2):
        for j in range(0, width, block_size // 2):
            block = src.read(window=(
                (i, min(i + block_size, height)),
                (j, min(j + block_size, width)))
            )
            
            if block.shape[1:] != (block_size, block_size):
                pad_height = block_size - block.shape[1]
                pad_width = block_size - block.shape[2]
                block = np.pad(block, ((0, 0), (0, pad_height), (0, pad_width)), mode='constant')
            
            # 保存 block 数组
            np.save(f'{root_path}/dataV2/train/blocks/block_{i}_{j}.npy', block)

In [None]:
import numpy as np
import rasterio

combined_file_path = '' # TODO
root_path = '' # TODO

with rasterio.open(combined_file_path) as src:
    print(src.meta)
    width = src.meta['width']
    height = src.meta['height']
    channels = src.meta['count']
    # 分块保存ndarray，每块大小为 256 * 256
    block_size = 256
    for i in range(0, height, block_size // 4):
        for j in range(0, width, block_size // 4):
            block = src.read(window=(
                (i, min(i + block_size, height)),
                (j, min(j + block_size, width)))
            )
            
            if block.shape[1:] != (block_size, block_size):
                pad_height = block_size - block.shape[1]
                pad_width = block_size - block.shape[2]
                block = np.pad(block, ((0, 0), (0, pad_height), (0, pad_width)), mode='constant')
            
            # 保存 block 数组
            np.save(f'{root_path}/dataV2/train/blocks/block_{i}_{j}.npy', block)

In [None]:
from tif2pngs import Tif2Pngs
import os


mask_file_path = os.path.join(root_path, 'datasets', 'standard.tif')
tif2pngs = Tif2Pngs(mask_file_path, os.path.join(root_path, 'dataV2', 'train', 'masks'), stride=64)
tif2pngs.process_tif()

In [None]:
import os
import numpy as np
from tqdm import tqdm
from albumentations import Compose, HorizontalFlip, VerticalFlip, ShiftScaleRotate, RandomResizedCrop
from albumentations.pytorch import ToTensorV2
from PIL import Image
from tif2pngs import ROOT

# 数据增强方法
transform = Compose([
    HorizontalFlip(p=0.5),  # 随机水平翻转
    VerticalFlip(p=0.5),  # 随机垂直翻转
    ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=45, p=0.5),  # 随机仿射变换
    RandomResizedCrop(height=256, width=256, scale=(0.6, 1.0), p=0.5),  # 随机裁剪和调整大小
    # OneOf([
    #     MotionBlur(p=0.2),
    #     MedianBlur(blur_limit=3, p=0.1),
    #     Blur(blur_limit=3, p=0.1),
    # ], p=0.5),
    ToTensorV2()
])

# 数据增强函数
def augment_npy_images(image_path, mask_path, save_dir, transform):
    # 加载 .npy 格式的图像和 .png 格式的掩码
    image = np.load(image_path)  # 加载 block 数据，形状 (C, H, W)
    image = np.moveaxis(image, 0, -1)  # 转换为 (H, W, C)
    mask = np.array(Image.open(mask_path))  # 掩码保持单通道格式

    # 筛选掩码：仅处理包含两种及以上类别的掩码
    unique_classes = np.unique(mask)
    if len(unique_classes) < 2:
        return

    # 应用数据增强
    augmented = transform(image=image, mask=mask)
    transformed_image = augmented['image']  # 增强后的图像
    transformed_mask = augmented['mask']    # 增强后的掩码

    # 保存增强后的图像和掩码
    save_image_path = os.path.join(save_dir, 'blocks', os.path.basename(image_path).replace('.npy', '_aug.npy'))
    save_mask_path = os.path.join(save_dir, 'masks', os.path.basename(mask_path).replace('.png', '_aug.png'))

    # 保存图像为 .npy 格式
    np.save(save_image_path, transformed_image.numpy())
    # 保存掩码为 .png 格式
    Image.fromarray(transformed_mask.numpy()).save(save_mask_path)

# 数据增强目录和逻辑
def augment_dataset(source_dir, target_dir, transform):
    os.makedirs(os.path.join(target_dir, 'blocks'), exist_ok=True)
    os.makedirs(os.path.join(target_dir, 'masks'), exist_ok=True)

    image_files = os.listdir(os.path.join(source_dir, 'blocks'))
    for image_file in tqdm(image_files, desc="Augmenting dataset"):
        if image_file.endswith('_aug.npy'):
            continue
        image_path = os.path.join(source_dir, 'blocks', image_file)
        mask_path = os.path.join(source_dir, 'masks', image_file.replace('block', 'standard').replace('.npy', '.png'))

        # 增强数据并保存到目标目录
        augment_npy_images(image_path, mask_path, target_dir, transform)

# 调用增强函数
train_source_dir = os.path.join(root_path, 'dataV2', 'train')
train_target_dir = os.path.join(root_path, 'dataV2', 'train')

augment_dataset(train_source_dir, train_target_dir, transform)
print("数据增强完成")


In [None]:
from torch.utils.data import DataLoader
from SegmentationDatasetV2 import SegmentationDatasetV2

train_dir = os.path.join(ROOT, 'dataV2', 'train')
train_dataset = SegmentationDatasetV2(root_dir=train_dir)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [None]:
from loss import FocalLoss
import torch
from segmentation_models_pytorch import Unet

model_config = {
    'model': Unet,
    'encoder_name': 'resnet34',
    'classes': 3,
    'channels': 5,
    'activation': 'softmax',
}

# 检查CUDA是否可用
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 创建模型
model = model_config['model'](
    encoder_name=model_config['encoder_name'],
    classes=model_config['classes'],
    in_channels=model_config['channels'],
    activation=model_config['activation'],
).to(device)

# 损失函数和优化器
loss_fn = FocalLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# 初始化最小loss为正无穷大
min_loss = float('inf')

# TODO 训练轮次改这里
epoch_num = 50

# 训练模型
for epoch in range(1, epoch_num + 1):
    model.train()
    for batch in train_loader:
        images, masks = batch
        # 将数据和模型都移动到GPU
        images = images.to(device)
        masks = masks.to(device)
        
        # 现在masks是一个一维的tensor，每个元素对应一个像素的类别索引
        outputs = model(images)
        loss = loss_fn(outputs, masks)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    # 每个epoch结束时保存模型
    torch.save(model.state_dict(), f'../model/best_model_epoch_{epoch}.pth')
    print(f'Epoch {epoch}, Loss: {loss.item()}')