In [1]:
import Inference
import csv
import Augmentation
import Model
import torch
import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from Dataset import _read_image
from Metrics import find_threshold

# device
device = torch.device('cuda:0')
# path
dataset_path = '/mnt/train-data1/howard/cvfinal/human-protein-atlas-image-classification/'
model_path = '/mnt/train-data1/howard/cvfinal/model/v5/epoch14.pth'
# transform
eval_transform = Augmentation.get_eval_transform()
# load validation set
_, validation_dataset = Dataset.get_train_val_dataset(dataset_path + 'train.csv',
                                                      28, img_folder=dataset_path + 'train/', 
                                                      transform=eval_transform)
validation_data_loader = DataLoader(validation_dataset, batch_size=16, shuffle=False, num_workers=36)
# model
model = Model.get_model(28, model_path)
model.eval()
model.to(device)

# find threshold
with torch.no_grad():
    y_true, y_pred = [], []
    for img, _ in tqdm(validation_data_loader):
        y_pred.append(model(img.to(device)).to('cpu'))
    y_pred = torch.sigmoid(torch.cat(y_pred))
    threshold = find_threshold(y_pred)
    for i, thr in enumerate(threshold):
        print(i, thr)

# test
output_str = 'Id,Predicted\n'
test_csv = dataset_path + 'sample_submission.csv'
with open(test_csv, newline='') as fp:
    rows = csv.reader(fp)
    first_row = True
    for row in tqdm(rows):
        if first_row:
            first_row = False
            continue
            
        image_id = row[0]
        img = eval_transform(_read_image(image_id, dataset_path + 'test/'))
        label = Inference.inference(model, img, device=device, threshold=threshold)
        label = [str(l) for l in label]
        
        output_str += image_id + ',' + ' '.join(label) + '\n'
with open('output.csv', 'w') as fp:
    fp.write(output_str)

100%|██████████| 292/292 [00:35<00:00,  8.33it/s]
5it [00:00, 49.00it/s]

0 0.324
1 0.057999999999999996
2 0.264
3 0.406
4 0.74
5 0.314
6 0.448
7 0.09999999999999999
8 0.838
9 0.5539999999999999
10 0.8019999999999999
11 0.401
12 0.422
13 0.132
14 0.128
15 0.05
16 0.05199999999999999
17 0.09699999999999999
18 0.12100000000000001
19 0.311
20 0.05499999999999999
21 0.16799999999999998
22 0.311
23 0.628
24 0.323
25 0.433
26 0.185
27 0.41300000000000003


11703it [04:04, 47.90it/s]
