In [1]:
%reload_ext autoreload
%autoreload 2

import os
import yaml
import argparse
import warnings
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

import numpy as np
import cv2
import pandas as pd
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

# SAM imports
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
from segment_anything.utils.transforms import ResizeLongestSide

from ship_detector.scripts.utils import load_config
from ship_detector.scripts.train_sam import ShipSAMDataset, SAMShipSegmentation, collate_fn_sam
from ship_detector.scripts.train_vit import ViTShipClassifier

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
warnings.filterwarnings('ignore')
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


In [3]:
config_path = 'configs/sam.yaml'
manifest_path = 'data/airbus-ship-detection/train_ship_segmentations_v2.csv'
output_dir = 'outputs/sam_finetune'

In [4]:
config = load_config(config_path)

In [5]:
Path(output_dir).mkdir(parents=True, exist_ok=True)

In [6]:
config['vit_integration'].keys()

dict_keys(['use_vit_prompts', 'vit_checkpoint', 'confidence_weight', 'heatmap_threshold'])

In [7]:
config['data']['patch_size'], config['vit_integration']['use_vit_prompts']

(1024, False)

In [8]:
sam_transform = ResizeLongestSide(config['data']['patch_size'])

In [9]:
manifest_df = pd.read_csv(manifest_path)
manifest_df['has_ship'] = manifest_df['EncodedPixels'].notna().astype(int)
manifest_df['patch_path'] = manifest_df['ImageId'].apply(
    lambda x: f"data/airbus-ship-detection/train_v2/{x}"
)

In [10]:
train_df = manifest_df.sample(frac=0.8, random_state=42).reset_index(drop=True)
val_df = manifest_df.drop(train_df.index).reset_index(drop=True)

In [11]:
train_dataset = ShipSAMDataset(manifest_df=train_df, sam_transform=sam_transform, patch_size=config['data']['patch_size'], use_vit_prompts=False)
val_dataset = ShipSAMDataset(manifest_df=val_df, sam_transform=sam_transform, patch_size=config['data']['patch_size'], use_vit_prompts=False)

INFO:ship_detector.scripts.train_sam:Dataset initialized with 65324 ship patches
INFO:ship_detector.scripts.train_sam:Dataset initialized with 16366 ship patches


In [12]:
callbacks = [
    ModelCheckpoint(
        monitor='val_iou',
        dirpath=os.path.join(output_dir, 'checkpoints'),
        filename='sam-{epoch:02d}-{val_iou:.4f}',
        save_top_k=3,
        mode='max',
    ),
    EarlyStopping(
        monitor='val_iou',
        patience=5,
        mode='max',
        verbose=True,
    ),
]
train_logger = TensorBoardLogger(
    save_dir=output_dir,
    name='logs',
    version='sam_finetune',
)

In [19]:
model = SAMShipSegmentation(config)

In [14]:
trainer = pl.Trainer(
    max_epochs=config['finetune']['num_epochs'],
    accelerator='gpu',
    devices=1,
    callbacks=callbacks,
    logger=train_logger,
    log_every_n_steps=10,
    accumulate_grad_batches=4,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [16]:
train_loader = DataLoader(
    train_dataset,
    batch_size=config['finetune']['batch_size'],
    shuffle=True,
    num_workers=config['data']['num_workers'],
    collate_fn=collate_fn_sam
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config['finetune']['batch_size'],
    shuffle=False,
    num_workers=config['data']['num_workers'],
    collate_fn=collate_fn_sam
)

In [None]:
trainer.fit(model, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type      | Params | Mode 
-------------------------------------------------
0 | sam        | Sam       | 641 M  | eval 
1 | focal_loss | FocalLoss | 0      | train
2 | dice_loss  | DiceLoss  | 0      | train
3 | iou_loss   | IoULoss   | 0      | train
-------------------------------------------------
4.1 M     Trainable params
637 M     Non-trainable params
641 M     Total params
2,564.362 Total estimated model params size (MB)
3         Modules in train mode
438       Modules in eval mode


                                                                   

In [19]:
sam = sam_model_registry[config['model']['checkpoint_type']](checkpoint=config['model']['checkpoint_path'])

In [20]:
help(sam.prompt_encoder)

Help on PromptEncoder in module segment_anything.modeling.prompt_encoder object:

class PromptEncoder(torch.nn.modules.module.Module)
 |  PromptEncoder(embed_dim: int, image_embedding_size: Tuple[int, int], input_image_size: Tuple[int, int], mask_in_chans: int, activation: Type[torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.GELU'>) -> None
 |
 |  Method resolution order:
 |      PromptEncoder
 |      torch.nn.modules.module.Module
 |      builtins.object
 |
 |  Methods defined here:
 |
 |  __init__(self, embed_dim: int, image_embedding_size: Tuple[int, int], input_image_size: Tuple[int, int], mask_in_chans: int, activation: Type[torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.GELU'>) -> None
 |      Encodes prompts for input to SAM's mask decoder.
 |
 |      Arguments:
 |        embed_dim (int): The prompts' embedding dimension
 |        image_embedding_size (tuple(int, int)): The spatial size of the
 |          image embedding, as (H, W).
 