In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import os

import torch
import albumentations as A
import datetime
import json

from torch.utils.data import DataLoader
from Module.dataset.WaterDataset import TrainDataset, ValDataset
from Module.sam import SamPredictor, sam_model_registry
from Module.utils.util import get_device, setting2json
from Module.trainer.full_train import train_one_epoch
from Module.trainer.fine_tuning import tune_one_epoch
from Module.trainer.validation import val_one_epoch
from Module.utils.text_writer import TextWriter

  from .autonotebook import tqdm as notebook_tqdm
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)


# Setting

In [3]:
setting = {}

lr = 1e-5
bs = 2
N_epoch = 30
device = get_device()

## Selelct "fine_tunning" or "full_train"
train_type = "fine_tunning"

setting["train"] = {}
setting["train"]["train_type"] = train_type
setting["train"]["learning_rate"] = lr
setting["train"]["batch_size"] = bs
setting["train"]["epochs"] = N_epoch
setting["device"] = device.type

# Model

## Load pre-trained model

In [4]:
model_type = 'vit_b'
checkpoint = None

if model_type=="vit_h":
    checkpoint = '../Weights/sam_vit_h_4b8939.pth'
elif model_type=="vit_l":
    checkpoint = '../Weights/sam_vit_l_0b3195.pth' 
elif model_type=="vit_b":
    checkpoint = '../Weights/sam_vit_b_01ec64.pth' 
    

In [5]:
sam_model = sam_model_registry[model_type](checkpoint=checkpoint)
sam_model.to(device)
sam_model.train();

  state_dict = torch.load(f)


In [6]:
enc_img_size = sam_model.image_encoder.img_size

setting["model"] = {}
setting["model"]["type"] = model_type
setting["model"]["checkpoint"] = checkpoint
setting["model"]["enc_img_size"] = enc_img_size

# Dataset

In [7]:
dataset_name = "Danu_WS_v2"
desired_size=(enc_img_size, enc_img_size)

## Train dataset


### Transform

In [8]:
darkness_transform = A.RandomBrightnessContrast(
    brightness_limit=(-0.5, -0.2),  # 밝기 감소 범위
    contrast_limit=0.1,  # 대비 조정 (옵션)
    p=0.5
)

noise_transform = A.GaussNoise(var_limit=(10.0, 50.0), 
                               p=0.5)
night_color_transform = A.OneOf([
    A.HueSaturationValue(hue_shift_limit=0, 
                         sat_shift_limit=-30, 
                         val_shift_limit=-50, 
                         p=0.7),  # 어둡고 낮은 채도
    A.RGBShift(r_shift_limit=-20, 
               g_shift_limit=-20, 
               b_shift_limit=30, 
               p=0.3)  # 밤 색조
], p=1.0)

contrast_transform = A.RandomBrightnessContrast(
    brightness_limit=(-0.4, 0.2), 
    contrast_limit=(-0.2, 0.2),
    p=0.7
)

In [9]:
train_transform = A.Compose([
    darkness_transform,       # 어둡게 하기
    noise_transform,          # 노이즈 추가
    night_color_transform,    # 밤 색상 변환
    contrast_transform,       # 대비 및 색상 조정
    A.RandomShadow(p=0.5),
    A.RandomRotate90(p=1),
    A.RandomGridShuffle(p=0.5)
])

### build train dataset

In [10]:
train_img_dir = r"D:\000_Datasets\20.Water segmentation\DANU_WS_v2\train\images"
train_gt_dir = r"D:\000_Datasets\20.Water segmentation\DANU_WS_v2\train\masks"

In [11]:
train_dataset = TrainDataset(train_img_dir, 
                             train_gt_dir, 
                             desired_size=desired_size,
                             transform=train_transform)


In [12]:
train_loader=DataLoader(train_dataset, 
                        shuffle=True, 
                        batch_size=bs, 
                        num_workers=0)

In [13]:
setting["Train_dataset"] = {}
setting["Train_dataset"]["dataset_name"] = dataset_name
setting["Train_dataset"]["Directory"] = train_img_dir
setting["Train_dataset"]["shuffle"] = True,
setting["Train_dataset"]["desired_size"] = desired_size
setting["Train_dataset"]["Augmentation"] = A.to_dict(train_transform)

### Validation dataset

In [14]:
val_img_dir = r"D:\000_Datasets\20.Water segmentation\DANU_WS_v2\valid\images"
val_gt_dir = r"D:\000_Datasets\20.Water segmentation\DANU_WS_v2\valid\masks"

In [15]:
val_dataset = ValDataset(val_img_dir, 
                         val_gt_dir,
                         desired_size=desired_size)


In [16]:
val_loader=DataLoader(val_dataset, 
                        shuffle=False, 
                        batch_size=1, 
                        num_workers=0)

