In [1]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import pandas as pd
import random
from joblib import Parallel, delayed
import os
from typing import List, Union

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torch.nn.functional import threshold, normalize

from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2

ImportError: libopenh264.so.5: cannot open shared object file: No such file or directory

In [None]:
config = {
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'base_path': '/home/jo/Desktop/swdacon/open20', # change relative path of data
    'train_data': 'zeroto20.csv', # change train data csv name
    'test_data': 'test20.csv', # change test data csv name
    'seed': 42,
    'valid_size': 0.1,
    'early_stopping': 15,
    'scheduler': True,
    'train' : {
       'batch_size' : 1,
       'num_workers': 8,
       'epochs': 100,
       'lr': 0.0005,
    },
    'inference' : {
       'batch_size' : 8,
       'num_workers': 8,
       'threshold': 0.35,
    },
}
custom_transform = {
    'train':A.Compose([
        A.augmentations.crops.transforms.RandomCrop(224,224,p=1.0),
        # A.Normalize(),
        # ToTensorV2()
    ]),
    'valid':A.Compose([
        A.augmentations.crops.transforms.CenterCrop(224,224,p=1.0),
        # A.Normalize(),
        # ToTensorV2()
    ]),
    'test': A.Compose([
        # A.Normalize(),
        # ToTensorV2()
    ]),
}

In [None]:
config['device']

In [None]:
# 시드 고정 함수
def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = True  # type: ignore

# RLE 디코딩 함수
def rle_decode(mask_rle: Union[str, int], shape=(224, 224)) -> np.array:
    '''
    mask_rle: run-length as string formatted (start length)
    shape: (height,width) of array to return
    Returns numpy array, 1 - mask, 0 - background
    '''
    if mask_rle == -1:
        return np.zeros(shape)

    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)

# RLE 인코딩 함수
def rle_encode(mask):
    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

# dice score 계산 함수
def dice_score(prediction: np.array, ground_truth: np.array, smooth=1e-7) -> float:
    '''
    Calculate Dice Score between two binary masks.
    '''
    intersection = np.sum(prediction * ground_truth)
    return (2.0 * intersection + smooth) / (np.sum(prediction) + np.sum(ground_truth) + smooth)

def calculate_dice_scores(validation_df, img_shape=(224, 224)) -> List[float]:
    '''
    Calculate Dice scores for a dataset.
    '''
    # Extract the mask_rle columns
    pred_mask_rle = validation_df.iloc[:, 3]
    gt_mask_rle = validation_df.iloc[:, 4]

    def calculate_dice(pred_rle, gt_rle):
        pred_mask = rle_decode(pred_rle, img_shape)
        gt_mask = rle_decode(gt_rle, img_shape)
        if np.sum(gt_mask) > 0 or np.sum(pred_mask) > 0:
            return dice_score(pred_mask, gt_mask)
        else:
            return None  # No valid masks found, return None
    dice_scores = Parallel(n_jobs=-1)(
        delayed(calculate_dice)(pred_rle, gt_rle) for pred_rle, gt_rle in zip(pred_mask_rle, gt_mask_rle)
    )
    dice_scores = [score for score in dice_scores if score is not None]  # Exclude None values
    return np.mean(dice_scores)

class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)

        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)

        return 1 - dice

In [None]:
# 로컬 사용시 활성화

train_df = pd.read_csv(f"{config['base_path']}/{config['train_data']}")
train, val = train_test_split(train_df, test_size=config['valid_size'], random_state=config['seed'])
print("train: ", len(train), "   valid: ", len(val))

### model

In [None]:
model_type = 'vit_b'
checkpoint = 'sam_hq_vit_b.pth'

In [None]:
from segment_anything import SamPredictor, sam_model_registry
sam_model = sam_model_registry[model_type](checkpoint='/home/jo/Desktop/swdacon/sam-hq/checkpoints/'+checkpoint)
sam_model = sam_model.to(config['device'])

### dataloader

In [None]:
from segment_anything.utils.transforms import ResizeLongestSide
# from mobile_sam.utils.transforms import ResizeLongestSide
from collections import defaultdict

In [None]:
class CustomDataset(Dataset):
    def __init__(self, img_paths, mask_rles = None,infer=False,transform=None):
        self.img_paths = img_paths
        self.mask_rles = mask_rles
        self.infer = infer
        self.preprocess = ResizeLongestSide(sam_model.image_encoder.img_size)
        self.transform = transform
    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        transformed_data = defaultdict(dict)
        img_path = self.img_paths.iloc[idx]
        image = cv2.imread(config['base_path']+img_path[1:])
        image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
        if self.transform:
            image = self.transform(image=image)['image']
        transform = ResizeLongestSide(sam_model.image_encoder.img_size)
        input_image = transform.apply_image(image)
        input_image_torch = torch.as_tensor(input_image, device=config['device'])
        transformed_image = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :].to(config['device'])
        input_image = sam_model.preprocess(transformed_image)
        original_image_size = image.shape[:2]
        input_size = tuple(transformed_image.shape[-2:])
        transformed_data['image'] = input_image
        transformed_data['input_size']=input_size
        transformed_data['original_image_size']= original_image_size

        mask_rle = self.mask_rles.iloc[idx]
        gt_mask = rle_decode(mask_rle, (image.shape[0], image.shape[1]))
        if self.transform:
            gt_mask = self.transform(image = image, mask=gt_mask)['mask']
        return transformed_data, gt_mask

