In [36]:
import torch
from segment_anything import SamPredictor, sam_model_registry
from torch.nn.functional import threshold, normalize
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import cv2
import os
import numpy as np
import random
import pandas as pd
from tqdm import tqdm

### 기본설정

In [37]:
config = {
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'base_path': 'C:/SWdacon/minseo/data/', # change relative path of data
    'train_data': 'train.csv', # change train data csv name
    'test_data': 'test20.csv', # change test data csv name
    'seed': 42,
    'valid_size': 0.3,
    'early_stopping': 3,
    'train' : {
       'batch_size' : 4,
       'num_workers': 1,
       'epochs': 5,
       'lr': 1e-4,
       'wd':0  ## weight decay
    },
    'inference' : {
       'batch_size' : 4,
       'num_workers': 1,
       'threshold': 0.35,
    },
}

In [38]:
sam_model = sam_model_registry['vit_h'](checkpoint='C:/SAM/segment-anything/checkpoints/sam_vit_h_4b8939.pth')
optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(),lr=config['train']['lr']) 
loss_fn = torch.nn.MSELoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [39]:
# 시드 고정 함수
def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = True  # type: ignore

# RLE 디코딩 함수
def rle_decode(mask_rle, shape):
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)

# RLE 인코딩 함수
def rle_encode(mask):
    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

### dataset

In [40]:
from segment_anything.utils.transforms import ResizeLongestSide
from collections import defaultdict

In [41]:
class CustomDataset(Dataset):
    def __init__(self, img_paths, mask_rles = None,infer=False):
        self.img_paths = img_paths
        self.mask_rles = mask_rles
        self.infer = infer
        self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        transformed_data = defaultdict(dict)
        img_path = self.img_paths.iloc[idx]
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.infer:
            input_image = self.transform.apply_image(image)
            input_image_torch = torch.as_tensor(input_image, device=device)
            transformed_image = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
            
            input_image = sam_model.preprocess(transformed_image)
            original_image_size = image.shape[:2]
            input_size = tuple(transformed_image.shape[-2:])

            transformed_data['image'] = input_image
            transformed_data['input_size'] = input_size
            transformed_data['original_image_size'] = original_image_size
            return transformed_data

        mask_rle = self.mask_rles.iloc[idx]
        mask = rle_decode(mask_rle, (image.shape[0], image.shape[1])) ## shape 바뀌어야 하는지 확인
        gt_mask_resized = torch.from_numpy(np.resize(mask, (1, 1, mask.shape[0], mask.shape[1]))).to(device)
        gt_binary_mask = torch.as_tensor(gt_mask_resized > 0, dtype=torch.float32)
        input_image = self.transform.apply_image(image)
        input_image_torch = torch.as_tensor(input_image, device=device)
        transformed_image = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
        
        input_image = sam_model.preprocess(transformed_image)
        original_image_size = image.shape[:2]
        input_size = tuple(transformed_image.shape[-2:])

        transformed_data['image'] = input_image
        transformed_data['input_size'] = input_size
        transformed_data['original_image_size'] = original_image_size

        return transformed_data, gt_binary_mask

In [42]:
train_df = pd.read_csv(f"{config['base_path']}/{config['train_data']}")
train, val = train_test_split(train_df, test_size=config['valid_size'], random_state=config['seed'])
print("train: ", len(train), "   valid: ", len(val))

train:  4998    valid:  2142


In [43]:
fix_seed(config['seed'])

train_dataset = CustomDataset(img_paths=train['img_path'], mask_rles=train['mask_rle'])
train_dataloader = DataLoader(train_dataset, batch_size=config['train']['batch_size'], shuffle=True, num_workers=config['train']['num_workers'])

valid_dataset = CustomDataset(img_paths=val['img_path'], mask_rles=val['mask_rle'])
valid_dataloader = DataLoader(valid_dataset , batch_size=config['train']['batch_size'], shuffle=True, num_workers=config['train']['num_workers'])

### train/vali

In [44]:
def validation(config, model, criterion, valid_loader):
    model.eval()
    valid_loss = 0

    with torch.no_grad():
        for transformed_datas, gt_masks in tqdm(valid_loader):
            input_image = transformed_datas['image'].to(device)
            input_size = transformed_datas['intput_size']
            original_image_size = transformed_datas['original_image_size']
            masks = masks.to(config['device'])

            with torch.no_grad():
                image_embedding = sam_model.image_encoder(input_image)
                sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
                    points=None,
                    boxes=None,
                    masks=None,
                )
            low_res_masks, iou_predictions = sam_model.mask_decoder(
                image_embeddings=image_embedding,
                image_pe=sam_model.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
            )
            upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)
            binary_mask = normalize(threshold(upscaled_masks, 0.0, 0))
            loss = criterion(binary_mask, gt_masks)
            valid_loss += loss.item()

    return valid_loss/len(valid_loader)

In [45]:
def training(config, model, train_loader, valid_loader):
    model = model.to(config['device'])
    es_count = 0
    min_val_loss = float('inf')
    best_model = None
    
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config['train']['lr'])

    # training loop
    for epoch in range(config['train']['epochs']):
        model.train()
        epoch_loss = 0
        for transformed_datas, gt_masks in tqdm(train_loader):
            input_image = transformed_datas['image'].to(device)
            input_size = transformed_datas['intput_size']
            original_image_size = transformed_datas['original_image_size']
            masks = masks.to(config['device'])

            with torch.no_grad():
                image_embedding = sam_model.image_encoder(input_image)
                sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
                    points=None,
                    boxes=None,
                    masks=None,
                )
            low_res_masks, iou_predictions = sam_model.mask_decoder(
                image_embeddings=image_embedding,
                image_pe=sam_model.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
            )
            upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)
            binary_mask = normalize(threshold(upscaled_masks, 0.0, 0))
            optimizer.zero_grad()
            loss = criterion(binary_mask, gt_masks)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        val_loss = validation(config, model, criterion, valid_loader)
        es_count += 1
        if min_val_loss > val_loss:
            es_count = 0
            min_val_loss = val_loss
            best_model = model
            best_epoch = epoch
            print(f"Epoch [{epoch + 1}] New Minimum Valid Loss!")
        
        if es_count == config['early_stopping']:
            print(f"EARLY STOPPING COUNT: {config['early_stopping']} BEST EPOCH: {best_epoch}")
            return best_model
        
        print(f'Epoch {epoch+1}, Train Loss: {epoch_loss/len(train_loader)}', 'Valid Loss:', val_loss, 'ES Count:', es_count)
        print("------------------------------------------------------------------------------------")
    
    print(f"EARLY STOPPING COUNT에 도달하지 않았습니다! \nEARLY STOPPING COUNT: {config['early_stopping']} BEST EPOCH: {best_epoch}")
    return best_model