In [38]:
from src.data.data_module import DepthEstimationDataModule
from src.utils import download_dataset
from src.models import Unet3Plus
from src.train import LightningModel,get_lr_scheduler_kwargs
import torchmetrics
from dataclasses import dataclass
import os


In [39]:
@dataclass
class Config:
    data_dir = "../data"
    batch_size: int = 16
    num_workers: int = 2
    encoder_name: str = "maxxvitv2_nano_rw_256.sw_in1k"
    decoder_attention_type: str = "scse"
    head_activation_name: str = "sigmoid"
    optimizer: str = "Adam"
    learning_rate: float = 1e-4
    accumulate_grad_batches: int = 1
    loss_function_name = "MeanAbsoluteError"
config = Config()

In [40]:
# download_dataset(destination_path=config.data_dir)

In [41]:
config.data_dir = os.path.join(config.data_dir,"data") #TODO fix zip file in gdrive

In [42]:
lr_scheduler_kwargs = get_lr_scheduler_kwargs(data_dir=os.path.join(config.data_dir, "train", "image"),batch_size=config.batch_size, accumulate_grad_batches=config.accumulate_grad_batches)

In [43]:
data_module = DepthEstimationDataModule(data_dir=os.path.join(config.data_dir,"data"),batch_size=config.batch_size,num_workers=config.num_workers,transforms=None)

In [44]:
model = Unet3Plus(encoder_name=config.encoder_name,classes=1,activation=config.head_activation_name, decoder_attention_type=config.decoder_attention_type)

In [45]:
lightning_model = LightningModel(model=model,optimizer=config.optimizer, learning_rate=config.learning_rate,loss=getattr(torchmetrics,config.loss_function_name), lr_scheduler_params=lr_scheduler_kwargs)