# Script exemple for inference 

In [1]:
import os
import multiprocessing
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image

import torch
from torch import nn
from torch.utils.data import Dataset
from torchvision import models, transforms

multiprocessing.cpu_count()

12

In [2]:
class EsthDataset(torch.utils.data.Dataset):
    def __init__(self, paths, transforms=None):
        self.paths = paths
        self.transform = transforms
        
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        image = Image.open(self.paths[index])
        
        if image.size[0] != 500 or image.size[1] != 500 :
            print(f"{self.paths[index]} is the wrong size : {image.size}. It has been resized to (500,500) but tha accuracy of the score cannot be guaranteed")
            image = image.resize((500,500))
        
        if self.transform is not None:
            image = self.transform(image)
            
        return image, torch.tensor(index, dtype=torch.float32)

In [3]:
# hyperparameters config for the model architecture and data loader

# Use GPU or CPU
use_cuda = torch.cuda.is_available()
#device = torch.device("cuda" if use_cuda else "cpu")
device = "mps" if torch.backends.mps.is_available() else "cpu"


#Path for the weights of the saved model we want to use
#weights_path = './saved_models/1000epochs_survey/Chckpt_ResNet50_-11339.0159.pt'
weights_path = './saved_models/1000epochs_survey/Chckpt_ResNet50_-11339.0159.pt'


#Path to the input data and output file
#in_path = '../data/BIG_FILES/ggstreet/png/'
in_path = '../data/BIG_FILES/ggstreet/png_good/'
#out_path = '../results/inference_ggstreet_facade.csv'
out_path = '../results/inference_ggstreet.csv'

In [4]:
def get_model_for_eval(path_to_weights):
    """Gets the broadcasted model."""
    model = models.resnet50()
    
    model.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(2048, 1)
    )
    
    model.load_state_dict(torch.load(path_to_weights))
    model.eval()
    
    return model

def predict_batch(paths, path_to_weights, output_path):
    transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.392, 0.298, 0.203], [0.192, 0.167, 0.140])
    ])
    
    model = get_model_for_eval(path_to_weights)
    model.to(device)
    
    images = EsthDataset(paths, transforms=transform)
    loader = torch.utils.data.DataLoader(images, batch_size=16, shuffle=False)
    
    im_names = torch.tensor([], dtype=torch.float, device=device)
    y_pred = torch.tensor([], device=device)
    
    with torch.no_grad():
        for data in tqdm(loader):
            inputs = [i.to(device) for i in data[:-1]]
            names = data[-1].to(device)

            outputs = model(*inputs)
            im_names = torch.cat((im_names, names), 0)
            y_pred = torch.cat((y_pred, outputs), 0)

    y_pred = y_pred.cpu().numpy().flatten()
    im_names = im_names.cpu().numpy().flatten().astype(int)
            
    df = pd.DataFrame({"image_name":[paths[i] for i in im_names],"predicted_score":y_pred})
    df.to_csv(output_path, sep=",", index=False)

In [5]:
files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(in_path) for f in filenames if os.path.splitext(f)[1] in ['.png', '.jpg','.jpeg']]

In [6]:
predict_batch(files, weights_path, out_path)

100%|███████████████████████████████████| 22198/22198 [2:30:29<00:00,  2.46it/s]
