In [None]:
%%capture
!pip install torchmetrics kaggle

In [None]:
from google.colab import files
files.upload()

In [None]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 /root/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d andreisaceleanu/romania-streetview

Downloading romania-streetview.zip to /content
100% 3.59G/3.59G [02:58<00:00, 22.8MB/s]
100% 3.59G/3.59G [02:58<00:00, 21.6MB/s]


In [None]:
import zipfile

# Define the path to your zip file
file_path = '/content/romania-streetview.zip'

# Unzip the file to a specific destination
with zipfile.ZipFile(file_path, 'r') as zip_ref:
    zip_ref.extractall('/content/')

In [None]:
import os
import cv2
import json
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torch.optim as optim

from collections import defaultdict
from sklearn.metrics.pairwise import haversine_distances as hsine
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import v2

from enum import Enum
from math import e
from tqdm import tqdm
from torchmetrics import Accuracy, Precision, Recall, F1Score, MetricCollection
from transformers import CLIPImageProcessor, CLIPVisionModel

# Data preprocessing & loading

In [None]:
class ImageDataset(Dataset):

    def __init__(self, fname_list, label_file, transform, mode=0, model_id="openai/clip-vit-base-patch32", **kwargs):
        super(ImageDataset, self).__init__(**kwargs)
        self.fname_list = fname_list
        self.transform = transform
        with open(label_file, "r") as fin:
            self.label_data = json.load(fin)

        # self.region_map = {
        #     0: [14, 38, 11, 23, 10, 7, 25, 9, 31],
        #     1: [19, 36, 16, 2, 30, 39, 17, 20],
        #     2: [41, 18, 3, 40, 29, 24, 35, 6],
        #     3: [15, 21, 8, 28, 34, 5, 13, 0, 22],
        #     4: [27, 12, 37, 1, 4, 32, 33, 26]
        # }
        # self.region = {}
        # for k,v in self.region_map.items():
        #     d = {elem:k for elem in v}
        #     self.region.update(d)

        self.name2idx = {k:idx for idx, k in enumerate(sorted(list(self.label_data.keys())))}
        if "clip" not in model_id:
            self.image_merge = self._concat_images if mode == 0 else self._stack_images
        else:
            self.processor = CLIPImageProcessor.from_pretrained(model_id)
            self.image_merge = self._clip_images

    def _clip_images(self, fnames):
        return self.processor(
            [
                torch.tensor(cv2.imread(elem)).permute(2,0,1)
                for elem in fnames
            ],
            return_tensors="pt"
        )

    def _concat_images(self, fnames):
        return torch.cat(
            [
                self.transform(torch.tensor(cv2.imread(elem)).permute(2,0,1))
                for elem in fnames
            ],
            dim=0
        )

    def _stack_images(self, fnames):
        return torch.stack(
            [
                self.transform(torch.tensor(cv2.imread(elem)).permute(2,0,1))
                for elem in fnames
            ],
            dim=1
        )

    def __len__(self):
        return len(self.fname_list)

    def __getitem__(self, idx):
        curr_fname = self.fname_list[idx]
        admin1, admin2, loc_idx = curr_fname.split(os.sep)[-3:]
        geo_coords = torch.tensor(
            list(map(float, self.label_data[admin1][admin2][int(loc_idx)].split(",")))
        )
        location = os.path.abspath(self.fname_list[idx])

        fnames = [
            os.path.join(location, elem)
            for elem in sorted(os.listdir(location)) if elem.endswith("jpg")
        ]
        sample = self.image_merge(fnames)

        # class_id = self.region[self.name2idx[admin1]]
        class_id = self.name2idx[admin1]
        return sample, class_id, geo_coords

def get_geocell_centroids(train_paths, label_data, num_classes):

    centroids = np.zeros((num_classes, 2))
    cnts = defaultdict(lambda: 0)
    name2idx = {k:idx for idx, k in enumerate(sorted(list(label_data.keys())))}

    for elem in train_paths:
        county, city, idx = elem.split(os.sep)[-3:]
        lat, lon = list(map(float, label_data[county][city][int(idx)].split(",")))
        centroids[name2idx[county]] += np.array([lat, lon], dtype=np.float32)
        cnts[name2idx[county]] += 1

    for class_idx in range(num_classes):
        centroids[class_idx] /= cnts[class_idx]

    return centroids


