In [1]:
%load_ext autoreload
%autoreload 2

# Main

There is a better way to train or test model with:
```
uv run run.py mode=train
```

But if you want you can run it from this jupyter notebook

In [None]:
from omegaconf import OmegaConf
from chest_segment.dataset import (
    ChestDataset,
    ChestDatasetConfig,
    get_dataloaders,
)
from chest_segment.transforms import (
    get_mask_transforms,
    get_image_transforms,
    ChestMaskTransformsConfig,
    ChestImageTransformsConfig,
    get_all_transforms,
    ChestAllTransformsConfig,
)
from chest_segment.models import get_model_from_config
from chest_segment.utils import get_optimizer, get_loss, get_metrics
from chest_segment.train import train
from chest_segment.test import test, evaluate, visualize

# Config

You can change any params that you want in config

In [None]:
cfg = OmegaConf.load("config/config.yaml")
print(OmegaConf.to_yaml(cfg))

# Dataset

In [11]:
dataset_config = ChestDatasetConfig(**cfg.dataset)
dataset = ChestDataset(dataset_config)

# Train

Logs also save in logs dir. And you can run it with Tensorbnoard

In [12]:
dataset.set_transforms(
    image_transforms=get_image_transforms(
        ChestImageTransformsConfig(**cfg.train.image_transforms)
    )
    if cfg.train.image_transforms
    else None,
    mask_transforms=get_mask_transforms(
        ChestMaskTransformsConfig(**cfg.train.mask_transforms)
    )
    if cfg.train.mask_transforms
    else None,
    all_transforms=get_all_transforms(
        ChestAllTransformsConfig(**cfg.train.all_transforms)
    )
    if cfg.train.all_transforms
    else None,
)
train_loader, val_loader, test_loader = get_dataloaders(
    dataset, dataset_config
)
model = get_model_from_config(cfg)
criterion = get_loss(cfg)
metrics = get_metrics(cfg)
optimizer = get_optimizer(model, cfg)

In [None]:
train(
    model, train_loader, val_loader, criterion, optimizer, metrics, cfg.train
)

# Test

In [14]:
dataset.set_transforms(
    image_transforms=get_image_transforms(
        ChestImageTransformsConfig(**cfg.test.image_transforms)
        if cfg.test.image_transforms
        else None
    ),
    mask_transforms=get_mask_transforms(
        ChestMaskTransformsConfig(**cfg.test.mask_transforms)
        if cfg.test.mask_transforms
        else None
    ),
    all_transforms=get_all_transforms(
        ChestAllTransformsConfig(**cfg.test.all_transforms)
        if cfg.test.all_transforms
        else None
    ),
)
train_loader, val_loader, test_loader = get_dataloaders(
    dataset, dataset_config
)

In [None]:
test(model, test_loader, criterion, metrics, cfg.test)

# Evaluate

In [None]:
test_dataset = dataset.from_split(split="test", config=dataset.config)
idxes = list(range(5))
preds = evaluate(model, test_dataset, idxes=idxes, device=cfg.device)
visualize(preds, test_dataset, idxes=idxes)