In [1]:
import torch
from parksim.intent_predict.cnnV2.network import SimpleCNN, RegularizedCNN, SmallRegularizedCNN
from parksim.intent_predict.cnnV2.utils import CNNDataset, CNNGroupedDataset
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm

import os
from datetime import datetime
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
sns.set_theme(style="darkgrid")

In [2]:

def get_predictions(model_path, dji_num):
    model = SmallRegularizedCNN()
    model_state = torch.load(model_path)
    model.load_state_dict(model_state)
    model.eval().cuda()
    dataset = CNNGroupedDataset(f"../data/DJI_{dji_num}", input_transform = transforms.ToTensor())
    dataloader = DataLoader(dataset, batch_size=32, num_workers=12)
    data_size = len(dataset)
    running_top_1_accuracy = 0
    running_top_3_accuracy = 0
    running_top_5_accuracy = 0
    for batch in tqdm(dataloader):
        for data in batch:
            img_feature, non_spatial_feature, labels = data
            img_feature = img_feature.cuda()
            non_spatial_feature = non_spatial_feature.cuda()
            labels = labels.cuda()
            num_options = labels.shape[0]
            #model.forward(img_feature, non_spatial_feature)
            #inputs, labels = data[0].to(device), data[1].to(device)

            #optimizer.zero_grad()

            preds = model(img_feature, non_spatial_feature)
            labels = labels.unsqueeze(1)
            label = torch.argmax(labels)
            preds = torch.nn.functional.sigmoid(preds)
            preds = torch.topk(preds, min(5, num_options))
            pred_indices = preds.indices
            if label in pred_indices[:1]:
                running_top_1_accuracy += 1 / data_size
            if label in pred_indices[:min(3, num_options)]:
                running_top_3_accuracy += 1 / data_size
            if label in pred_indices[:min(5, num_options)]:
                running_top_5_accuracy += 1 / data_size

            print(label, pred_indices)
            
    print(f"Top 1 Accuracy: {running_top_1_accuracy}")
    print(f"Top 3 Accuracy: {running_top_3_accuracy}")
    print(f"Top 5 Accuracy: {running_top_5_accuracy}")
    return running_top_1_accuracy, running_top_3_accuracy, running_top_5_accuracy

NameError: name 'main' is not defined