In [2]:
import pandas as pd
import numpy as np
import torch
from torch.utils import data
import torchvision
from torchvision import transforms
from skimage.io import imread
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score
from tqdm import tqdm
import random

In [14]:
class DataGenerator(data.Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform
        
    def __len__(self):
        return self.dataset.shape[0]
    
    def __getitem__(self, index):
        label = self.dataset['label'][index]
        
        img_raw = imread(self.dataset['img_path'][index])
        img = self.transform(img_raw)
        
        return [img, label]

In [15]:
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

## Load Saved Model

In [16]:
model = torchvision.models.resnet34(pretrained=False)

In [17]:
model.fc = torch.nn.Sequential(
    torch.nn.Linear(in_features=512, out_features=1),
    torch.nn.Sigmoid()
)

In [18]:
model.load_state_dict(torch.load("checkpoint_final.pth"))

<All keys matched successfully>

## Predictions

### Load Test Data

In [19]:
test_path = "../data/test/"

test_values = pd.read_csv("../data/sample_submission.csv")
test_values['img_path'] = test_path + test_values[['id']] + '.tif'

In [20]:
test_data = DataGenerator(test_values, transform=transform)

In [21]:
test_loader = data.DataLoader(test_data, batch_size=64)

In [26]:
len(test_loader)

898

In [28]:
## USE GPU
model = model.cuda()

predictions = []
model.eval()
total_test_batches = len(test_loader)

for x,y in tqdm(test_loader, total = total_test_batches):
    x = x.cuda()
    y = y.cuda()
    y = y.view(-1, 1).float()
    
    with torch.no_grad():
        out = model(x)
        predictions.append((out >= 0.5).float().cpu().numpy())

100%|██████████| 898/898 [05:39<00:00,  2.65it/s]
