In [1]:
# Pytorch
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet50

# Others
import glob
import cv2
import numpy as np
from tqdm.notebook import tqdm
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
model = resnet50()
model.fc = nn.Linear(2048, 4)

checkpoint = torch.load('epoch-160_loss_0.00020.pt', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])

model = model.to(device)

In [3]:
class SATDataset(Dataset):
    def __init__(self, paths):
        self.paths = paths
        self.to_tensor = transforms.ToTensor()
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index: int):
        img_path = self.paths[index]
        roi = cv2.imread(img_path)
        roi = cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)
        roi = self.to_tensor(roi)
        roi = self.normalize(roi)
        return roi

img_paths = sorted(glob.glob('test_dataset_test/*.png'), key=lambda x: int(x.split('/')[-1].split('.')[0]))

BATCH_SIZE = 2
NUM_WORKERS = 4
dataset = SATDataset(img_paths)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False)

In [4]:
centers = []
angles = []

model.eval()
with torch.no_grad():
    for batch in tqdm(dataloader):
        batch = batch.to(device)
        result = model.forward(batch)
        result = result.detach().cpu().numpy()
        centers.append(result[:, :2] * 10496)
        angles.append(np.round(np.rad2deg(np.arctan2(result[:, 2], result[:, 3])) % 360).astype(int))

centers = np.concatenate(centers)
angles = np.concatenate(angles)

  0%|          | 0/200 [00:00<?, ?it/s]

In [32]:
def rotate(p, origin=(0, 0), degrees=0):
    angle = np.deg2rad(degrees)
    R = np.array([[np.cos(angle), -np.sin(angle)],
                  [np.sin(angle),  np.cos(angle)]])
    o = np.atleast_2d(origin)
    p = np.atleast_2d(p)
    return np.squeeze((R @ (p.T-o.T) + o.T).T)

results = []
HALF_SIZE = 512
for center, angle in zip(centers, angles):
    bbox = np.array([
        center + np.array([-HALF_SIZE, -HALF_SIZE]),
        center + np.array([HALF_SIZE, -HALF_SIZE]),
        center + np.array([-HALF_SIZE, HALF_SIZE]),
        center + np.array([HALF_SIZE, HALF_SIZE]),
    ])
    bbox = rotate(bbox, center, angle)
    bbox = np.round(bbox).astype(int)
    results.append({
        'left_top': bbox[0].tolist(),
        'right_top': bbox[1].tolist(),
        'left_bottom': bbox[2].tolist(),
        'right_bottom': bbox[3].tolist(),
        'angle': angle.item()
    })

In [33]:
import json
for pred, path in zip(results, img_paths):
    name = path.split('/')[-1].split('.')[0]
    with open(f'preds/{name}.json', 'w') as f:
        json.dump(pred, f)