def get_dataloaders(db_root: str, eval_splits=(0.1, 0.1), batch_size=2, random_seed=42, mode=0, model_id="openai/clip-vit-base-patch32", **kwargs):

    image_root = os.path.join(db_root, "images")
    counties = sorted(os.listdir(image_root))

    train_paths, val_paths, test_paths = [], [], []
    train_percent = 1 - sum(eval_splits)
    test_percent = eval_splits[1]

    label_file = os.path.join(db_root, "locations.json")
    with open(label_file, "r") as fin:
        label_data = json.load(fin)

    num_classes = 0
    for county in tqdm(counties):
        num_classes += 1
        cities = os.listdir(os.path.join(image_root, county))
        for city in cities:
            if city in label_data[county]:
                locations = [
                    os.path.join(image_root, county, city, elem)
                    for elem in os.listdir(os.path.join(image_root, county, city))
                ]
                cnt_locations = len(locations)
                np.random.seed(random_seed)
                np.random.shuffle(locations)
                train_paths.extend(locations[:int(cnt_locations*train_percent)])
                val_paths.extend(locations[int(cnt_locations*train_percent):-int(cnt_locations*test_percent)])
                test_paths.extend(locations[-int(cnt_locations*test_percent):])

    centroids = get_geocell_centroids(train_paths, label_data, num_classes)

    transform = v2.Compose(
        [
            v2.Lambda(lambd=lambda x: x[...,:512,:512]),
            v2.RandomCrop((128,128), pad_if_needed=True),
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    train_dataset = ImageDataset(fname_list=train_paths, label_file=label_file, mode=mode, model_id=model_id, transform=transform)
    val_dataset = ImageDataset(fname_list=val_paths, label_file=label_file, mode=mode, model_id=model_id, transform=transform)
    test_dataset = ImageDataset(fname_list=test_paths, label_file=label_file, mode=mode, model_id=model_id, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    return (train_loader, val_loader, test_loader), centroids, num_classes

In [None]:
db_root = "/content/geo_dbv2"
(train_loader, val_loader, test_loader), centroids, num_classes = get_dataloaders(db_root, batch_size=32,mode="1",model_id="cnn")

100%|██████████| 42/42 [00:00<00:00, 299.86it/s]


# Model architectures

In [None]:
class Pretrain(Enum):
    NO_PRETRAIN = 0
    PATH = 1
    URL = 2

def hook_fn(module, input, output, name, d):
    d.update({name: output})

In [None]:
class GeoFinderV1(nn.Module):

    def __init__(
        self,
        num_classes=42,
        in_channels=6,
        pretrained=Pretrain.NO_PRETRAIN,
        pretrain_url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth",
        **kwargs
    ) -> None:

        super(GeoFinderV1, self).__init__(**kwargs)
        self.adapter = nn.Conv2d(
            in_channels=in_channels,
            out_channels=3,
            kernel_size=1,
            stride=1
        )
        self.backbone = torchvision.models.resnet50()
        if pretrained == Pretrain.URL:
            state_dict = torch.hub.load_state_dict_from_url(pretrain_url)
            self.backbone.load_state_dict(state_dict)

        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(
            in_features=in_features,
            out_features=num_classes,
            bias=False
        )
        self.layers = dict(self.backbone.named_modules())
        self.intermediates = {}

    def enable_layer(self, layer_name):

        self.layers[layer_name].register_forward_hook(lambda module, input, output, name=layer_name, d=self.intermediates: hook_fn(module, input, output, name, d))

    def forward(self, x):
        return self.backbone(self.adapter(x))

In [None]:
class GeoFinderV2(nn.Module):

    def __init__(
        self,
        num_classes=42,
        in_channels=3,
        features_base = 32,
        **kwargs
    ) -> None:
        super(GeoFinderV2, self).__init__(**kwargs)

        self.backbone = nn.Sequential(
            nn.Conv2d(in_channels, features_base, kernel_size=4, stride=2, padding=1), # 32x64x64
            nn.LeakyReLU(0.2),
            self._block(features_base, features_base * 2, 4, 2, 1), # 64x32x32
            self._block(features_base * 2, features_base * 4, 4, 2, 1), # 128x16x16
            self._block(features_base * 4, features_base * 8, 4, 2, 1), # 256x8x8
            # self._block(features_base * 8, features_base * 16, 4, 2, 1), # 512x4x4
            # self._block(features_base * 16, features_base * 32, 4, 2, 1), # 1024x4x4
        )

        self.head = nn.Sequential(
            nn.Linear(features_base * 8, features_base * 2),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(features_base * 2, num_classes)
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):

        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                bias=False
            ),
            nn.BatchNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.1)
        )

    def _get_embedding(self, x):
        features = self.backbone(x)
        feature_embeds = torch.mean(features, dim=(-2,-1))
        return feature_embeds

    def forward(self, x):
        if len(x.shape) == 4:
            embeds = self._get_embedding(x)
        else:
            embeds = torch.stack(
                [self._get_embedding(x[:,:,idx,:,:]) for idx in range(x.shape[2])],
                dim=-1
            ).mean(dim=-1)

        return self.head(embeds)

