In [None]:
import sys
sys.path.append('./satcle')
import matplotlib.pyplot as plt
import torch
from load import get_satcle
from mpl_toolkits.basemap import Basemap
from urllib import request
import numpy as np
import pandas as pd
import io
import torch
from torch.utils.data import TensorDataset, random_split

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, input_dim, dim_hidden, num_layers, out_dims):
        super(MLP, self).__init__()

        layers = []
        layers += [nn.Linear(input_dim, dim_hidden, bias=True), nn.ReLU()] # Input layer
        layers += [nn.Linear(dim_hidden, dim_hidden, bias=True), nn.ReLU()] * num_layers # Hidden layers
        layers += [nn.Linear(dim_hidden, out_dims, bias=True)] # Output layer

        self.features = nn.Sequential(*layers)

    def forward(self, x):
        return self.features(x)

def get_population_data(pred="temp",
                      norm_y=True,
                      norm_x=True
                      ):
    inc = pd.read_csv('data/downstream/labels_population.csv')

    inc = np.array(inc)
    y = inc[:,0].reshape(-1)
    coords = inc[:,1:]
    coords = coords[:,::-1]
    
    y_mean = np.mean(y)
    y_std = np.std(y)
    y = (y - y_mean) / y_std

    return torch.tensor(coords.copy()),  torch.tensor(y), y_mean, y_std


plt.rcParams['font.family'] = 'Times New Roman'
coords, y, y_mean, y_std = get_population_data()

fig, ax = plt.subplots(1, figsize=(18, 5))

m = Basemap(projection='cyl', resolution='c', ax=ax)
m.drawcoastlines()
m.drawmeridians(np.arange(-180, 180, 30), labels=[0,0,0,1]) 
m.drawparallels(np.arange(-90, 90, 30), labels=[1,0,0,0]) 
scatter = ax.scatter(coords[:,0], coords[:,1], c=y*y_std+y_mean, s=5)
ax.set_title('Population')

cbar = plt.colorbar(scatter, shrink=0.9)


plt.show()


In [None]:
from main_satcle import SatCLELightningModule
def get_satcle(ckpt_path, device, return_all=False):
    ckpt = torch.load(ckpt_path,map_location=device)
    if 'eval_downstream' in ckpt['hyper_parameters']:
        ckpt['hyper_parameters'].pop('eval_downstream')
    # ckpt['hyper_parameters'].pop('eval_downstream')
    if 'air_temp_data_path' in ckpt['hyper_parameters']:
        ckpt['hyper_parameters'].pop('air_temp_data_path')
    if 'election_data_path' in ckpt['hyper_parameters']:
        ckpt['hyper_parameters'].pop('election_data_path')
    
    lightning_model = SatCLELightningModule(**ckpt['hyper_parameters']).to(device)

    lightning_model.load_state_dict(ckpt['state_dict'])
    lightning_model.eval()

    geo_model = lightning_model.model

    if return_all:
        return geo_model
    else:
        return geo_model.location
      
# import ipdb; ipdb.set_trace()
satclip_path = '***'

model = get_satcle(satclip_path, device=device)
model.eval()
with torch.no_grad():
  x  = model(coords.double().to(device)).detach().cpu()

print(coords.shape)
print(x.shape)


dataset = TensorDataset(coords, x, y)

train_size = int(0.7 * len(dataset))
# test_size = len(dataset) - train_size
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
# train_set, test_set = random_split(dataset, [train_size, test_size])
train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])

coords_train, x_train, y_train = train_set.dataset.tensors[0][train_set.indices], train_set.dataset.tensors[1][train_set.indices], train_set.dataset.tensors[2][train_set.indices]
coords_test, x_test, y_test = test_set.dataset.tensors[0][test_set.indices], test_set.dataset.tensors[1][test_set.indices], test_set.dataset.tensors[2][test_set.indices]
coords_val, x_val, y_val = val_set.dataset.tensors[0][val_set.indices], val_set.dataset.tensors[1][val_set.indices], val_set.dataset.tensors[2][val_set.indices]


