# Main model loading / prediction

This notebook is adapted from `predict.py` to perform manual model loading / prediction

In [10]:
import os
import csv
import json
import torch
import torchvision

from PIL import Image
from tqdm import tqdm

from src.model import Textboxes, ResNet, SSD
from src.dataset import CustomDataset, collate_fn
from torch.utils.data import DataLoader
from src.transform import SSDTransformer
from src.utils import generate_dboxes, Encoder

def save_predictions(predictions, dataset_name, model_name):
    output_dir = os.path.join("predictions", dataset_name)
    os.makedirs(output_dir, exist_ok = True)

    # Define the name of the CSV file you want to create
    filename = os.path.join(output_dir, model_name + ".csv")

    # Define the fieldnames for the first row of the CSV file
    columns = ["image_id", "category_id", "bbox", "score"]

    # Open the CSV file in "write" mode
    with open(filename, mode = "w", newline = "") as csv_file:
        # Create a writer object using the csv library
        writer = csv.DictWriter(csv_file, fieldnames = columns)

        # Write the first row (i.e., the fieldnames)
        writer.writeheader()
        for prediction in predictions:
            prediction['bbox'] = [int(v) for v in prediction['bbox']]
            # Write the content of your dictionary to the CSV file
            writer.writerow(prediction)
    
    return filename

def load_data(data_path, dataset_name):
    with open(os.path.join(data_path, "annotations", dataset_name + ".json")) as json_file:
        coco = json.load(json_file)

    return [(image['file_name'], image['id']) for image in coco['images']]

def load_model(model_name, ckpt_path, backbone = 'RN512', img_size = 512, truncate = False, device = 'cpu'):
    if "SSD" in model_name:
         model = SSD(model_name, truncate, backbone = ResNet(backbone), figsize = img_size, num_classes = 2)
    else:
        model = Textboxes(model_name, truncate, backbone = ResNet(backbone), figsize = img_size, num_classes = 2)

    checkpoint   = torch.load(ckpt_path, map_location = device)

    model_state_dict = {k.replace('module.', '') : v for k, v in checkpoint["model_state_dict"].items()}
    model.load_state_dict(model_state_dict)
    model.to(device)
    model.eval()
    
    return model


## Model loading

In [11]:
data_path    = '/storage/medical_text_images'
dataset_name = ''
ckpt_path    = 'SSD512_R152.pth'
model_name   = 'SSD'

img_size  = 512
truncate  = False
backbone  = 'ResNet512'

nms_threshold = 0.2

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

assert os.path.exists(ckpt_path), 'Checkpoint not found !'

boxes   = generate_dboxes(model_name, truncate, img_size)
encoder = Encoder(boxes)

transformer = SSDTransformer(boxes, (img_size, img_size), val = True)

model = load_model(
    model_name, ckpt_path, truncate = truncate, img_size = img_size, backbone = backbone, device = device
)

In [12]:
print(model)

SSD(
  (feature_extractor): ResNet(
    (feature_extractor): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=T

## Prediction on dataset

In [None]:
samples = load_data(data_path, dataset_name)

predictions = []
print("processing samples")
for filepath, img_id in tqdm(samples):
    img = Image.open(os.path.join(data_path, dataset_name, filepath)).convert("RGB")
    width, height = img.size

    img, _, _, _ = transformer(img, None, torch.zeros(1, 4), torch.zeros(1))
    if torch.cuda.is_available(): img = img.to(device)

    with torch.no_grad():
        # Get predictions
        ploc, plabel = model(img.unsqueeze(0))
        ploc, plabel = ploc.float(), plabel.float()
        # scores, candidates = encoder.get_matched_idx(ploc, plabel,opt.nms_threshold , 200)[0]
        # print(scores)
        # print(candidates)
        result = encoder.decode_batch(ploc, plabel, nms_threshold , 200)[0]

        loc, label, prob = [r.cpu().numpy() for r in result]

        for loc_, label_, prob_ in zip(loc, label, prob):
            xmin, ymin, w, h = loc_[0] * width, loc_[1] * height, (loc_[2] - loc_[0]) * width, (loc_[3] - loc_[1]) * height
            pred = {"image_id" : img_id, "category_id" : category_ids[label_ - 1], "bbox":  [xmin, ymin, w, h], "score" : prob_}
            predictions.append(pred)

save_predictions(predictions)