In [None]:
class GeoFinderV3(nn.Module):

    def __init__(
        self,
        num_classes=42,
        in_channels=3,
        features_base = 32,
        **kwargs
    ) -> None:
        super(GeoFinderV3, self).__init__(**kwargs)

        self.backbone = nn.Sequential(
            nn.Conv3d(in_channels, features_base, kernel_size=3, stride=1, padding=1), # 32x2x128x128
            nn.LeakyReLU(0.2),
            self._block(features_base, features_base * 2, (3,4,4), (1,2,2), 1), # 64x2x64x64
            self._block(features_base * 2, features_base * 4, (3,4,4), (1,2,2), 1), # 128x2x32x32
            self._block(features_base * 4, features_base * 8, 4, 2, 1), # 256x1x16x16
        )

        self.head = nn.Sequential(
            nn.Linear(features_base * 8, features_base * 2),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(features_base * 2, num_classes)
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):

        return nn.Sequential(
            nn.Conv3d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                bias=False
            ),
            nn.BatchNorm3d(out_channels, affine=True),
            nn.LeakyReLU(0.1)
        )

    def _get_embedding(self, x):
        features = self.backbone(x)
        feature_embeds = torch.mean(features, dim=(-3, -2, -1))
        return feature_embeds

    def forward(self, x):
        embeds = self._get_embedding(x)
        return self.head(embeds)

In [None]:
class GeoFinderV4(nn.Module):

    def __init__(
        self,
        num_classes=42,
        model_id="openai/clip-vit-base-patch32",
        **kwargs
    ) -> None:
        super(GeoFinderV4, self).__init__(**kwargs)

        self.backbone = CLIPVisionModel.from_pretrained(model_id)

        self.head = nn.Linear(768, num_classes)

    def forward(self, x):

        views = torch.chunk(x["pixel_values"],chunks=2,dim=1)

        embeds = torch.stack(
            [self.backbone(pixel_values=view[:,0]).pooler_output for view in views],
            dim=-1
        ).mean(dim=-1)
        return self.head(embeds)

# Training and evaluation

In [None]:
def train(model, loader, optimizer, centroids, metrics_group, tau=50, device="cpu"):

    epoch_loss = 0.0
    model.train()
    for inputs, labels, coords in tqdm(loader):
        # inputs = {k:v.to(device) for k, v in inputs.items()}
        optimizer.zero_grad()
        inputs = inputs.to(device)
        out = F.softmax(model(inputs), dim=-1)
        logits = -torch.log(out+1e-10) # B x 42

        input_coords = coords.detach().numpy()
        a = 6371 * hsine(np.deg2rad(input_coords), np.deg2rad(centroids)) # B x 42
        b = np.diag(6371 * hsine(np.deg2rad(centroids[labels.detach().numpy()]), np.deg2rad(coords)))[:, np.newaxis] # B x 1
        weights = torch.tensor(e**(-(a - b) / tau)).detach().to(device) # B x 42


        loss = torch.mean(torch.sum(logits * weights, dim=1))
        loss.backward()
        optimizer.step()

        labels = labels.to(device)
        metrics_group.update(out, labels)
        epoch_loss += loss.item()

    return epoch_loss / len(loader), metrics_group.compute()

