In [None]:
import argparse
import json
import os
from pathlib import Path

import albumentations as A
import wandb
from torch.utils.data import DataLoader

from tracenet import get_train_transform, get_valid_transform, collate_fn
from tracenet.datasets import FilamentDetection
from tracenet.models.criterion import Criterion
from tracenet.models.detr import build_model
from tracenet.models.matcher import HungarianMatcher
from tracenet.utils import get_model_name
from tracenet.utils.train import train

### Specify parameters

In [None]:
data_dir = 'mtdata'
model_path = rf'models/mt/checkpoint.pth'
maxsize = 630

In [None]:
train_dir = 'train'
val_dir = 'val'

wandb_project = 'Test'
log_progress = False

In [None]:
config = dict(
    epochs=20,
    batch_size=2,
    lr=0.0001,
    weight_decay=0.0005,
    factor=0.1,
    patience=2,
    model_path=model_path,
    log_progress=log_progress,
    wandb_project=wandb_project,
    data_dir=data_dir,
    train_dirname=train_dir,
    val_dirname=val_dir,
    bbox_loss_coef=5, 
    maxsize=maxsize,
    n_points=10
    )
config

### Initialize wandb project

In [None]:
config = argparse.Namespace(**config)
if config.log_progress:
    with open('/home/amedyukh/.wandb_api_key') as f:
        key = f.read()
    os.environ['WANDB_API_KEY'] = key
else:
    os.environ['WANDB_MODE'] = 'offline'

wandb.init(project=config.wandb_project, config=vars(config))

# Update model path
config.model_path = os.path.join(config.model_path, get_model_name(config.log_progress))

# Save training parameters
os.makedirs(config.model_path, exist_ok=True)
with open(os.path.join(config.model_path, 'config.json'), 'w') as f:
    json.dump(vars(config), f, indent=4)
          


### Setup data loaders

In [None]:
path = Path(data_dir)

ds = []
for dset, transform in zip([train_dir, val_dir], [get_train_transform, get_valid_transform]): 
    files = os.listdir(path / dset / 'img')
    files.sort()
    ds.append(
        FilamentDetection(
            [path / dset / 'img' / fn for fn in files],
            [path / dset / 'gt' / fn.replace('.tif', '.csv') for fn in files],
            transforms=transform(keypoint_params=A.KeypointParams(format='xy', 
                                                                  label_fields=['point_labels'], 
                                                                  remove_invisible=False, 
                                                                  angle_in_degrees=True)),
            maxsize=maxsize, n_points=config.n_points
        )
    )
ds_train, ds_val = ds

dl_train = DataLoader(ds_train, shuffle=True, collate_fn=collate_fn,
                      batch_size=config.batch_size, num_workers=config.batch_size)
dl_val = DataLoader(ds_val, shuffle=False, collate_fn=collate_fn,
                    batch_size=config.batch_size, num_workers=config.batch_size)

### Setup model, loss and metric

In [None]:
model = build_model(n_classes=1, n_points=config.n_points, pretrained=True)
loss_function = Criterion(1, HungarianMatcher(), losses=['labels', 'boxes', 'cardinality'])

### Train

In [None]:
%%time
train(dl_train, dl_val, model, loss_function, config=config, log_tensorboard=True)

In [None]:
wandb.finish()