In [1]:
import os

os.chdir('../')
os.getcwd()

'/home/shivansh/repos/assembly_glovebox_dataset'

In [2]:
import torch
import torch.nn.functional as F

from training.lightning_trainers.lightning_model import LitModel
from training.lightning_trainers.mobilesam_model import MobileSamLitModel
from training.lightning_trainers.bisenetv2_model import BiSeNetV2Model

from lightning.pytorch import loggers as pl_loggers
from training.dataloaders.datamodule import AssemblyDataModule
import lightning.pytorch as pl

import time
import os

In [3]:
test_dropout = False


def get_path(model_type):
    if model_type == "unet":
        path_1 = "./lightning_logs/version_73/checkpoints/epoch=1011-step=8096.ckpt"
        model_1 = LitModel.load_from_checkpoint(path_1, map_location=torch.device('cuda:0'), test_dropout=test_dropout)

        path_2 = "./lightning_logs/version_76/checkpoints/epoch=1856-step=29712.ckpt"
        model_2 = LitModel.load_from_checkpoint(path_2, map_location=torch.device('cuda:0'), test_dropout=test_dropout)

        path_3 = "./lightning_logs/version_78/checkpoints/epoch=3223-step=51584.ckpt"
        model_3 = LitModel.load_from_checkpoint(path_3, map_location=torch.device('cuda:0'), test_dropout=test_dropout)

    if model_type == "bisenetv2":
        path_1="./lightning_logs/version_79/checkpoints/epoch=3373-step=26992.ckpt"
        model_1 = BiSeNetV2Model.load_from_checkpoint(path_1, map_location=torch.device('cuda:1'), test_dropout=test_dropout)

        
        path_2="./lightning_logs/version_77/checkpoints/epoch=1083-step=8672.ckpt"
        model_2 = BiSeNetV2Model.load_from_checkpoint(path_2, map_location=torch.device('cuda:1'), test_dropout=test_dropout)

        path_3="./lightning_logs/version_72/checkpoints/epoch=1440-step=11528.ckpt"
        model_3 = BiSeNetV2Model.load_from_checkpoint(path_3, map_location=torch.device('cuda:1'), test_dropout=test_dropout)

    if model_type == "mobilesam":
        
        path_1="./lightning_logs/version_80/checkpoints/epoch=3027-step=24224.ckpt"
        model_1 = MobileSamLitModel.load_from_checkpoint(path_1, map_location=torch.device('cuda:2'), test_dropout=test_dropout)
        
        path_2="./lightning_logs/version_75/checkpoints/epoch=1849-step=29600.ckpt"
        model_2 = MobileSamLitModel.load_from_checkpoint(path_2, map_location=torch.device('cuda:2'), test_dropout=test_dropout)

        path_3="./lightning_logs/version_69/checkpoints/epoch=1312-step=10704.ckpt"
        model_3 = MobileSamLitModel.load_from_checkpoint(path_3, map_location=torch.device('cuda:2'), test_dropout=test_dropout)

    return [model_1, model_2, model_3]

# 1. Warm up the gpu

TODO: 
- Edit back CUDA device change in get_prediction in mobilesam_model.py

# 2. Load an image to use from dataset

In [4]:
image_0 = torch.randn(1, 3, 256, 256).cuda(device='cuda:0')
image_1 = torch.randn(2, 3, 256, 256).cuda(device='cuda:1')
image_2 = torch.randn(1, 3, 256, 256).cuda(device='cuda:2')

In [5]:
from training.dataloaders.datamodule import AssemblyDataModule


In [6]:
batch_size = 64

data = {
    "fit_query": {
        "participants": [
            "Test_Subject_1",
            "Test_Subject_2",
            "Test_Subject_3",
            "Test_Subject_4",
            "Test_Subject_6",
            "Test_Subject_7",
            "Test_Subject_9",
            "Test_Subject_10",
            "Test_Subject_11",
            "Test_Subject_12"
        ],
        "distribution": ["id"],
        "task": ["J", "TB"],
        "view": ["Top_View", "Side_View"]
    },
    "test_query": {
        "participants": [
            "Test_Subject_1",
            "Test_Subject_2",
            "Test_Subject_3",
            "Test_Subject_4",
            "Test_Subject_6",
            "Test_Subject_7",
            "Test_Subject_9",
            "Test_Subject_10",
            "Test_Subject_11",
            "Test_Subject_12"
        ],
        "distribution": ["id"],
        "task": ["J", "TB"],
        "view": ["Top_View", "Side_View"]
    },
}


id_data = data.copy()
id_data['test_query']['distribution'] = ["id"]
id_data['test_query']['participants'] = ["Test_Subject_2"]

