In [15]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim
from utils.utils import train_model, evaluate_model_with_cm, TiffDataset
from models.video_classifier import VideoClassifier
import os

In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [17]:
test_filter = lambda box_number: (box_number % 9 == 0 or box_number % 9 == 5)

In [18]:
class ReshapeTransform:
    """(12*bands, 3, 3) → (12, bands, 3, 3) 변환"""
    def __init__(self, bands):
        self.bands = bands

    def __call__(self, x):
        return x.view(12, self.bands, 9, 9).permute(1, 0, 2, 3)

def get_transform(bands, scale_channels_func=None):
    transform_list = [
        transforms.ToTensor(),  # (H, W, C) → (C, H, W)
        transforms.Lambda(lambda x: x.float()),  # uint16 → float 변환
        ReshapeTransform(bands)  # (12*bands, 3, 3) → (bands, 12, 3, 3)
    ]

    if scale_channels_func:
        transform_list.append(transforms.Lambda(scale_channels_func))  # 채널별 값 조정 추가

    return transforms.Compose(transform_list)

#scaling 함수 - 채널별로 범위의 차이가 크기때문에 어느정도 맞추어주기 위해서 수행
def scale_channels(x):
    """특정 채널값 조정"""
    x[0:3] *= 5  # B,G,R 채널 * 5
    if 4 < x.shape[0]:
        x[4] *= 0.5  # NDVI 채널 * 0.5
    return x

In [19]:
large_tif_dir = '../../data/source_data/naive' #원천데이터 주소
bands = 4 #밴드 수
patch_size = 9

transform = get_transform(bands, scale_channels)

train_dataset = TiffDataset(
    large_tif_dir = large_tif_dir,
    file_list = ["jiri_1.tif", "jiri_2.tif", "sobaek.tif"], #전체 지역을 모두 사용한다.
    label_file = "../../data/label_data/species/label_mapping_sampled.csv",
    box_filter_fn = lambda box_number: not test_filter(box_number),
    patch_size = patch_size,
    transform=transform
)

val_dataset = TiffDataset(
    large_tif_dir = large_tif_dir,
    file_list = ["jiri_1.tif", "jiri_2.tif", "sobaek.tif"], #전체 지역을 모두 사용한다.
    label_file ="../../data/label_data/species/label_mapping_sampled.csv",
    box_filter_fn = test_filter,
    patch_size = patch_size,
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [20]:
# 모델 설정
stage_repeats = [2, 2, 4, 3]  # 각 stage에서 ResBlock 반복 횟수
stage_channels = [16, 32, 64, 128]  # 각 stage의 채널 크기
num_classes = 6  # 분류할 클래스 개수

# 모델 생성
model = VideoClassifier(bands, stage_repeats, stage_channels, num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-2)

In [21]:
num_epochs = 30

best_model_state, train_losses, val_losses = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=num_epochs, patience=100)
os.makedirs("./checkpoints/video_classification_enhanced", exist_ok=True)
torch.save(best_model_state, f"./checkpoints/video_classification_enhanced/cnn_{bands}_{patch_size}_{num_epochs}.pth")

model.load_state_dict(best_model_state)

print("\ntrain data")
evaluate_model_with_cm(model, train_loader, num_classes=6)
print("\nvalidation data")
evaluate_model_with_cm(model, val_loader, num_classes=6)

Epoch 1/30 - Training:   0%|          | 0/3502 [00:00<?, ?it/s]

Epoch 1/30 - Training: 100%|██████████| 3502/3502 [2:22:14<00:00,  2.44s/it]  
Epoch 1/30 - Validation: 100%|██████████| 1207/1207 [15:53<00:00,  1.27it/s]



Epoch [1/30], Train Loss: 2.1599, Train Accuracy: 29.59%, Val Loss: 1.7528, Val Accuracy: 20.36%



Epoch 2/30 - Training: 100%|██████████| 3502/3502 [2:20:46<00:00,  2.41s/it]  
Epoch 2/30 - Validation: 100%|██████████| 1207/1207 [13:47<00:00,  1.46it/s]



Epoch [2/30], Train Loss: 1.5682, Train Accuracy: 34.57%, Val Loss: 1.5861, Val Accuracy: 25.50%



Epoch 3/30 - Training: 100%|██████████| 3502/3502 [2:21:42<00:00,  2.43s/it]  
Epoch 3/30 - Validation: 100%|██████████| 1207/1207 [13:42<00:00,  1.47it/s]



Epoch [3/30], Train Loss: 1.3416, Train Accuracy: 43.69%, Val Loss: 1.2637, Val Accuracy: 37.90%



Epoch 4/30 - Training:  43%|████▎     | 1519/3502 [1:02:38<1:21:47,  2.47s/it]


KeyboardInterrupt: 