## Load configuration

In [None]:
!wandb login

In [None]:
import sys, os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

In [None]:
from src.utils.config_loader import load_config
cfg = load_config("../config.yaml")
data_cfg = cfg['data']
model_cfg = cfg['model']
training_cfg = cfg['training']

model_parameters = {
    'n':{'csp': [False, True], 'depth' : [1, 1, 1, 1, 1, 1], 'width' : [3, 16, 32, 64, 128, 256]},
    's':{'csp': [False, True], 'depth' : [1, 1, 1, 1, 1, 1], 'width' : [3, 32, 64, 128, 256, 512]},
    'm':{'csp': [True, True], 'depth' : [1, 1, 1, 1, 1, 1], 'width' : [3, 64, 128, 256, 512, 512]},
    'l':{'csp': [True, True], 'depth' : [2, 2, 2, 2, 2, 2], 'width' : [3, 64, 128, 256, 512, 512]},
    'x':{'csp': [True, True], 'depth' : [2, 2, 2, 2, 2, 2], 'width' : [3, 96, 192, 384, 768, 768]},
}

## Load Data Sample

In [None]:
from src.data.data_loader import get_data_loaders

In [None]:
train_loader, val_loader = get_data_loaders(
    "../" + data_cfg['train_parquet'],
    "../" + data_cfg['val_parquet'],
    "../" + data_cfg['train_images'],
    "../" + data_cfg['train_images'],
    # training_cfg['batch_size']
    4,
    isTest=True
)
print("Loaded train and validation data loaders")

In [None]:
len(train_loader), len(val_loader)

In [None]:
images, targets = next(iter(train_loader))
len(images), len(targets)

In [None]:
from src.data.visualization import visualize_comparison

visualize_comparison(images[0], targets[0])

## Training loop

In [None]:
import torch
from src.training.train_model import train
from src.model.losses import YoloDFLQFLoss
from src.training.utils_train import get_optimizer

In [None]:
device = 'cpu'

if torch.cuda.is_available():
    device = 'cuda'
# elif torch.mps.is_available():
#     device = 'mps'
print(f"Device - {device}")

In [None]:
import torch
from src.model.model_builder import Model

model = Model(**model_parameters['m'], num_classes=model_cfg['num_classes']).to(device)

In [None]:
optimizer, scheduler = get_optimizer(
    model,
    training_cfg['learning_rate'],
    training_cfg['weight_decay'],
    training_cfg['learning_rate_patience'],
    training_cfg['learning_rate_factor']
)

In [None]:
criterion = YoloDFLQFLoss(
    num_classes=model_cfg['num_classes'],
    lambda_box=training_cfg.get('lambda_box', 1.5),
    lambda_cls=training_cfg.get('lambda_cls', 1.0)
)

In [None]:
from src.training.wandb_setup import setup_wandb

wandb_config = cfg["wandb"]
config = {
    # "gpu": world_size,
    **training_cfg
}
wandb_run = setup_wandb(config, wandb_config)

In [None]:
train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    criterion=criterion,
    # num_epochs=training_cfg["epochs"],
    num_epochs=2,
    device=device,
    num_classes=model_cfg['num_classes'],
    rank=0,
    use_wandb=True if wandb_run else False,
    wandb_instance=wandb_run,
    log_interval=training_cfg.get('log_interval', 10),
    checkpoint_dir=os.path.join("..", training_cfg.get('checkpoint_dir', 'experiments/checkpoints')),
    iou_threshold=training_cfg.get('iou_threshold', 0.5),
    conf_threshold=training_cfg.get('conf_threshold', 0.25)
)

In [None]:
wandb_run.finish()

In [None]:
# ## Initialize the model
# import torch
# from src.model.model_builder import DetectionModel
# model = DetectionModel(num_classes=cfg["project"]["num_classes"]).cuda()
# from torchinfo import summary
# summary(model, input_size=(1, 3, 640, 640))
# from torchview import draw_graph

# draw_graph(model, input_size=(1, 3, 640, 640))
# # model.eval()
# # with torch.no_grad():
# #     preds = model(images.cuda())
# # preds
# # targets[0]['boxes'].cpu().numpy()
# # from src.data.visualization import visualize_comparison

# # visualize_comparison(images[0], targets[0], prediction=preds[0])

In [None]:
# ## Model Initialization
# import torch
# from src.model.model_builder import Model

# example_input = (1, 3, 640, 640)
# example_data = torch.rand(example_input).cuda()
# def create_model(obj_params):
#     model = Model(**obj_params, num_classes=171)
#     return model.cuda()
# def visualize_graph(model, input_data, filename: str = 'Model'):
#     from torchview import draw_graph
#     from IPython.display import Image, display

#     model_graph = draw_graph(model, input_data=input_data, graph_name=filename, save_graph=True, expand_nested=True, depth=10)
#     # img_bytes = model_graph.visual_graph.render(format='png')
#     # # display(Image(data=img_bytes, format='png'))
# model = create_model(model_parameters['n'])
# model.eval()
# output = model(example_data)
# output.shape
# preds = output.transpose(1, 2)
# preds.shape
# pred_box = preds[0][:, 0:4]
# pred_scores = preds[0][:, 4:]
# pred_box.shape, pred_scores.shape

# model = create_model(model_parameters['n'])
# visualize_graph(model, example_data, "Model - nano")
# model = create_model(model_parameters['s'])
# visualize_graph(model, example_data, "Model - small")
# model = create_model(model_parameters['m'])
# visualize_graph(model, example_data, "Model - medium")
# model = create_model(model_parameters['l'])
# visualize_graph(model, example_data, "Model - large")
# model = create_model(model_parameters['x'])
# visualize_graph(model, example_data, "Model - xlarge")