In [1]:
import numpy as np
import pandas as pd
import torch
import torchvision
from torch.autograd import Variable

In [2]:
try:
    import baseline_model
    import self_defined_dataset
except:
    !jupyter nbconvert --to script baseline_model.ipynb
    !jupyter nbconvert --to script self_defined_dataset.ipynb
finally:
    import baseline_model
    import self_defined_dataset

In [3]:
if torch.cuda.is_available():
    CUDA_USAGE = True
else:
    CUDA_USAGE = False

In [4]:
transform_val_list = [
    torchvision.transforms.Resize((256, 256), interpolation=3),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
data_transform = torchvision.transforms.Compose(transform_val_list)

In [5]:
img_root_path = "/home/extension/kaggle/APTOS_2019_Blindness_Detection/test_images"
img_csv = "/home/extension/kaggle/APTOS_2019_Blindness_Detection/test.csv"

In [6]:
train_dataset = self_defined_dataset.Blindness(img_root_path, img_csv, True, data_transform, eval_flag=True)
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=False, num_workers=4)

id_code    0005cfc8afb6
Name: 0, dtype: object


In [11]:
model = baseline_model.Blindness(5)
if CUDA_USAGE:
    model = model.cuda()
    model.eval()
else:
    pass

model.load_state_dict(torch.load("/home/hdd/hdD_Git/kaggle/image_data/classification/aptos_2019_blindness_detection/models/49.pth"))
torch.save(model, "./models/submission.pth")

In [8]:
ID_CODE = []
DIAGNOSIS = []

In [9]:
for data in dataloader:
    with torch.no_grad():
        input_data, data_name = data
        if CUDA_USAGE:
            input_data = Variable(input_data.cuda())
            
        outputs = model(input_data)
        indices = torch.max(outputs, 1).indices
        
        ID_CODE.extend(data_name)
        DIAGNOSIS.extend(list(indices.cpu().detach().numpy()))
        
        torch.cuda.empty_cache()

In [10]:
submission = pd.DataFrame({
    "id_code": np.array(ID_CODE),
    "diagnosis": np.array(DIAGNOSIS)
})

submission.to_csv("submission.csv", index=False)