def evaluate(model, loader, centroids, metrics_group, tau=50, device="cpu"):

    epoch_loss = 0.0
    model.eval()

    with torch.no_grad():
        for inputs, labels, coords in tqdm(loader):
            # inputs = {k:v.to(device) for k, v in inputs.items()}
            inputs = inputs.to(device)

            out = F.softmax(model(inputs), dim=-1)
            logits = -torch.log(out+1e-10) # B x 42

            input_coords = coords.detach().numpy()
            a = 6371 * hsine(np.deg2rad(input_coords), np.deg2rad(centroids)) # B x 42
            b = np.diag(6371 * hsine(np.deg2rad(centroids[labels.detach().numpy()]), np.deg2rad(coords)))[:, np.newaxis] # B x 1
            weights = torch.tensor(e**(-(a - b) / tau)).detach().to(device) # B x 42

            loss = torch.mean(torch.sum(logits * weights, dim=1))

            labels = labels.to(device)
            metrics_group.update(out, labels)
            epoch_loss += loss.item()

    return epoch_loss / len(loader), metrics_group.compute()

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = GeoFinderV3(num_classes=num_classes)
# for param in model.backbone.parameters():
#     param.requires_grad = False
model = model.to(device)
optimizer = optim.AdamW([elem for elem in model.parameters() if elem.requires_grad], lr=5e-3, weight_decay=0.0)
epochs = 2

metrics_group = MetricCollection(
    [
        Accuracy(task="multiclass", num_classes=num_classes),
        Precision(task="multiclass", num_classes=num_classes),
        Recall(task="multiclass", num_classes=num_classes),
        F1Score(task="multiclass", num_classes=num_classes)
    ]
).to(device)

In [None]:
best_score = None
best_state_dict = None

for epoch in range(epochs):
    print(f"Epoch {epoch}")
    train_loss, train_metrics = train(
        model,
        train_loader,
        optimizer,
        centroids,
        metrics_group,
        tau=100,
        device=device
    )
    metrics_group.reset()

    valid_loss, valid_metrics = evaluate(
        model,
        val_loader,
        centroids,
        metrics_group,
        tau=100,
        device=device
    )
    metrics_group.reset()

    valid_f1 = valid_metrics["MulticlassF1Score"].cpu().item()
    if (best_score is None) or (valid_f1 > best_score):
        best_score = valid_f1
        best_state_dict = model.state_dict()

    print("Train")
    print(f"Loss: {train_loss:.3f}", end=", ")
    print(", ".join([f"{k}: {v.cpu().item() * 100:.2f}" for k, v in train_metrics.items()]))

    print("Validation")
    print(f"Loss: {valid_loss:.3f}", end=", ")
    print(", ".join([f"{k}: {v.cpu().item() * 100:.2f}" for k, v in valid_metrics.items()]))
    print()

model.load_state_dict(best_state_dict)
test_loss, test_metrics = evaluate(
    model,
    test_loader,
    centroids,
    metrics_group,
    tau=50,
    device=device
)
metrics_group.reset()

print("Test")
print(f"Loss: {test_loss:.3f}", end=", ")
print(", ".join([f"{k}: {v.cpu().item() * 100:.2f}" for k, v in test_metrics.items()]))

Epoch 0


100%|██████████| 735/735 [04:39<00:00,  2.63it/s]
100%|██████████| 95/95 [00:31<00:00,  3.01it/s]


Train
Loss: 31.314, MulticlassAccuracy: 2.33, MulticlassPrecision: 2.33, MulticlassRecall: 2.33, MulticlassF1Score: 2.33
Validation
Loss: 31.205, MulticlassAccuracy: 2.38, MulticlassPrecision: 2.38, MulticlassRecall: 2.38, MulticlassF1Score: 2.38

Epoch 1


  3%|▎         | 19/735 [00:07<04:37,  2.58it/s]


KeyboardInterrupt: 

In [None]:
best_score = None
best_state_dict = None

