In [1]:
import sys
# package_dir = "../input/aptos-models"
package_dir = './'
sys.path.insert(0, package_dir)
import numpy as np
import pandas as pd
import torchvision
import torch.nn as nn
from tqdm import tqdm
from PIL import Image, ImageFile
from torch.utils.data import Dataset
import torch
from torchvision import transforms
import os

device = torch.device("cuda:0")
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [2]:
class RetinopathyDatasetTest(Dataset):
    def __init__(self, csv_file, transform):
        self.data = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = os.path.join('/storage/aptos2019-blindness-detection/test_images', self.data.loc[idx, 'id_code'] + '.png')
        image = Image.open(img_name)
        image = self.transform(image)
        return {'image': image}

In [3]:
model = torchvision.models.resnet101(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(2048, 1)
model.load_state_dict(torch.load(f'{package_dir}/resnet101.pth'))
model = model.to(device)

In [4]:
for param in model.parameters():
    param.requires_grad = False

model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [5]:
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
test_dataset = RetinopathyDatasetTest(csv_file='/storage/aptos2019-blindness-detection/sample_submission.csv',
                                      transform=test_transform)

#### TTA for the lazy, like me

In [None]:
batch_size=1
preds = []

for i in range(10):
    test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    test_preds = np.zeros((len(test_dataset), 1))
    tk0 = tqdm(test_data_loader)
    for i, x_batch in enumerate(tk0):
        x_batch = x_batch["image"]
        pred = model(x_batch.to(device))
        test_preds[i * batch_size:(i + 1) * batch_size] = pred.detach().cpu().squeeze().numpy().ravel().reshape(-1, 1)
        
    preds = preds + [test_preds]

100%|██████████| 1928/1928 [00:44<00:00, 43.32it/s]
100%|██████████| 1928/1928 [00:43<00:00, 44.00it/s]
100%|██████████| 1928/1928 [00:43<00:00, 43.91it/s]
100%|██████████| 1928/1928 [00:44<00:00, 43.49it/s]
100%|██████████| 1928/1928 [00:43<00:00, 43.84it/s]
100%|██████████| 1928/1928 [00:44<00:00, 43.76it/s]
 96%|█████████▌| 1853/1928 [00:42<00:01, 44.69it/s]

In [None]:
test_preds = preds[0]
for test_pred in preds[1:]:
    test_preds = test_preds + test_pred
test_preds = test_preds / 10

In [None]:
coef = [0.5, 1.5, 2.5, 3.5]

for i, pred in enumerate(test_preds):
    if pred < coef[0]:
        test_preds[i] = 0
    elif pred >= coef[0] and pred < coef[1]:
        test_preds[i] = 1
    elif pred >= coef[1] and pred < coef[2]:
        test_preds[i] = 2
    elif pred >= coef[2] and pred < coef[3]:
        test_preds[i] = 3
    else:
        test_preds[i] = 4


sample = pd.read_csv("/storage/aptos2019-blindness-detection/sample_submission.csv")
sample.diagnosis = test_preds.astype(int)
sample.to_csv("submission.csv", index=False)

In [None]:
sample