In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import os
import cv2
import numpy as np
from tqdm import tqdm
from glob import glob
import torch
from collections import defaultdict
from sklearn.metrics import auc, roc_curve
import matplotlib.pyplot as plt
from pathlib import Path
from statistics import mean

from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.nn.functional import threshold, normalize
import albumentations as A
from albumentations.pytorch import ToTensorV2
from segment_anything.utils.transforms import ResizeLongestSide
from segment_anything import SamPredictor, sam_model_registry

from Module.dataset.WaterDataset import TrainDataset, ValDataset
from Module.utils.metric import calculate_accuracy
from Module.trainer.full_train import train_one_epoch
from Module.trainer.validation import val_one_epoch
from Module.utils.text_writer import TextWriter

  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()


## Model

In [3]:
model_type = 'vit_b'
checkpoint = None

if model_type=="vit_h":
    checkpoint = '../Weights/sam_vit_h_4b8939.pth'
elif model_type=="vit_l":
    checkpoint = '../Weights/sam_vit_l_0b3195.pth' 
elif model_type=="vit_b":
    checkpoint = '../Weights/sam_vit_b_01ec64.pth' 
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
sam_model = sam_model_registry[model_type](checkpoint=checkpoint)
sam_model.to(device)
sam_model.train();

In [5]:
enc_img_size = sam_model.image_encoder.img_size
print(enc_img_size)

1024


## Datasets

In [6]:
bs=2
desired_size=(1024, 1024)

In [7]:
train_transform = A.Compose([
    A.RandomBrightnessContrast(p=1, brightness_limit=(-0.5,0.2)),
    A.RandomShadow(p=1),
    A.RandomRotate90(p=1),
    A.RandomGridShuffle(p=0.7)
])

### Train datasets

### Train dataset

In [8]:
train_img_dir = r"D:\WaterSegmentation\Datasets\DANU_WS_v2\train\images"

train_gt_dir = r"D:\WaterSegmentation\Datasets\DANU_WS_v2\train\masks"

In [9]:
img_dir = r"D:\WaterSegmentation\Datasets\AY_frames_241115\02_dataset_format\images"
gt_dir = r"D:\WaterSegmentation\Datasets\AY_frames_241115\02_dataset_format\masks"

In [10]:
train_img_dir = img_dir

train_gt_dir = gt_dir

In [9]:
train_dataset = TrainDataset(train_img_dir, 
                             train_gt_dir, 
                             transform=train_transform)


In [10]:
train_loader=DataLoader(train_dataset, 
                        shuffle=True, 
                        batch_size=bs, 
                        num_workers=0)

### Validation dataset

In [14]:
val_img_dir = r"D:\WaterSegmentation\Datasets\DANU_WS_v1\valid\images"
val_gt_dir = r"D:\WaterSegmentation\Datasets\DANU_WS_v1\valid\masks"

In [15]:
val_dataset = ValDataset(val_img_dir, 
                         val_gt_dir)


In [16]:
val_loader=DataLoader(val_dataset, 
                        shuffle=False, 
                        batch_size=1, 
                        num_workers=0)

## Fine tuning

### Setting

In [17]:
optimizer = torch.optim.Adam(sam_model.parameters(), 
                             lr=1e-5, 
                             weight_decay=0)

loss_fn = torch.nn.BCEWithLogitsLoss() #torch.nn.MSELoss()
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define batch size

num_epochs = 30
nw = 5 # warm up epochs
save_dir_name = "241118_SAM_ViT_b_dataset_V2_ft_v2"
model_name = "best.pth"

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 
                                              lr_lambda=lambda epoch: epoch / nw  if epoch < nw else 1)

In [18]:
save_dir = os.path.join('../runs',save_dir_name)
if not os.path.exists(save_dir): 
    os.makedirs(save_dir) 
    
log_path = os.path.join(save_dir, "train_logs.txt")

In [None]:
best_val_loss = float('inf')

recoder = TextWriter(log_path)

for epoch in range(num_epochs):
    epoch_train_loss, epoch_train_accuracy = train_one_epoch(model=sam_model,
                                                            data_loader=train_loader,
                                                            optimizer=optimizer,
                                                            loss_fn=loss_fn,
                                                             device=device)
    train_txt1 = f'[{epoch}] Mean training loss: {epoch_train_loss}'
    train_txt2 = f'[{epoch}] Mean training accuracy: {epoch_train_accuracy}'
    print(train_txt1)
    print(train_txt2)
    recoder.add_line(train_txt1+"\n")
    recoder.add_line(train_txt2+"\n")
    
    epoch_val_loss, epoch_val_accuracy = val_one_epoch(model=sam_model,
                                                       data_loader=val_loader,
                                                       loss_fn=loss_fn,
                                                       device=device)
    # Validation loop
    # Calculate mean validation loss for the current epoch
    val_txt1 = f'[{epoch}] Mean validation loss: {epoch_val_loss}'
    print(val_txt1)

    # Calculate mean validation accuracy for the current epoch
    val_txt2 = f'[{epoch}] Mean validation accuracy: {epoch_val_accuracy}'
    print(val_txt2)
    
    recoder.add_line(val_txt1+"\n")
    recoder.add_line(val_txt2+"\n")
    # Save the model checkpoint if the validation accuracy improves
    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        torch.save(sam_model.state_dict(), os.path.join(save_dir, model_name))

    # Clear GPU cache after each epoch
    scheduler.step()
    torch.cuda.empty_cache()


100%|██████████████████████████████████████████████████████████████████████████████| 1260/1260 [21:25<00:00,  1.02s/it]


[0] Mean training loss: 0.7099401380334581
[0] Mean training accuracy: 0.6989813804626465


100%|████████████████████████████████████████████████████████████████████████████████| 359/359 [01:21<00:00,  4.39it/s]


[0] Mean validation loss: 0.6781796523289428
[0] Mean validation accuracy: 0.7554335315247432


100%|██████████████████████████████████████████████████████████████████████████████| 1260/1260 [20:33<00:00,  1.02it/s]


[1] Mean training loss: 0.6949200826977927
[1] Mean training accuracy: 0.7161010253997077


100%|████████████████████████████████████████████████████████████████████████████████| 359/359 [01:22<00:00,  4.33it/s]


[1] Mean validation loss: 0.6931471824645996
[1] Mean validation accuracy: 0.6902136523743526


100%|██████████████████████████████████████████████████████████████████████████████| 1260/1260 [20:36<00:00,  1.02it/s]


[2] Mean training loss: 0.6893020647149237
[2] Mean training accuracy: 0.730247780254909


100%|████████████████████████████████████████████████████████████████████████████████| 359/359 [01:23<00:00,  4.32it/s]


[2] Mean validation loss: 0.6650617015561021
[2] Mean validation accuracy: 0.7687299723080606


100%|██████████████████████████████████████████████████████████████████████████████| 1260/1260 [20:30<00:00,  1.02it/s]


[3] Mean training loss: 0.6694524266180538
[3] Mean training accuracy: 0.7908324211362808


100%|████████████████████████████████████████████████████████████████████████████████| 359/359 [01:23<00:00,  4.27it/s]


[3] Mean validation loss: 0.6331521552419264
[3] Mean validation accuracy: 0.858461799727841


 29%|███████████████████████                                                        | 367/1260 [05:55<14:56,  1.00s/it]