In [2]:
import sys
import os
sys.path.append('../..')
import torch
import torchvision
from src.models import ResNet18
from src.transforms import ImageResizer,make_patches
import matplotlib.pyplot as plt
import pandas as pd
from torchsummary import summary

In [3]:
model = ResNet18(n_classes=3)

In [4]:
# dataloader = torchvision.datasets.ImageFolder()

In [5]:
def predict(model, dataloader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    predictions = []

    with torch.no_grad():
        for patch in dataloader:
            patch = patch.to(device)
            outputs = model(patch)
            _, preds = torch.max(outputs, 1)
            predictions.extend(preds.cpu().numpy())

    return predictions

In [10]:
def vote(voting_type, predictions, classes_weights=None):
    if classes_weights is None:
        classes_weights = [0, 0, 0]
    if voting_type == "soft":
        return torch.argmax(torch.mean(predictions, dim=0))
    elif voting_type == "hard":
        one_hot_predictions = torch.zeros_like(predictions)
        one_hot_predictions.scatter_(1, torch.argmax(predictions, dim=1).unsqueeze(1), 1)
        return torch.argmax(torch.sum(one_hot_predictions, dim=0))
    elif voting_type == "weighted":
        weighted_sum = torch.sum(predictions * torch.tensor(classes_weights), dim=0)
        return torch.argmax(weighted_sum)

In [23]:
# testing
import torch.distributions.dirichlet as dirichlet

num_rows = 100
num_classes = 3

samples = dirichlet.Dirichlet(torch.tensor([1.0]*num_classes)).sample((num_rows,))

print(vote("soft", samples))
print(vote("hard", samples))
print(vote("weighted", samples, classes_weights=[2,2,2]))

tensor(1)
tensor(1)
tensor(1)