for epoch in range(epochs):
    print(f"Epoch {epoch}")
    train_loss, train_metrics = train(
        model,
        train_loader,
        optimizer,
        centroids,
        metrics_group,
        tau=50,
        device=device
    )
    metrics_group.reset()

    valid_loss, valid_metrics = evaluate(
        model,
        val_loader,
        centroids,
        metrics_group,
        tau=50,
        device=device
    )
    metrics_group.reset()

    valid_f1 = valid_metrics["MulticlassF1Score"].cpu().item()
    if (best_score is None) or (valid_f1 > best_score):
        best_score = valid_f1
        best_state_dict = model.state_dict()

    print("Train")
    print(f"Loss: {train_loss:.3f}", end=", ")
    print(", ".join([f"{k}: {v.cpu().item() * 100:.2f}" for k, v in train_metrics.items()]))

    print("Validation")
    print(f"Loss: {valid_loss:.3f}", end=", ")
    print(", ".join([f"{k}: {v.cpu().item() * 100:.2f}" for k, v in valid_metrics.items()]))
    print()

model.load_state_dict(best_state_dict)
test_loss, test_metrics = evaluate(
    model,
    test_loader,
    centroids,
    metrics_group,
    tau=50,
    device=device
)
metrics_group.reset()

print("Test")
print(f"Loss: {test_loss:.3f}", end=", ")
print(", ".join([f"{k}: {v.cpu().item() * 100:.2f}" for k, v in test_metrics.items()]))

Epoch 0


100%|██████████| 735/735 [11:45<00:00,  1.04it/s]
100%|██████████| 95/95 [01:30<00:00,  1.05it/s]


Train
Loss: 14.256, MulticlassAccuracy: 9.83, MulticlassPrecision: 9.83, MulticlassRecall: 9.83, MulticlassF1Score: 9.83
Validation
Loss: 14.003, MulticlassAccuracy: 12.21, MulticlassPrecision: 12.21, MulticlassRecall: 12.21, MulticlassF1Score: 12.21

Epoch 1


100%|██████████| 735/735 [11:47<00:00,  1.04it/s]
100%|██████████| 95/95 [01:32<00:00,  1.02it/s]


Train
Loss: 13.938, MulticlassAccuracy: 14.05, MulticlassPrecision: 14.05, MulticlassRecall: 14.05, MulticlassF1Score: 14.05
Validation
Loss: 13.911, MulticlassAccuracy: 15.54, MulticlassPrecision: 15.54, MulticlassRecall: 15.54, MulticlassF1Score: 15.54

Epoch 2


100%|██████████| 735/735 [11:41<00:00,  1.05it/s]
100%|██████████| 95/95 [01:29<00:00,  1.07it/s]


Train
Loss: 13.903, MulticlassAccuracy: 14.68, MulticlassPrecision: 14.68, MulticlassRecall: 14.68, MulticlassF1Score: 14.68
Validation
Loss: 13.926, MulticlassAccuracy: 13.63, MulticlassPrecision: 13.63, MulticlassRecall: 13.63, MulticlassF1Score: 13.63

Epoch 3


100%|██████████| 735/735 [11:39<00:00,  1.05it/s]
100%|██████████| 95/95 [01:32<00:00,  1.03it/s]


Train
Loss: 13.882, MulticlassAccuracy: 15.62, MulticlassPrecision: 15.62, MulticlassRecall: 15.62, MulticlassF1Score: 15.62
Validation
Loss: 13.705, MulticlassAccuracy: 15.68, MulticlassPrecision: 15.68, MulticlassRecall: 15.68, MulticlassF1Score: 15.68

Epoch 4


100%|██████████| 735/735 [11:41<00:00,  1.05it/s]
100%|██████████| 95/95 [01:31<00:00,  1.04it/s]


Train
Loss: 13.769, MulticlassAccuracy: 16.60, MulticlassPrecision: 16.60, MulticlassRecall: 16.60, MulticlassF1Score: 16.60
Validation
Loss: 14.051, MulticlassAccuracy: 14.06, MulticlassPrecision: 14.06, MulticlassRecall: 14.06, MulticlassF1Score: 14.06

Epoch 5


100%|██████████| 735/735 [11:49<00:00,  1.04it/s]
100%|██████████| 95/95 [01:37<00:00,  1.03s/it]


Train
Loss: 13.763, MulticlassAccuracy: 17.03, MulticlassPrecision: 17.03, MulticlassRecall: 17.03, MulticlassF1Score: 17.03
Validation
Loss: 13.862, MulticlassAccuracy: 14.19, MulticlassPrecision: 14.19, MulticlassRecall: 14.19, MulticlassF1Score: 14.19

Epoch 6


  2%|▏         | 18/735 [00:18<12:36,  1.05s/it]


KeyboardInterrupt: 