In [9]:
%reload_ext autoreload
%autoreload 2

import os
import sys
import cv2
import timm
import torch
import transformers
import numpy as np
import pandas as pd
import ruamel.yaml as yaml
import pytorch_lightning as pl
import matplotlib.pyplot as plt

from PIL import Image
from pathlib import Path
from sklearn.model_selection import train_test_split
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

from ship_detector.scripts.prepare_data import rle_decode, tile_geotiff, process_masks, create_synthetic_test
from ship_detector.scripts.train_vit import ShipPatchDataset, ResNetShipClassifier, ViTShipClassifier, get_augmentation_transforms, create_data_loader


In [3]:
config_path = "configs/resnet.yaml"
manifest_path = "data/airbus-ship-detection/train_ship_segmentations_v2.csv"
output_dir = "outputs"

In [4]:
with open(config_path, 'rb') as f:
    config = yaml.YAML(typ='rt').load(f)
pl.seed_everything(config['data']['random_seed'])
Path(output_dir).mkdir(parents=True, exist_ok=True)

Seed set to 42


In [7]:
train_loader, val_loader = create_data_loader(manifest_path, config)

Training samples: 185378
  - With ships: 65378 (35.3%)
Validation samples: 46345
  - With ships: 16345 (35.3%)


In [12]:
if 'vit' in config['model']['name']:
    pre_name = 'vit'
    model = VitShipClassifier(config)
else:
    pre_name = 'resnet'
    model = ResNetShipClassifier(config)

In [None]:
# torch.save(model.state_dict(), f'model_states/pretrained/resnet50.pth')

In [16]:
callbacks = [
        ModelCheckpoint(
            dirpath=os.path.join(output_dir, 'checkpoints'),
            filename=pre_name + '-{epoch:02d}-{val_acc:.3f}',
            monitor='val_loss',
            mode='max',
            save_top_k=3,
            save_last=True
        ),
        EarlyStopping(
            monitor='val_loss',
            patience=config['training']['early_stopping_patience'],
            mode='min',
        ),
        LearningRateMonitor(logging_interval='epoch')
    ]
logger = TensorBoardLogger(
        save_dir=output_dir,
        name=pre_name + '_logs',
    )

In [17]:
trainer = pl.Trainer(
        max_epochs=config['training']['max_epochs'],
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=1,
        callbacks=callbacks,
        logger=logger,
        log_every_n_steps=10,
        deterministic=True,
        precision=config['training'].get('precision', 32)
    )

/Users/lin.yang/Documents/GitHub/ship-detector/.venv/lib/python3.12/site-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
/Users/lin.yang/Documents/GitHub/ship-detector/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:508: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/lin.yang/Documents/GitHub/ship-detector/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


In [18]:
trainer.fit(model, train_loader, val_loader)

/Users/lin.yang/Documents/GitHub/ship-detector/.venv/lib/python3.12/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision bf16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name      | Type              | Params | Mode 
--------------------------------------------------------
0 | model     | ResNet            | 23.5 M | train
1 | criterion | BCEWithLogitsLoss | 0      | train
--------------------------------------------------------
2.0 K     Trainable params
23.5 M    Non-trainable params
23.5 M    Total params
94.040    Total estimated model params size (MB)
218       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]



Epoch 0:   9%|▉         | 134/1449 [9:10:14<89:59:43,  0.00it/s, v_num=0]  


Detected KeyboardInterrupt, attempting graceful shutdown ...


: 