In [1]:
# Pytorch
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.models import resnet50

# Transformers
from transformers import AdamW, get_linear_schedule_with_warmup

# Others
import json
import glob
import cv2
import numpy as np
import random
from tqdm.notebook import tqdm
from PIL import Image
import imgaug.augmenters as iaa

# Make computations repeatable
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)

# Compute on gpu if available
device = "cuda" if torch.cuda.is_available() else "cpu"

def extract_region(img, center, size, angle):
    # Extract region from image around the center
    radius = np.ceil(np.sqrt(size** 2 * 2) / 2).astype(int)
    assert min(center) >= radius, 'center is too close to the border'
    cx, cy = center
    roi = img[cy-radius:cy+radius, cx-radius:cx+radius]

    # Rotate this region
    h, w = roi.shape[:2]
    M = cv2.getRotationMatrix2D((w // 2, h // 2), angle, 1.0)
    roi = cv2.warpAffine(roi, M, (w, h))

    # Center crop roi
    start_y = (h - size) // 2
    start_x = (w - size) // 2
    roi = roi[start_y:start_y+size, start_x:start_x+size]

    return roi

In [2]:
sat_map = cv2.imread('original.tiff')
sat_map = cv2.cvtColor(sat_map, cv2.COLOR_BGR2RGB)

In [3]:
class SATDatasetFirst(Dataset):
    def __init__(self):
        self.img_paths = sorted(glob.glob('train/img/*.png'))
        self.labels = []
        for path in sorted(glob.glob('train/json/*.json')):
            with open(path, 'r') as f:
                data = json.load(f)
                left_top = np.array(data['left_top'])
                right_bottom = np.array(data['right_bottom'])
                center = left_top + ((right_bottom - left_top) / 2)
                angle = data['angle']
                self.labels.append([center, angle])
        self.to_tensor = transforms.ToTensor()
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index: int):
        img_path = self.img_paths[index]
        roi = cv2.imread(img_path)
        roi = cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)
        roi = self.to_tensor(roi)
        roi = self.normalize(roi)
        
        center, angle = self.labels[index]
        center = torch.tensor(center / 10496)
        angle = torch.tensor([np.sin(np.deg2rad(angle)), np.cos(np.deg2rad(angle))])
        label = torch.cat([center, angle]).to(torch.float32)

        return roi, label

class SATDatasetSecond(Dataset):
    def __init__(self, sat_map):
        self.sat_map = sat_map
        self.aug = iaa.Clouds()
        self.to_tensor = transforms.ToTensor()
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def __len__(self):
        return 3000

    def __getitem__(self, index: int):
        center = np.array([np.random.randint(725, 9771), np.random.randint(725, 9771)])
        angle = np.random.randint(0, 359)
        
        roi = extract_region(self.sat_map, center, 1024, angle)
        roi = self.aug(image=roi)
        roi = self.to_tensor(roi)
        roi = self.normalize(roi)

        center = torch.tensor(center / 10496)
        angle = torch.tensor([np.sin(np.deg2rad(angle)), np.cos(np.deg2rad(angle))])
        label = torch.cat([center, angle]).to(torch.float32)

        return roi, label

In [4]:
LEARNING_RATE = 1E-4
EPOCHS = 220
BATCH_SIZE = 8
NUM_WORKERS = 20
fs_ds = SATDatasetFirst()
sc_ds = SATDatasetSecond(sat_map)
dataset = torch.utils.data.ConcatDataset([fs_ds, sc_ds])
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)

In [5]:
model = resnet50()
model.fc = nn.Linear(2048, 4)

checkpoint = torch.load('/home/vdd/MIPT/v2/checkpoints/epoch-175_loss_0.00019.pt', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model = model.cuda()

# Loss and optimizer
last_checkpoint_epoch = -1
criterion = torch.nn.MSELoss()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
total_steps = len(dataloader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps = len(dataloader),
    num_training_steps = total_steps
)

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
last_checkpoint_epoch = checkpoint['epoch']



In [6]:
def train_epoch(model, dataloader, loss_fn, optimizer, scheduler, device, writer=None, epoch_index=0):
    # Tracking variables.
    losses = []

    # Put the model into training mode.
    model.train()

    # For each batch of training data...
    for batch_index, batch in enumerate(tqdm(dataloader, total=len(dataloader), desc="Training on batches")):
        global_batch_index = epoch_index * len(dataloader) + batch_index # Global step index
        roi = batch[0].to(device)
        label = batch[1].to(device)
        
        # Forward
        outputs = model(roi)
        loss = loss_fn(outputs, label)
        losses.append(loss.item())

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Clip the norm of gradient to prevent gradient expolosion
        optimizer.step() # Update weights
        scheduler.step() # Update the learning rate.

        # Write loss per batch to tensorboard
        if writer is not None:
            writer.add_scalar('Loss/train (per batch)', loss.item(), global_batch_index)

    return np.mean(losses)

In [None]:
TENSORBOARD_DIR = '/home/vdd/MIPT/v2/tensorboard'
CHECKPOINTS_DIR = '/home/vdd/MIPT/v2/checkpoints'
! mkdir -p {CHECKPOINTS_DIR}

# Tensorboard
writer = SummaryWriter(log_dir=TENSORBOARD_DIR)

# Loop through each epoch.
for epoch in tqdm(range(last_checkpoint_epoch + 1, EPOCHS), desc="Epoch"):
    print(f'Running on epoch: {epoch}')

    # Perform one full pass over the training and validation sets
    train_loss = train_epoch(model, dataloader, criterion, optimizer, scheduler, device, writer, epoch)

    # Populate tensorboard
    writer.add_scalar('Loss/train (per epoch)', train_loss, epoch)

    # Print loss and accuracy values to see how training evolves.
    print(f'train_loss: {train_loss:.5f}\n')

    # Save checkpoint
    if epoch % 5 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }, f"{CHECKPOINTS_DIR}/epoch-{epoch}_loss_{train_loss:.5f}.pt")