In [17]:
setting["Valid_dataset"] = {}
setting["Valid_dataset"]["dataset_name"] = dataset_name
setting["Valid_dataset"]["Directory"] = val_img_dir
setting["Valid_dataset"]["desired_size"] = desired_size
setting["Valid_dataset"]["shuffle"] = False

# setting 

## Tunning-setting

### Optimizer

In [18]:
if train_type == "fine_tunning":

    optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), 
                                 lr=lr, 
                                 weight_decay=0)
elif train_type == "full_train":
    optimizer = torch.optim.Adam(sam_model.parameters(), 
                                 lr=lr, 
                                 weight_decay=0)

setting["train"]["optimizer"] = "adam"

### Loss

In [19]:
loss_fn = torch.nn.BCEWithLogitsLoss() #torch.nn.MSELoss()

setting["train"]["optimizer"] = "BCEwithLogits"

### Warm-up

In [20]:
nw = 5
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 
                                              lr_lambda=lambda epoch: epoch / nw  if epoch < nw else 1)

setting["train"]["warm-up_epoch"] = nw

## Setting directory

In [21]:
def check_dir(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)
        
def get_output_dir(base_dir, model_type, dataset_type, note=""):
    
    now = datetime.datetime.now()
    formattedDate = now.strftime("%Y%m%d")[2:]
    

    save_dir_name = f"{formattedDate}_{model_type}_on_{dataset_type}_{note}" 
    
    
    weight_dir = os.path.join(base_dir, save_dir_name)
    
    fig_dir = os.path.join(weight_dir,"figures")
    
    return weight_dir, fig_dir
    

In [22]:
base_dir = "../runs/"
note = "ft_only_decoder_with_pos_and_neg_v4"
output_checkpoint_dir, figure_dir = get_output_dir(base_dir, model_type, dataset_name, note)

print(output_checkpoint_dir)

../runs/241202_vit_b_on_Danu_WS_v2_ft_only_decoder_with_pos_and_neg_v4


In [23]:
check_dir(output_checkpoint_dir)
check_dir(figure_dir)

In [24]:
output_checkpoint_path = os.path.join(output_checkpoint_dir, 
                                      "best.pth")

## Save setting

In [25]:
output_json_path = os.path.join(output_checkpoint_dir, "setting.json")
output_json = setting2json(setting)

with open(output_json_path, 'w', encoding='UTF-8') as outfile:
    json.dump(output_json, outfile, indent=4)

## Setting recoder

In [26]:
log_path = os.path.join(output_checkpoint_dir, "train_logs.txt")
recoder = TextWriter(log_path)

### Setting

In [None]:
best_val_loss = float('inf')

for epoch in range(N_epoch):
    epoch_train_loss, epoch_train_accuracy = tune_one_epoch(model=sam_model,
                                                             data_loader=train_loader,
                                                             optimizer=optimizer,
                                                             loss_fn=loss_fn,
                                                             device=device)
    train_txt1 = f'[{epoch}] Mean training loss: {epoch_train_loss}'
    train_txt2 = f'[{epoch}] Mean training accuracy: {epoch_train_accuracy}'
    print(train_txt1)
    print(train_txt2)
    recoder.add_line(train_txt1+"\n")
    recoder.add_line(train_txt2+"\n")
    
    epoch_val_loss, epoch_val_accuracy = val_one_epoch(model=sam_model,
                                                       data_loader=val_loader,
                                                       loss_fn=loss_fn,
                                                       device=device)
    # Validation loop
    # Calculate mean validation loss for the current epoch
    val_txt1 = f'[{epoch}] Mean validation loss: {epoch_val_loss}'
    print(val_txt1)

    # Calculate mean validation accuracy for the current epoch
    val_txt2 = f'[{epoch}] Mean validation accuracy: {epoch_val_accuracy}'
    print(val_txt2)
    
    recoder.add_line(val_txt1+"\n")
    recoder.add_line(val_txt2+"\n")
    # Save the model checkpoint if the validation accuracy improves
    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        torch.save(sam_model.state_dict(), output_checkpoint_path)

    # Clear GPU cache after each epoch
    scheduler.step()
    torch.cuda.empty_cache()


100%|██████████████████████████████████████████████████████████████████████████████| 1260/1260 [20:02<00:00,  1.05it/s]


[0] Mean training loss: 0.719346374415216
[0] Mean training accuracy: 0.684337610290164


100%|████████████████████████████████████████████████████████████████████████████████| 359/359 [02:35<00:00,  2.31it/s]


[0] Mean validation loss: 0.6817650120072378
[0] Mean validation accuracy: 0.7398429248990455


 23%|█████████████████▊                                                             | 285/1260 [04:39<15:56,  1.02it/s]