In [1]:
%reload_ext autoreload
%autoreload 2

import os
import sys
import cv2
import timm
import torch
import transformers
import numpy as np
import pandas as pd
import ruamel.yaml as yaml
import pytorch_lightning as pl
import matplotlib.pyplot as plt

from PIL import Image
from pathlib import Path
from sklearn.model_selection import train_test_split
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

from ship_detector.scripts.prepare_data import rle_decode, tile_geotiff, process_masks, create_synthetic_test
from ship_detector.scripts.train_vit import ShipPatchDataset, ViTShipClassifier, get_augmentation_transforms, create_data_loader

  from .autonotebook import tqdm as notebook_tqdm


In [21]:
config_path = "configs/vit.yaml"
manifest_path = "data/airbus-ship-detection/train_ship_segmentations_v2.csv"

output_dir = "outputs"

In [22]:
with open(config_path, 'rb') as f:
    config = yaml.YAML(typ='rt').load(f)
pl.seed_everything(config['data']['random_seed'])
Path(output_dir).mkdir(parents=True, exist_ok=True)

Seed set to 42


In [23]:
# train_loader, val_loader = create_data_loader(manifest_path, config)
df_mani = pd.read_csv(manifest_path)
df_mani['has_ship'] = df_mani['EncodedPixels'].apply(lambda x: 0 if pd.isna(x) else 1)
df_mani['patch_path'] = df_mani['ImageId'].apply(lambda x: f"data/airbus-ship-detection/train_v2/{x}")
train_df, val_df = train_test_split(
    df_mani,
    test_size=config['data']['val_split'],
    random_state=config['data']['random_seed'],
    stratify=df_mani['has_ship']
)

In [24]:
df_mani.shape, train_df.shape, val_df.shape

((231723, 4), (185378, 4), (46345, 4))

In [6]:
df_mani

Unnamed: 0,ImageId,EncodedPixels,has_ship,patch_path
0,00003e153.jpg,,0,data/airbus-ship-detection/train_v2/00003e153.jpg
1,0001124c7.jpg,,0,data/airbus-ship-detection/train_v2/0001124c7.jpg
2,000155de5.jpg,264661 17 265429 33 266197 33 266965 33 267733...,1,data/airbus-ship-detection/train_v2/000155de5.jpg
3,000194a2d.jpg,360486 1 361252 4 362019 5 362785 8 363552 10 ...,1,data/airbus-ship-detection/train_v2/000194a2d.jpg
4,000194a2d.jpg,51834 9 52602 9 53370 9 54138 9 54906 9 55674 ...,1,data/airbus-ship-detection/train_v2/000194a2d.jpg
...,...,...,...,...
231718,fffedbb6b.jpg,,0,data/airbus-ship-detection/train_v2/fffedbb6b.jpg
231719,ffff2aa57.jpg,,0,data/airbus-ship-detection/train_v2/ffff2aa57.jpg
231720,ffff6e525.jpg,,0,data/airbus-ship-detection/train_v2/ffff6e525.jpg
231721,ffffc50b4.jpg,,0,data/airbus-ship-detection/train_v2/ffffc50b4.jpg


In [25]:
# Get transform
aug_config = config['augmentation']

train_transforms = [
    transforms.RandomHorizontalFlip(p=aug_config.get('hflip_prob', 0.5)),
    transforms.RandomVerticalFlip(p=aug_config.get('vflip_prob', 0.5)),
]

if aug_config.get('rotation', False):
    train_transforms.append(transforms.RandomRotation(degrees=90))

if aug_config.get('color_jitter', False):
    train_transforms.append(
        transforms.ColorJitter(
            brightness=0.2,
            contrast=0.2,
            saturation=0.2,
            hue=0.1
        )
    )
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

train_transforms.extend([
    transforms.ToTensor(),
    normalize
])

val_transforms = transforms.Compose([
    transforms.ToTensor(),
    normalize
])

# train_transforms, val_transforms = transforms.Compose(train_transforms), transforms.Compose(val_transforms)

In [26]:
config

