In [None]:
%reload_ext autoreload
%autoreload 2

from ship_detector.scripts.train_vit_efficient import (create_efficient_data_loaders,
                                                       EfficientViTClassifier,
                                                       MemoryMonitor)

import tensorboard as tb
# %load_ext tensorboard

import os
import cv2
import yaml
import timm
import pandas as pd
import torch
import torch.nn as nn
import torch_tb_profiler as tbp
from pathlib import Path
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping

In [None]:
# # tb.notebook.start(args_string='--logdir ./logs --bind_all')
# %tb --logdir ./logs --port 6006 --host localhost

In [None]:
config_path = "configs/vit_efficient.yaml"
manifest_path = "data/airbus-ship-detection/train_ship_segmentations_v2.csv"
output_dir = "outputs/efficient"

In [None]:
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)
pl.seed_everything(config.get('seed', 42))
Path(output_dir).mkdir(parents=True, exist_ok=True)

In [None]:
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True

In [None]:
memory_monitor = MemoryMonitor()
for key, value in memory_monitor.get_memory_usage().items():
    print(f"{key}: {value:.2f} GB")

In [None]:
train_loader, val_loader = create_efficient_data_loaders(
    manifest_path=manifest_path,
    config=config,
    memory_monitor=memory_monitor,
)

In [None]:
model = EfficientViTClassifier(config)

In [None]:
config['data']['use_streaming']

In [None]:
callbacks = [
    ModelCheckpoint(
        dirpath=os.path.join(output_dir, 'lora/checkpoints'),
        filename='vit-{epoch:02d}-{val_acc:.3f}',
        monitor='val_acc',
        mode='max',
        save_top_k=2,
        save_last=True
    ),
    EarlyStopping(
        monitor='val_loss',
        patience=config['training']['early_stopping_patience'],
        mode='min',
        verbose=True,
        strict=False,
        check_finite=True
    ),
    LearningRateMonitor(logging_interval='epoch')
]

In [None]:
logger = TensorBoardLogger(
    save_dir='outputs/efficient/tb_logs',
    name='efficientvit'
)

In [None]:
trainer = pl.Trainer(
    max_epochs=config['training']['max_epochs'],
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    callbacks=callbacks,
    logger=logger,
    accumulate_grad_batches=config['training'].get('accumulate_grad_batches', 1),
    gradient_clip_val=config['training'].get('gradient_clip_val', 1.0),
    precision=config['training'].get('precision', 16),
    log_every_n_steps=10,
    val_check_interval=config['training'].get('val_check_interval', 1.0),
    limit_train_batches=config['training'].get('limit_train_batches', 1.0),
    limit_val_batches=config['training'].get('limit_val_batches', 1.0)
)

In [None]:
trainer.fit(model, train_loader, val_loader)