In [2]:
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from model import *
from train import *
from data import *
from constants import *
import math


def inference(model, dataloader: DataLoader, outfile, device):
    model.eval()
    all_probabilities = None
    for imgs, _ in tqdm(dataloader):
        imgs = imgs.to(device)
        predictions = model(imgs)  # each: [bsz, num_cls]
        if all_probabilities is None:
            all_probabilities = predictions
        else:
            all_probabilities = torch.cat((all_probabilities, predictions), dim=0)
    return all_probabilities


transform = inference_transform
BATCH_SIZE = 256

dl_tt = get_dataloader(
    f"/mnt/slurm_home/pzzhao/acad_projects/AI6102_proj/test_dataset",
    BATCH_SIZE,
    shuffle=False,
    transform=default_transform,
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model_path = "/mnt/slurm_home/pzzhao/acad_projects/AI6102_proj/AI6102Project_ImageClassification/model_ckpt/resnet50_baseline/resnet50_2024-03-31_23-15_valLoss:0.862225.pth"
model_weights = torch.load(model_path, map_location=torch.device("cpu"))
# model = ResNetClassifier(num_classes=NUM_CLASSES)
# model.load_state_dict(model_weights)
model = model_weights
for param in model.parameters():
    param.requires_grad = False
model = model.to(device)
model.eval()
ret_pred = inference(model, dl_tt, None, device)
pass


cuda:0


100%|██████████| 510/510 [08:24<00:00,  1.01it/s]


In [3]:
import pandas as pd
import glob
import os
import torch.nn.functional as F

In [4]:

parent_folder = "/mnt/slurm_home/pzzhao/acad_projects/AI6102_proj/dataset/train"
subfolder_pattern = os.path.join(parent_folder, "*")
subfolders = glob.glob(subfolder_pattern)
subfolder_names = [os.path.basename(subfolder) for subfolder in subfolders]
subfolder_names.sort()

In [5]:
test_folder = f"/mnt/slurm_home/pzzhao/acad_projects/AI6102_proj/test_dataset/test/"
image_names = sorted(os.listdir(test_folder))

In [6]:
# predictions_np = ret_pred.detach().cpu().numpy()
predictions_np = F.softmax(ret_pred, dim=1).detach().cpu().numpy()
df = pd.DataFrame(predictions_np)
df.columns = subfolder_names
df.insert(0, 'image', image_names)

In [7]:
df.head(30)

Unnamed: 0,image,acantharia_protist,acantharia_protist_big_center,acantharia_protist_halo,amphipods,appendicularian_fritillaridae,appendicularian_s_shape,appendicularian_slight_curve,appendicularian_straight,artifacts,...,trichodesmium_tuft,trochophore_larvae,tunicate_doliolid,tunicate_doliolid_nurse,tunicate_partial,tunicate_salp,tunicate_salp_chains,unknown_blobs_and_smudges,unknown_sticks,unknown_unclassified
0,1.jpg,0.0001085782,1.48474e-10,6.747295e-08,7.872458e-07,3.28513e-07,0.0006255255,0.0001841697,0.003268667,0.0001383242,...,0.0004797008,2.663192e-10,2.922593e-05,2.074897e-05,3.594987e-06,2.520306e-06,3.914967e-07,0.006323696,0.1959405,0.00026
1,10.jpg,0.01338786,2.020615e-06,0.0002559622,0.003800535,6.845106e-07,0.0005415305,0.002278321,0.005113441,0.000135219,...,0.0006114368,1.03753e-07,4.41786e-05,2.212344e-05,1.832397e-05,7.267413e-06,5.360462e-07,0.068156,0.001703049,0.024166
2,100.jpg,1.830404e-06,1.197761e-10,1.57339e-08,1.809346e-11,4.047659e-12,1.794386e-10,2.649172e-09,7.630408e-12,1.57328e-11,...,3.85517e-09,4.241061e-12,6.405e-09,9.650774e-10,1.239446e-10,1.425681e-08,4.34913e-11,2.18673e-07,5.292073e-10,9e-06
3,1000.jpg,1.495715e-05,6.114016e-08,9.892386e-09,0.0002797813,5.685018e-07,8.875445e-06,2.222637e-06,1.442174e-06,3.825375e-08,...,1.836312e-06,1.520648e-08,7.430475e-07,3.241089e-07,1.32322e-08,1.67149e-06,4.57006e-08,0.0006727221,1.464648e-05,0.000329
4,10000.jpg,1.023699e-05,1.12921e-07,2.524817e-08,4.41723e-07,1.361943e-09,6.291485e-07,2.620985e-06,3.78366e-06,9.260813e-09,...,3.452062e-05,7.068257e-09,1.222074e-06,8.445111e-07,8.513336e-09,1.355627e-09,1.856336e-09,0.000157071,1.782556e-06,0.000149
5,100000.jpg,0.001148601,0.001451629,0.0005819164,9.414937e-06,8.677538e-07,0.0004072369,0.0004083955,0.0002293735,1.089523e-05,...,2.987174e-05,0.002135348,0.009837061,0.009066541,0.04044913,0.0001188001,0.0002113243,0.003751887,0.0006217179,0.02292
6,100001.jpg,0.0153278,0.0001376316,0.003537446,9.953244e-05,1.96045e-07,2.929238e-05,0.0001930409,3.405815e-05,7.100564e-08,...,0.000141068,6.974448e-05,0.008679834,0.006052343,0.0001102778,0.02301676,5.864248e-05,0.0001823249,4.112457e-05,0.082583
7,100002.jpg,0.001199299,4.226799e-05,7.096539e-06,0.0007768171,4.992001e-08,0.0002940171,0.0001527438,0.0003111475,1.792457e-05,...,0.001105587,2.151906e-06,0.001193378,0.004507849,0.00271853,7.890257e-05,3.159589e-05,0.04001616,0.02154914,0.037838
8,100003.jpg,0.0001793212,0.0003257158,0.0004525102,1.473753e-08,4.831446e-12,3.097357e-07,8.672565e-07,3.350078e-06,2.318723e-07,...,0.000225992,6.636952e-06,0.001480621,0.0003241975,0.0005616329,1.152997e-05,5.860618e-05,0.0001288862,1.59433e-05,0.009856
9,100004.jpg,0.04280015,2.065205e-06,1.468466e-06,4.814926e-07,4.666223e-10,1.950998e-06,4.721093e-06,1.142708e-05,0.0008476393,...,5.345318e-06,9.119943e-08,4.622084e-06,1.649759e-05,4.387001e-07,1.41524e-07,4.135537e-07,5.882248e-05,0.0002122485,0.00028


In [8]:
df.to_csv('resnet50_baseline.csv', index=False)