dm = AssemblyDataModule(
    test_query = id_data.get("test_query"),
    fit_query = id_data.get("fit_query"),
    batch_size=batch_size,
    img_size=256
)


ood_data = data.copy()
ood_data['test_query']['distribution'] = ["ood"]

ood_dm = AssemblyDataModule(
    test_query = ood_data.get("test_query"),
    fit_query = ood_data.get("fit_query"),
    batch_size=batch_size,
    img_size=256
)

In [7]:
dm.setup(stage="test")

ood_dm.setup(stage="test")

the query is {'participants': ['Test_Subject_2'], 'distribution': ['ood'], 'task': ['J', 'TB'], 'view': ['Top_View', 'Side_View']}
the query is {'participants': ['Test_Subject_2'], 'distribution': ['ood'], 'task': ['J', 'TB'], 'view': ['Top_View', 'Side_View']}


In [8]:
x = None

for batch in dm.test_dataloader():
    x, y = batch
    print(x.shape)
    break

torch.Size([24, 3, 256, 256])


In [9]:
ood_x = None

for ood_batch in ood_dm.test_dataloader():
    ood_x, ood_y = ood_batch
    print(ood_x.shape)
    break

torch.Size([24, 3, 256, 256])


# 3. Test Time Predictions for Each Model

### unet

In [72]:
models = get_path("unet")

ensemble_time = 0
for model in models:
    start_time = time.time()

    output = model.model(x.to(device=model.device))

    end_time = time.time()

    ensemble_time += end_time - start_time

ensemble_time / x.shape[0]

0.007290085156758626

In [70]:
start_time = time.time()

model_1 = models[1]

output = model_1.model(x.to(device=model_1.device))

end_time = time.time()

base_time =  end_time - start_time

base_time / x.shape[0]

0.0003467102845509847

ood

In [32]:
## testing ood images

models = get_path("unet")

ensemble_time = 0
for model in models:
    start_time = time.time()

    output = model.model(ood_x.to(device=model.device))

    end_time = time.time()

    ensemble_time += end_time - start_time

ensemble_time / ood_x.shape[0]

0.007293144861857097

In [34]:
start_time = time.time()

model_1 = models[1]

output = model_1.model(ood_x.to(device=model_1.device))

end_time = time.time()

base_time =  end_time - start_time

base_time / ood_x.shape[0]

0.00029810269673665363

### bisenet

In [36]:
models = get_path("bisenetv2")

ensemble_time = 0
for model in models:
    start_time = time.time()

    output = model.model(x.to(device=model.device))

    end_time = time.time()

    ensemble_time += end_time - start_time

ensemble_time / x.shape[0] # bisenet needs 2 images passed

0.001775225003560384

In [39]:
start_time = time.time()

model_1 = models[0]

output = model_1.model(x.to(device=model_1.device))

end_time = time.time()

base_time =  end_time - start_time

base_time / x.shape[0]

0.0008324285348256429

ood

In [63]:
for model in models:
    start_time = time.time()

    output = model.model(ood_x.to(device=model.device))

    end_time = time.time()

    ensemble_time += end_time - start_time

ensemble_time / ood_x.shape[0]

0.0020158588886260986

In [66]:
models = get_path("bisenetv2")

ensemble_time = 0

start_time = time.time()

model_1 = models[0]

output = model_1.model(ood_x.to(device=model_1.device))

end_time = time.time()

base_time =  end_time - start_time

base_time / ood_x.shape[0]

0.00044297178586324054

# mobilesam

In [17]:
image_2.device, model.device

(device(type='cuda', index=2), device(type='cuda', index=1))

In [51]:
models = get_path("mobilesam")

ensemble_time = 0
for model in models:
    
    start_time = time.time()

    output = model._get_preds(x.to(device=model.device))

    end_time = time.time()

    ensemble_time += end_time - start_time

ensemble_time / x.shape[0]

0.04516427715619405

In [50]:
start_time = time.time()

model_1 = models[1].to(image_2.device)

output = model_1._get_preds(x.to(device=model_1.device))

end_time = time.time()

base_time =  end_time - start_time

base_time / x.shape[0]

0.014566481113433838

ood

In [54]:
start_time = time.time()

model_1 = models[1].to(image_2.device)

output = model_1._get_preds(ood_x.to(device=model_1.device))

end_time = time.time()

base_time =  end_time - start_time

base_time / ood_x.shape[0]


0.014891356229782104

In [53]:
models = get_path("mobilesam")

ensemble_time = 0
for model in models:
    
    start_time = time.time()

    output = model._get_preds(ood_x.to(device=model.device))

    end_time = time.time()

    ensemble_time += end_time - start_time

ensemble_time / ood_x.shape[0]

0.045158336559931435