fig, ax = plt.subplots(1, figsize=(18, 5))

m = Basemap(projection='cyl', resolution='c', ax=ax)
m.drawcoastlines()
m.drawmeridians(np.arange(-180, 180, 30), labels=[0,0,0,1])
m.drawparallels(np.arange(-90, 90, 30), labels=[1,0,0,0])
ax.scatter(coords_train[:,0], coords_train[:,1], c='green', s=2, label='Training',alpha=0.5)
ax.scatter(coords_test[:,0], coords_test[:,1], c='lightblue', s=2, label='Testing',alpha=0.5)
ax.legend()
ax.set_title('Train-Test Split')

fig, ax = plt.subplots(1, figsize=(18, 5))

m = Basemap(projection='cyl', resolution='c', ax=ax)
m.drawcoastlines()
m.drawmeridians(np.arange(-180, 180, 30), labels=[0,0,0,1])
m.drawparallels(np.arange(-90, 90, 30), labels=[1,0,0,0])
scatter = ax.scatter(coords_train[:,0], coords_train[:,1], 
                     c=y_train*y_std+y_mean, s=5)
ax.set_title('train set')

cbar = plt.colorbar(scatter, shrink=0.9)

plt.show()

fig, ax = plt.subplots(1, figsize=(18, 5))

m = Basemap(projection='cyl', resolution='c', ax=ax)
m.drawcoastlines()
m.drawmeridians(np.arange(-180, 180, 30), labels=[0,0,0,1])
m.drawparallels(np.arange(-90, 90, 30), labels=[1,0,0,0])
scatter = ax.scatter(coords_test[:,0], coords_test[:,1], 
                     c=y_test*y_std+y_mean, s=5)
ax.set_title('test set')

cbar = plt.colorbar(scatter, shrink=0.9)

plt.show()

In [None]:

pred_model = MLP(input_dim=256, dim_hidden=64, num_layers=2, out_dims=1).float().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(pred_model.parameters(), lr=0.001)

losses = []
epochs = 50000
patience = 100 
best_loss = float('inf')
stop_counter = 0

for epoch in range(epochs):
    optimizer.zero_grad()
    y_pred = pred_model(x_train.float().to(device))
    loss = criterion(y_pred.reshape(-1), y_train.float().to(device))
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
    if (epoch + 1) % 250 == 0:
        print(f"Epoch {epoch + 1}, MSE Loss: {loss.item():.4f}")
        print(f"Epoch {epoch + 1}, MAE Loss: {nn.L1Loss()(y_pred.reshape(-1), y_train.float().to(device)).item():.4f}")
    
    if loss.item() < best_loss:
        best_loss = loss.item()
        stop_counter = 0
    else:
        stop_counter += 1
    if stop_counter >= patience:
        print("Early stopping triggered")
        break

with torch.no_grad():
    model.eval()
    y_pred_test = pred_model(x_test.float().to(device))

print(f'Test MSE loss: {criterion(y_pred_test.reshape(-1), y_test.float().to(device)).item()}')
print(f'Test MAE loss: {nn.L1Loss()(y_pred_test.reshape(-1), y_test.float().to(device)).item()}')

fig, ax = plt.subplots(1, 2, figsize=(10, 3))

m = Basemap(projection='cyl', resolution='c', ax=ax[0])
m.drawcoastlines()
sc0 = ax[0].scatter(coords_test[:,0], coords_test[:,1], 
                    c=y_test*y_std+y_mean, 
                    s=5)
ax[0].set_title('GT')
fig.colorbar(sc0, ax=ax[0], orientation='vertical', shrink=0.65)

m = Basemap(projection='cyl', resolution='c', ax=ax[1])
m.drawcoastlines()
sc1 = ax[1].scatter(coords_test[:,0].cpu().numpy(), coords_test[:,1].cpu().numpy(), 
                    c=(y_pred_test*y_std+y_mean).reshape(-1).cpu().numpy(), s=5)
ax[1].set_title('Predicted')
fig.colorbar(sc1, ax=ax[1], orientation='vertical', shrink=0.65)

plt.show()