In [1]:
import wandb
run = wandb.init()
artifact = run.use_artifact('vincekillerz/base-confidence-estimation-v2/saved_model:v5', type='model')
artifact_dir = artifact.download()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mvincekillerz[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m:   2 of 2 files downloaded.  


In [2]:
import os
import yaml 
from src.tools import check_file_path
config = yaml.safe_load(open(check_file_path(artifact_dir, 'config.yaml')))
model_path = check_file_path(artifact_dir, 'trained_model.pth')

In [3]:
import torch
from src.ml_orchestrator.dataset import COTDataset
from src.ml_orchestrator.loss.loss_builder import loss_builder
from src.ml_orchestrator.transforms.transforms_builder import TransformBuilder
from torch.utils.data import DataLoader

from src.models.model_builder import model_builder

transform_builder = TransformBuilder(config['transforms'])

dataset_folder = check_file_path("datasets",config['ml_orchestrator']['dataset_name'])

train_dataset = COTDataset(
    confidence=config['confidence'],
    root_dir=check_file_path(dataset_folder,"train"), 
    transform_input=transform_builder.build_transforms_inputs(),
    transform_common=transform_builder.build_transform_common(), 
    config=config, 
)

valid_manually_labelled_dataset = COTDataset(
    confidence=config['confidence'],
    root_dir=check_file_path(dataset_folder,"valid_manually_labelled"), 
    transform_input=transform_builder.build_transforms_inputs(),
    transform_common=transform_builder.build_transform_common(), 
    config=config, 
)

train_loader = DataLoader(
    train_dataset, 
    batch_size=config['ml_orchestrator']['batch_size'], 
    shuffle=True,
    pin_memory=True,
    num_workers=config['ml_orchestrator']['num_workers']
)
valid_loader = DataLoader(
    valid_manually_labelled_dataset, 
    batch_size=config['ml_orchestrator']['batch_size'], 
    shuffle=True,
    pin_memory=True,
    num_workers=config['ml_orchestrator']['num_workers']
)
model = model_builder(config['model_builder'])
model.to(config['ml_orchestrator']['device'])
# load model
model.load_state_dict(torch.load(model_path))

device = config['ml_orchestrator']['device']
criterion = loss_builder(config)

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from src.tools import clear_folder


def process_batch(batch,name):
    images = batch['image'].to(device)
    depths = batch['depth'].to(device)
    confidence = batch['confidence'].to(device)
    masks = batch['mask'].to(device)
    timestamps = batch['timestamp']
    outputs = model(images, depths)
    loss,confidence = criterion(outputs, masks, confidence, 0)
    outputs = outputs[0]

    segs = batch['seg'].to(device)
    for i in range(len(images)):
        confidence_segs = torch.zeros_like(confidence[i])
        unique, counts = torch.unique(segs[i], return_counts=True)
        for j in range(len(unique)):
            sum = torch.sum(confidence[i][segs[i] == unique[j]])
            
            ratio = sum / counts[j]
            if ratio > 0.25:
                confidence_segs[segs[i] == unique[j]] = 1
        # fig, ax = plt.subplots(1, 3)
        # ax[0].imshow(images[i].cpu().permute(1,2,0).numpy())
        # ax[1].imshow(confidence[i].squeeze(0).cpu().numpy())
        # ax[2].imshow(confidence_segs.squeeze(0).cpu().numpy())
        # plt.show()
        confidence_segs = confidence_segs.unsqueeze(0)
        confidence_segs_big = nn.functional.interpolate(confidence_segs, size=(480, 640), mode='nearest')
        np.save(os.path.join(dataset_folder,name,"confidence",f"{timestamps[i]}.npy"), confidence_segs_big.squeeze(0).squeeze(0).cpu().numpy())
        

with torch.inference_mode():
    clear_folder(os.path.join(dataset_folder,"train","confidence"))
    clear_folder(os.path.join(dataset_folder,"valid_manually_labelled","confidence"))
    for batch in train_loader:
        process_batch(batch,"train")
    for batch in valid_loader:
        process_batch(batch,"valid_manually_labelled")

