In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from math import ceil

preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

from model import EarthVisionModel

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        _dict = pickle.load(fo, encoding='bytes')
    return _dict

label_meta = unpickle("./meta")

#manually define class indexes corresponding with contstructed objects in label_meta
buildings = {12, 6, 60, 13, 84, 18, 16, 73, 95, 49, 55, 58, 59, 72, 48, 96, 19, 39, 90, 75, 23, 97, 69, 56, 33, 81, 63, 85, 5, 17, 71, 37, 76, 9, 31, 68, 38}


In [2]:
model = EarthVisionModel()                             #declare the model
model.eval()                                           #set the model to evaluation mode
model.load_state_dict(torch.load('state_dict.pt'))     #load the model

<All keys matched successfully>

In [3]:
flag = np.zeros((32,256))                              #declare the flag as 32*256

for row in range(0,32):                                #running through 32 rows
    for col in range (0,256):                          #running through 256 columns
        i = ((256 * (row)) + col)                      #calculated image number to pull in
        image = Image.open(f"test_X/{i}.jpg")          #import image into the model
        img = preprocess(image)                        #load and preprocess the input image
        img_tensor = img.unsqueeze(0)                  #convert the preprocessed image to a tensor
        with torch.no_grad():                          #disable gradients to save memory and resolve faster
            output = model(img_tensor)                 #run the input image through the model to get a prediction
        output_probs = F.softmax(output, dim=1)        #get predicted class index
        _, predicted = torch.max(output_probs, 1)

        if predicted.item() in buildings:              #check if class index belongs to buildings
            color = 100                                #assign white to pixel belonging to buildings
        else:
            color = 1                                  #assign black to any other
        
        flag[row,col] = color                          #put a pixel into the corresponding position in the flag

plt.imshow(flag, cmap="gray")
plt.show()