In [None]:
fix_seed(config['seed'])

train_dataset = CustomDataset(img_paths=train['img_path'], mask_rles=train['mask_rle'],transform = custom_transform['train'])
train_dataloader = DataLoader(train_dataset, batch_size=config['train']['batch_size'], shuffle=True, num_workers=config['train']['num_workers'])

valid_dataset = CustomDataset(img_paths=val['img_path'], mask_rles=val['mask_rle'],transform = custom_transform['valid'])
valid_dataloader = DataLoader(valid_dataset , batch_size=config['train']['batch_size'], shuffle=True, num_workers=config['train']['num_workers'])

In [None]:
print(train_dataset[0])

In [None]:
def predict_sam(input_image,input_size, original_image_size, sam):
    image_embedding = sam.image_encoder(input_image)
    sparse_embeddings, dense_embeddings = sam.prompt_encoder(
      points=None,
      boxes=None,
      masks=None,
    )
    low_res_masks, iou_predictions = sam.mask_decoder(
      image_embeddings=image_embedding,
      image_pe=sam_model.prompt_encoder.get_dense_pe(),
      sparse_prompt_embeddings=sparse_embeddings,
      dense_prompt_embeddings=dense_embeddings,
      multimask_output=False,
    )
    upscaled_masks = sam.postprocess_masks(low_res_masks, input_size, original_image_size).to(config['device'])
    binary_mask = normalize(threshold(upscaled_masks, 0.0, 0))
    return binary_mask

In [None]:
def validation(config, model, criterion, valid_loader, val):
    model.eval()
    valid_loss = 0
    result = []
    transformed_mask = []
    val_df = val.copy()

    with torch.no_grad():
        for transformed_data, gt_masks in tqdm(valid_loader):
            if type(transformed_mask) == torch.Tensor:
                transformed_mask = torch.cat([transformed_mask, gt_masks])
            else:
                transformed_mask = gt_masks.clone().detach()
            output_masks=[]
            for index in range(len(gt_masks)):
                input_image = transformed_data['image'][index].to(config['device'])
                input_size = transformed_data['image'][index].to(config['device'])
                original_image_size = transformed_data['image'][index].to(config['device'])
                gt_mask = gt_masks[index]
                output = predict_sam(input_image,input_size,original_image_size,model)
                loss = criterion(output, gt_mask.unsqueeze(1))
                valid_loss += loss.item()
                output_masks.append(output)

            for i in range(len(gt_masks)):
                mask_rle = rle_encode(output_masks[i])
                if mask_rle == '': # 예측된 건물 픽셀이 아예 없는 경우 -1
                    result.append(-1)
                else:
                    result.append(mask_rle)
        val_df['valid_mask_rle'] = result
        val_df['transformed_mask_rle'] = list(map(rle_encode, transformed_mask.squeeze().numpy()))
        dice_score = calculate_dice_scores(val_df)

    return valid_loss/len(valid_loader), dice_score

In [None]:
def training(config, model, train_loader, valid_loader, val):
    model = model.to(config['device'])
    es_count = 0
    min_val_loss = float('inf')
    best_model = None

    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config['train']['lr'])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=4, min_lr=1e-8, verbose=True)

    # training loop
    for epoch in range(config['train']['epochs']):
        model.train()
        epoch_loss = 0
        for transformed_data, gt_masks in tqdm(train_loader):
            for index in range(len(gt_masks)):
                input_image = transformed_data['image'][index].to(config['device'])
                input_size = transformed_data['image'][index].to(config['device'])
                original_image_size = transformed_data['image'][index].to(config['device'])
                gt_mask = gt_masks[index]
                output = predict_sam(input_image,input_size,original_image_size,model)
                loss = criterion(output, gt_mask.unsqueeze(1))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
        val_loss, dice_score = validation(config, model, criterion, valid_loader, val)
        es_count += 1
        if min_val_loss > val_loss:
            es_count = 0
            min_val_loss = val_loss
            best_model = model
            best_epoch = epoch + 1
            print(f"Epoch [{epoch + 1}] New Minimum Valid Loss!")

        if config['scheduler']:
            scheduler.step(val_loss)

        if es_count == config['early_stopping']:
            print(f'Epoch {epoch+1}, Train Loss: {(epoch_loss/len(train_loader)):6f}, Valid Loss: {val_loss:6f}, Dice Coefficient: {dice_score:6f}, ES Count:, {es_count}')
            print(f"EARLY STOPPING COUNT에 도달했습니다! \nEARLY STOPPING COUNT: {config['early_stopping']} BEST EPOCH: {best_epoch}")
            print("***TRAINING DONE***")
            return best_model

        print(f'Epoch {epoch+1}, Train Loss: {(epoch_loss/len(train_loader)):6f}, Valid Loss: {val_loss:6f}, Dice Coefficient: {dice_score:6f} ES Count:, {es_count}')
        print("------------------------------------------------------------------------------------")

    print(f"EARLY STOPPING COUNT에 도달하지 않았습니다! \nEARLY STOPPING COUNT: {config['early_stopping']} BEST EPOCH: {best_epoch}")
    print("***TRAINING DONE***")
    return best_model

In [None]:
torch.multiprocessing.set_start_method('spawn')
best_model = training(config, sam_model,train_dataloader, valid_dataloader, val)