In [1]:
import pandas as pd
from PIL import Image
import consts
import torch as t
from plant_pathology_dataset import PlantPathologyTestDataset
import os.path as path
from torchvision import transforms
import torch.utils.data as td
from torchutils import evaluate

In [2]:
model_name = "baseline_model"

In [3]:
with open(f"./models/{model_name}.pkl", "rb") as f:
    model = t.load(f)

In [4]:
imgroot = path.join(consts.DATAROOT, "cache", "250x250")
means = (0.4038582976214318, 0.5127894672998029, 0.3129764558236694)
stds = (0.2034616086724042, 0.18909514150453344, 0.18761408366900625)
xform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(means, stds)
])
testset = PlantPathologyTestDataset(imgroot, xform)

In [5]:
dl = td.DataLoader(testset, batch_size=128, shuffle=False)

In [6]:
metrics = evaluate(model, dl, [])

In [7]:
all_probs = t.nn.Softmax()(metrics["outputs"])
all_probs.shape

torch.Size([10, 4])

In [9]:
with open(f"./submissions/{model_name}.csv", "wt") as f:
    print("image_id,healthy,multiple_diseases,rust,scab", file=f)
    for i in range(all_probs.shape[0]):
        image_id = testset.image_ids[i]
        probs = all_probs[i, :]
        healthy = probs[consts.HEALTHY]
        multiple_diseases = probs[consts.MULTIPLE_DISEASES]
        rust = probs[consts.RUST]
        scab = probs[consts.SCAB]
        print(f"{image_id},{healthy:.3f},{multiple_diseases:.3f},{rust:.3f},{scab:.3f}", file=f)

In [None]:
trainer.model.eval()
with open("./baseline_submission.csv", "wt") as f:
    print("image_id,healthy,multiple_diseases,rust,scab", file=f)
    for row in testset.itertuples():
        img_file = f"{row.image_id}.jpg"
        img_path = path.join(imgroot, img_file)
        img = Image.open(img_path)
        img = xform(img)
        img = t.unsqueeze(img, 0)
        logits = trainer.model(img)[0]
        probs = softmax(logits)
        healthy = probs[consts.HEALTHY]
        multiple_diseases = probs[consts.MULTIPLE_DISEASES]
        rust = probs[consts.RUST]
        scab = probs[consts.SCAB]
        print(f"{row.image_id},{healthy:.3f},{multiple_diseases:.3f},{rust:.3f},{scab:.3f}", file=f)