In [50]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
import pandas as pd
from tqdm import tqdm
from efficientnet_pytorch import EfficientNet
device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

In [51]:
SIZE = (224, 224)

transforms_test = transforms.Compose([
    transforms.Resize(SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [52]:
true_preds = {}
for filename in os.listdir("test_src/fire"):
    true_preds[filename] = 1
for filename in os.listdir("test_src/not_fire"):
    true_preds[filename] = 0

In [65]:
class TestDataset(Dataset):
    
    def __init__(self, images, path, transform):
        self.images = images
        self.transform = transform
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image_path = self.images[idx]
        image = Image.open(path + image_path)
        image = self.transform(image)
        return image, image_path

path = 'test_imgs/'
test_imgs = os.listdir(path) 
test_data = TestDataset(test_imgs, path, transforms_test)

test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=1,
    shuffle=False,
    drop_last=False,
)

In [66]:
def init_model(model_name):
    
    path = 'saved_models/' + model_name
    checkpoint = torch.load(path)
    if model_name.startswith('resnet152'):
        model = models.resnet152()
    if model_name.startswith('resnet101'):
        model = models.resnet101()
    if model_name.startswith('resnet50'):
        model = models.resnet50()
    if model_name.startswith('resnext50'):
        model = models.resnext50_32x4d()
    if model_name.startswith('resnext101'):
        model = models.resnext101_32x8d()
    if model_name.startswith('wide_resnet50'):
        model = models.wide_resnet50_2()
    if model_name.startswith('wide_resnet101'):
        model = models.models.wide_resnet101_2()
    
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, 2)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    
    return model

In [67]:
def inference(model):
    
    predictions = {}
    model.eval()
    curr_correct = 0
    for data in test_loader:
        inputs, image_path = data
        inputs = inputs.to(device)
        output = model(inputs)
        _, preds = torch.max(output, 1)
        predictions[image_path[0]] = 0 if preds.cpu().detach().numpy() else 1
        
    return predictions

In [69]:
def save_submission(predictions, sub_name):
    df = pd.DataFrame(predictions.values(), index=predictions)
    df.to_csv('submissions/' + sub_name)
    return df

In [70]:
os.listdir('saved_models')

['resnet101-224-97.25.pth',
 'resnext50_32-224-96.75.pth',
 'resnet50-224-97.5.pth',
 'resnext101_32-224-96.25.pth',
 'wide_resnet50-224-96.25.pth',
 'resnet101-224-96.5.pth',
 'resnext50_32-224-96.5.pth',
 'resnet101-224-97.0.pth',
 'resnet152-224-96.75.pth',
 'resnext50_32-224-96.25.pth',
 'resnet101-224-96.75.pth',
 'wide_resnet50-224-96.5.pth',
 'resnext50_32-224-97.25.pth',
 'wide_resnet101-224-96.5.pth',
 'wide_resnet50-224-97.0.pth',
 'resnet152-224-95.25.pth',
 'resnet152-224-96.5.pth',
 'resnet50-224-96.75.pth',
 'resnet50-224-96.5.pth',
 'model_efficient_b7_9475.pth']

In [71]:
model_name = 'resnet50-224-97.5.pth'
model = init_model(model_name)
prediction = inference(model)
save_submission(prediction, model_name + ".csv")

Unnamed: 0,0
25.jpg,1
762.jpg,1
313.jpg,1
262.jpg,1
203.jpg,1
...,...
449.jpg,0
835.jpg,0
367.jpg,1
343.jpg,0
