In [19]:
import torch

from src.data.project4.dataloader import get_loaders 
from src.models.project4.models import get_model
from src.utils import set_seed, get_optimizer
from src.models.project4.losses import get_loss

import matplotlib.pyplot as plt
import matplotlib.patches as patches

from omegaconf import OmegaConf

In [20]:
torch.cuda.is_available()

True

In [21]:
args = OmegaConf.create({
    'model_name': 'efficientnet_b4',
    'region_size': 224,
    'batch_size': 1,
    'optimizer': 'Adam',
    'loss': 'BCE',
    'data_path': '/work3/s194253/02514/project4_results/data_wastedetection',
    'use_super_categories': True,
    'lr': 1e-04,
    'out': False,
    'seed': 0,
    'verbose': False,
    'percentage_to_freeze': None,
})

device = torch.device('cuda:0')
device = torch.device('cpu')

In [22]:
set_seed(args.seed)

In [23]:
# Get data loaders with applied transformations
loaders, num_classes = get_loaders(
    dataset='waste', 
    batch_size=args.batch_size, 
    seed=args.seed, 
    num_workers=1,
    img_size = (512, 512),
    region_size = (args.region_size, args.region_size),
    use_super_categories=args.use_super_categories,
    root = args.data_path,
)

id2cat = loaders['train'].dataset.id2cat

loading annotations into memory...
Done (t=0.07s)
creating index...
index created!
loading annotations into memory...
Done (t=0.05s)
creating index...
index created!
loading annotations into memory...
Done (t=0.05s)
creating index...
index created!


In [24]:
loss_fun = get_loss(args.loss)
optimizer = get_optimizer(args.optimizer)

model = get_model(args.model_name, args, loss_fun, optimizer, out=args.out, num_classes=num_classes, region_size=(args.region_size, args.region_size), id2cat=loaders['train'].dataset.id2cat)

In [25]:
BASE_PATH = '/work3/s194253/02514/project4_results/logs/albertkjoller_efficientnet_crossentropy_resnet/resnet18/albertkjoller_efficientnet_crossentropy_resnet/resnet18/'


VERSION = 0
EPOCH_NAME = 'epoch=6_loss_val=0.3089.ckpt'

checkpoint_path = BASE_PATH + f'version_{VERSION}/checkpoints/' + EPOCH_NAME

model.load_from_checkpoint(checkpoint_path, loss_fun=loss_fun)
model.to(device)
print("Checkpoint loaded!")

Checkpoint loaded!


In [26]:
from src.utils import plot_SS
from torchvision.ops import box_iou, nms

In [None]:
batch_idx = 1
batch = loaders['test'].dataset.__getitem__(batch_idx)

# for each image
i = 0
(img, cat_ids, bboxes_data, pred_bboxes_data) = batch

# for each bounding box
(bboxes, regions)           = bboxes_data # - not available at this point
(pred_bboxes, pred_regions) = pred_bboxes_data

# Classify proposed regions
y_hat = model.forward(pred_regions.to(device))

# maximum probabilities
outputs = torch.nn.functional.softmax(y_hat, dim=1)
pred_prob, pred_cat = torch.max(outputs, 1)

print("pred_cat:", pred_cat)

# Applying NMS (remove redundant boxes)
keep_indices = nms(pred_bboxes.to(torch.float).to(device), pred_prob, 0.5).cpu()

# Computing AP
preds = {'boxes': pred_bboxes.cpu()[keep_indices][pred_cat.cpu()[keep_indices] != max(model.id2cat.keys())], 
        'scores': pred_prob.cpu()[keep_indices][pred_cat.cpu()[keep_indices]   != max(model.id2cat.keys())], 
        'labels': pred_cat.cpu()[keep_indices][pred_cat.cpu()[keep_indices]    != max(model.id2cat.keys())]} 

targets = {
    'boxes':  bboxes, 
    'labels': cat_ids.flatten()
}

In [58]:
preds['labels']

tensor([19,  3,  2, 17,  2, 24, 12,  3, 19, 19, 24, 12,  2, 17, 24, 17, 17,  3,
        21, 12, 17, 15, 14, 22,  2,  2,  7, 10, 24, 19,  3, 12, 24, 20, 26,  3,
        19, 27, 15, 24,  2,  2, 12, 12, 15, 13, 19,  2,  0, 19, 24, 24, 14, 19,
         3,  7, 17,  2, 18, 17, 17, 11, 10,  5, 27, 19,  2, 16, 19, 24])

In [31]:
plot_SS(
    img, 
    targets['boxes'].detach().cpu(), 
    targets['labels'].detach().cpu(), 
    preds['boxes'].detach().cpu(), 
    preds['labels'].detach().cpu(), 
    preds['scores'].detach().cpu(),
    i,
    batch_idx,
    id2cat,
    path = '/work3/s194253/02514/project4_results/predict_imgs'
)

In [None]:


fig = plt.figure()
ax = fig.add_subplot(111)

ax.imshow(img.permute(1,2,0))
for i, bbox in enumerate(preds['boxes'][preds['labels'] != 28]):
    rect = patches.Rectangle(
        (bbox[0].item(), bbox[1].item()), 
        width=(bbox[2] - bbox[0]).item(), 
        height=(bbox[3] - bbox[1]).item(), 
        linewidth=2, 
        edgecolor=f"r", 
        facecolor='none'
    )
    ax.add_patch(rect)

for i, bbox in enumerate(targets['boxes']):
    rect = patches.Rectangle(
        (bbox[0].item(), bbox[1].item()), 
        width=(bbox[2] - bbox[0]).item(), 
        height=(bbox[3] - bbox[1]).item(), 
        linewidth=2, 
        edgecolor=f"b", 
        facecolor='none'
    )
    ax.add_patch(rect)
    
plt.show()
plt.close()
fig.close()

In [None]:
checkpoint_path = '/work3/s194253/02514/project4_results/logs/albertkjoller_efficientnet/efficientnet_b4/albertkjoller_efficientnet/efficientnet_b4/version_2/checkpoints/epoch=28_val_loss=0.0000.ckpt'

model.load_from_checkpoint(checkpoint_path)