In [1]:
import sys, os
import torch, wandb
import torch.nn as nn
from torch.utils.data import DataLoader
sys.path.append(os.path.abspath(os.path.join(os.curdir, '..')))

In [2]:
from configs import unet_convnextv2_hard_aug_config as config
from models.unet_convnextv2 import Unet
from datasets.depth_dataset import DepthDataset
from utils.train_utils import train_model

  from .autonotebook import tqdm as notebook_tqdm
  original_init(self, **validated_kwargs)


In [12]:
# Set a fixed random seed for reproducibility
torch.manual_seed(config.random_seed+1)

train_full_dataset = DepthDataset(
    data_dir=os.path.join(config.dataset_path, 'train/train'),
    list_file=os.path.join(config.dataset_path, 'train_list.txt'), 
    transform=config.transform_train,
    target_transform=config.target_transform,
    has_gt=True,
    use_albumentations=True)
    
    # Create test dataset without ground truth
test_dataset = DepthDataset(
    data_dir=os.path.join(config.dataset_path, 'test/test'),
    list_file=os.path.join(config.dataset_path, 'test_list.txt'),
    transform=config.transform_val,
    has_gt=False,
    use_albumentations=True)  # Test set has no ground truth
    
# Split training dataset into train and validation
total_size = len(train_full_dataset)
train_size = int((1-config.val_part) * total_size)  
val_size = total_size - train_size    
    
train_dataset, val_dataset = torch.utils.data.random_split(
    train_full_dataset, [train_size, val_size]
)
val_dataset.transform = config.transform_val # I dont think we need to use augmentations for validation

# Create data loaders with memory optimizations
train_loader = DataLoader(
    train_dataset, 
    batch_size=config.train_bs, 
    shuffle=True, 
    num_workers=config.num_workers, 
    pin_memory=True,
    drop_last=True,
    persistent_workers=True
)
    
    
val_loader = DataLoader(
    val_dataset, 
    batch_size=config.val_bs, 
    shuffle=False, 
    num_workers=config.num_workers, 
    pin_memory=True
)
    
test_loader = DataLoader(
    test_dataset, 
    batch_size=config.val_bs, 
    shuffle=False, 
    num_workers=config.num_workers, 
    pin_memory=True
)

print(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}, Test size: {len(test_dataset)}")

Train size: 20375, Validation size: 3596, Test size: 650


In [4]:
model = config.model()
# #model = nn.DataParallel(model)


optimizer = config.optimizer(model.parameters())
print(f"Using device: {config.device}")
CIL/data/convnextv2_hardaug_mixedloss_cont/best_model_28.pt

Using device: cuda:3


In [7]:
exp_name = "convnextv2_hardaug_mixedloss_cont"
model.load_state_dict(torch.load("/home/v.lomtev/CIL/data/convnextv2_hardaug_mixedloss_cont/best_model_28.pt"))

<All keys matched successfully>

In [None]:
print("Starting training...")
exp_name = "convnextv2_hardaug_mixedloss_cont"
with wandb.init(project="CIL",
                save_code=True,
                notes=config.WANDB_NOTES):
    model = train_model(model, train_loader, val_loader,
                        config.loss, optimizer, config.epochs, config.device,
                       exp_path=os.path.join(config.dataset_path, exp_name),
                       mask_indicator=config.additional_params["MASK_INDICATOR"])

<All keys matched successfully>

In [8]:
import utils.train_utils as tu

In [10]:
exp_name = "convnextv2_hardaug_mixedloss_cont"

In [13]:
from utils.train_utils import evaluate_model
import importlib
importlib.reload(tu)
tu.evaluate_model(model, val_loader, config.device,
                  exp_path=os.path.join(config.dataset_path, exp_name))

Evaluating:  42%|████████████████████████████████████████████████████████████████▎                                                                                         | 94/225 [00:40<00:56,  2.32it/s]


KeyboardInterrupt: 

In [None]:
from utils.train_utils import evaluate_model
import importlib
importlib.reload(tu)
tu.evaluate_model(model, val_loader, config.device,
                  exp_path=os.path.join(config.dataset_path, exp_name))

In [14]:
importlib.reload(tu)
tu.generate_test_predictions(model, test_loader, config.device,
                             exp_path=os.path.join(config.dataset_path, exp_name))

Generating Test Predictions: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [00:09<00:00,  4.48it/s]


In [15]:
importlib.reload(tu)
tu.visualize_test_predictions(model, test_loader, config.device,
                              exp_path=os.path.join(config.dataset_path, exp_name))

Visualizing: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [04:40<00:00,  6.84s/it]