{'model': {'name': 'vit_base_patch16_224', 'pretrained': True, 'freeze_backbone_epochs': 3, 'download': False, 'local_weights_path': 'model_states', 'preprocessing_method': 'adaptive'}, 'training': {'batch_size': 128, 'max_epochs': 40, 'early_stopping_patience': 7, 'pos_weight': 4.0, 'use_weighted_sampler': True, 'precision': 16, 'small_ship_focus': True, 'mixup_alpha': 0.2}, 'optimizer': {'name': 'adamw', 'lr': 8e-05, 'weight_decay': 0.02}, 'scheduler': {'name': 'cosine', 'T_max': 30, 'eta_min': 1e-07, 'warmup_epochs': 5}, 'augmentation': {'hflip_prob': 0.5, 'vflip_prob': 0.5, 'rotation': True, 'color_jitter': True, 'gaussian_blur_prob': 0.1, 'brightness_range': [0.8, 1.2], 'contrast_range': [0.9, 1.1], 'ship_aware_crop': True, 'preserv_small_objects': True}, 'data': {'val_split': 0.2, 'num_workers': 4, 'random_seed': 42, 'target_size': 224, 'interpolation': 'area', 'apply_sharpening': True, 'sharpening_strength': 0.3, 'min_ship_pixel_ratio': 0.001, 'max_background_ratio': 0.95}, 'eva

In [27]:
# Create datasets
train_dataset = ShipPatchDataset(config=config, manifest_df=train_df, transform=train_transforms, is_training=True)
val_dataset = ShipPatchDataset(config=config, manifest_df=val_df, transform=val_transforms, is_training=False)

In [28]:
# handling imbalance class with weighted sampling
if config['training'].get('use_weighted_sampler', False):
    train_labels = train_df['has_ship'].values
    class_counts = np.bincount(train_labels)
    class_weights = 1.0 / class_counts
    sample_weights = class_weights[train_labels]
    
    sampler =WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(train_dataset),
        replacement=True
    )
    shuffle = False
else:
    sampler = None
    shuffle = True

In [29]:
train_loader = DataLoader(
    train_dataset,
    batch_size=config['training']['batch_size'],
    shuffle=shuffle,
    sampler=sampler,
    num_workers=config['data']['num_workers'],
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config['training']['batch_size'],
    shuffle=False,
    num_workers=config['data']['num_workers'],
    pin_memory=True,
)

In [37]:
# model = ViTShipClassifier(config)
# torch.save(model.state_dict(), f'model_states/pretrained/vit_base_patch16_224.pth')
config['model']['pretrained'] = False
model = ViTShipClassifier(config)
if config['model']['pretrained'] == False:
    state_dict = torch.load(f"model_states/pretrained/{config['model']['name']}.pth", map_location='cpu')
    model.load_state_dict(state_dict)

In [None]:
vit_checkpoint = "outputs/checkpoints/vit-epoch=04-val_acc=0.962.ckpt"
# checkpoint = torch.load(vit_checkpoint, map_location='cpu', weights_only=False)
# model.load_state_dict(checkpoint['state_dict'])

In [39]:
model.to('cuda' if torch.cuda.is_available() else 'cpu')

ViTShipClassifier(
  (model): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (norm): Identity()
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
       

In [40]:
callbacks = [
    ModelCheckpoint(
        dirpath=os.path.join(output_dir, 'checkpoints'),
        filename='vit-{epoch:02d}-{val_acc:.3f}',
        monitor='val_acc',
        mode='max',
        save_top_k=3,
        save_last=True,
    ),
    EarlyStopping(
        monitor='val_loss',
        patience=config['training']['early_stopping_patience'],
        mode='min'
    ),
    LearningRateMonitor(logging_interval='epoch')
]
logger = TensorBoardLogger(
    save_dir=output_dir,
    name='vit_logs'
)

In [41]:
trainer = pl.Trainer(
    max_epochs=config['training']['max_epochs'],
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    callbacks=callbacks,
    logger=logger,
    log_every_n_steps=10,
    deterministic=True,
    precision=config['training'].get('precision', 32)
)


`precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!

Using 16bit Automatic Mixed Precision (AMP)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, train_loader, val_loader, ckpt_path=vit_checkpoint)