## Steps

So note that we have two different sampling methods, "UAR" and "POP", and the labels we have for our 7 tasks are split between them. Thus, we must train on two different image datasets at once.

Our multitask model will also have 7 heads, so we must edit our ResNet18 to accomplish that

We should also save the model parameters for future use.

## Creating our training and test dataset

In [238]:
import io as python_io
import time
import math
from tqdm import tqdm
import sklearn.metrics
import torch
from loguru import logger
from torchvision import transforms
from torch import nn
from torchvision import models
import numpy as np
from pathlib import Path
import pickle
from mosaiks import config as c
from mosaiks import transforms as m_transforms
from mosaiks.featurization import RemoteSensingSubgridDataset
from mosaiks.utils import io, spatial
from mosaiks.solve import data_parser as parse

In [239]:
import pandas as pd
from torchinfo import summary
from tensorboardX import SummaryWriter

In [240]:
tasks_UAR = ["treecover", "elevation", "population",]
tasks_POP = ["nightlights", "income", "roads", "housing",]


tasks = [
    "treecover",
    "elevation",
    "population",
    "nightlights",
    "income",
    "roads",
    "housing",
]

data_home = Path(c.data_dir) / "raw" / "imagery"
data_home_UAR = data_home / "CONTUS_UAR"
data_home_POP = data_home / "CONTUS_POP"

In [241]:
def grab_labels(task):
    c_local = io.get_filepaths(c, task)
    c_app = getattr(c_local, task)
    Y = io.get_Y(c_local, c_app["colname"])
    lons, lats = spatial.ids_to_ll(
        Y.index,
        c.grid_dir,
        c_local.grid["area"],
        c_local.images["zoom_level"],
        c_local.images["n_pixels"],
    )
    latlons = np.vstack((np.array(lats), np.array(lons))).T.astype("float64")
    ids, Y, latlons = m_transforms.dropna_and_transform(
        Y.index.values, Y.values, latlons, c_app
    )
    return Y, latlons, ids

In [242]:
def split_train_test(ids, Y, ratio=0.8):
    seed = 0
    r = np.random.RandomState(seed=seed)
    
    n = ids.shape[0]
    
    test_n = round((1 - ratio) * n)
    train_n = n - test_n
    
    shuffled_idx = r.choice(n, n, replace=False)
    train_idx = shuffled_idx[:train_n]
    test_idx = shuffled_idx[train_n:]
    
    return ids[train_idx], Y[train_idx], ids[test_idx], Y[test_idx]

In [243]:
dfs_UAR = pd.DataFrame()
for task in tasks_UAR:
    Y_task, ll_task, ids_task = grab_labels(task)

    Y_and_ids = np.vstack([Y_task, ids_task]).T
    
    df = pd.DataFrame(Y_and_ids, columns=[task, "id"])
    df = df.set_index("id")
    if dfs_UAR.empty:
        dfs_UAR = df
    else:
        dfs_UAR = dfs_UAR.merge(df, how='outer', on='id')
        
display(dfs_UAR)
    
dfs_POP = pd.DataFrame()
for task in tasks_POP:
    Y_task, ll_task, ids_task = grab_labels(task)

    Y_and_ids = np.vstack([Y_task, ids_task]).T
    
    df = pd.DataFrame(Y_and_ids, columns=[task, "id"])
    df = df.set_index("id")
    if dfs_POP.empty:
        dfs_POP = df
    else:
        dfs_POP = dfs_POP.merge(df, how='outer', on='id')
        
display(dfs_POP)

Unnamed: 0_level_0,treecover,elevation,population
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1000105,91.223158,1462.463364,
10001067,11.715897,1919.119465,2.234753
10001080,0.0,1814.329174,0.385035
10001219,0.162632,2079.739485,0.010085
1000122,89.316842,1544.912699,
...,...,...,...
99976,96.171579,156.416732,
999914,0.003158,1534.612079,1.980177
999942,0.0,1524.94773,
999949,0.0,1430.260728,


Unnamed: 0_level_0,nightlights,income,roads,housing
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1000114,0.0,53277.708883,0.0,
10001457,2.793769,85920.0,5856.211968,5.003458
10001458,2.374133,85920.0,3708.122387,5.144845
10001459,2.813775,62022.434843,5458.893879,5.091291
10001463,1.613503,61151.985279,2711.385912,5.042022
...,...,...,...,...
999962,3.164506,74250.611331,6009.25,
999963,2.678835,73374.52862,5856.974,
999964,3.302048,73351.713585,8132.549,
999965,3.07826,,6581.161,


In [244]:
(ids_train_UAR, Y_train_UAR, ids_test_UAR, Y_test_UAR) = split_train_test(dfs_UAR.index.to_numpy(), dfs_UAR.loc[:, dfs_UAR.columns != 'id'].to_numpy(dtype='float32'))
(ids_train_POP, Y_train_POP, ids_test_POP, Y_test_POP) = split_train_test(dfs_POP.index.to_numpy(), dfs_POP.loc[:, dfs_POP.columns != 'id'].to_numpy(dtype='float32'))

print(Y_train_UAR.shape)
print(Y_train_POP.shape)
print(Y_train_UAR.dtype)
print(Y_train_POP.dtype)

(80000, 3)
(80000, 4)
float32
float32


In [245]:
def transform_img_inputs():
    out = [transforms.ToPILImage(), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
    return transforms.Compose(out)

In [246]:
def get_dataloader(data_home, Y, ids, batch_size=16, shuffle=True, num_workers=4):
    transform = transform_img_inputs()
    r_grid = RemoteSensingSubgridDataset(data_home, Y, ids, transform=transform)
    return torch.utils.data.DataLoader(r_grid, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)

## Model

In [261]:
class MultiTaskModel(nn.Module):
    def __init__(self):
        super(MultiTaskModel, self).__init__()
        #shared part
        self.resnet18 = models.resnet18(pretrained=True)
        num_ftrs = self.resnet18.fc.in_features
        self.resnet18.fc = nn.Identity()
        
        self.sampling = nn.ModuleList()
        self.sampling.add_module('UAR', nn.Linear(num_ftrs, len(tasks_UAR)))
        self.sampling.add_module('POP', nn.Linear(num_ftrs, len(tasks_POP)))

    def forward(self, X, sampling):
        # shared part
        resnet_output = self.resnet18(X)

        # sampling specific parts
        if sampling == 'UAR':
            return self.sampling.UAR(resnet_output)
        elif sampling == 'POP':
            return self.sampling.POP(resnet_output)

In [262]:
class MultiTaskModelWrapper(nn.Module):
    def __init__(self):
        super(MultiTaskModelWrapper, self).__init__()
        self.log_vars = nn.Parameter(torch.zeros((len(tasks))))
        
    def forward(self, outputs, labels, criterion, sampling):
        mask = torch.isnan(labels)
        outputs = outputs * mask
        labels = torch.nan_to_num(labels, nan=0.0)
        if sampling == 'UAR':
            loss_treecover = criterion(outputs[:,0], labels[:, 0])
            precision_treecover = torch.exp(-self.log_vars[0])
            loss_treecover = precision_treecover * loss_treecover + self.log_vars[0]
            
            loss_elevation = criterion(outputs[:,1], labels[:, 1])
            precision_elevation = torch.exp(-self.log_vars[1])
            loss_elevation = precision_elevation * loss_elevation + self.log_vars[1]
            
            loss_population = criterion(outputs[:,2], labels[:, 2])
            precision_population = torch.exp(-self.log_vars[2])
            loss_population = precision_population * loss_population + self.log_vars[2]
            
            return loss_treecover + loss_elevation + loss_population
        elif sampling == "POP":
            loss_nightlights = criterion(outputs[:,0], labels[:, 0])
            precision_nightlights = torch.exp(-self.log_vars[3])
            loss_nightlights = precision_nightlights * loss_nightlights + self.log_vars[3]
            
            loss_income = criterion(outputs[:,1], labels[:, 1])
            precision_income = torch.exp(-self.log_vars[4])
            loss_income = precision_income * loss_income + self.log_vars[4]
            
            loss_roads = criterion(outputs[:,2], labels[:, 2])
            precision_roads = torch.exp(-self.log_vars[5])
            loss_roads = precision_roads * loss_roads + self.log_vars[5]
            
            loss_housing = criterion(outputs[:,3], labels[:, 3])
            precision_housing = torch.exp(-self.log_vars[6])
            loss_housing = precision_housing * loss_housing + self.log_vars[6]
            
            return loss_nightlights + loss_income + loss_roads + loss_housing

In [263]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = MultiTaskModel().to(device)

loss_ft = MultiTaskModelWrapper().to(device)

In [264]:
print(summary(model_ft, input_size=(32, 3, 224, 224), sampling="POP"))
print(summary(loss_ft))

Layer (type:depth-idx)                        Output Shape              Param #
MultiTaskModel                                --                        --
├─ModuleList: 1-1                             --                        --
├─ResNet: 1-2                                 [32, 512]                 --
│    └─Conv2d: 2-1                            [32, 64, 112, 112]        9,408
│    └─BatchNorm2d: 2-2                       [32, 64, 112, 112]        128
│    └─ReLU: 2-3                              [32, 64, 112, 112]        --
│    └─MaxPool2d: 2-4                         [32, 64, 56, 56]          --
│    └─Sequential: 2-5                        [32, 64, 56, 56]          --
│    │    └─BasicBlock: 3-1                   [32, 64, 56, 56]          73,984
│    │    └─BasicBlock: 3-2                   [32, 64, 56, 56]          73,984
│    └─Sequential: 2-6                        [32, 128, 28, 28]         --
│    │    └─BasicBlock: 3-3                   [32, 128, 28, 28]         230,144
│  

In [258]:
def train_model(
    model,
    loss_model,
    criterion,
    train_dataloaders,
    test_dataloaders,
    optimizer,
    scheduler,
    mean_UAR,
    std_UAR,
    mean_POP,
    std_POP,
    num_epochs=50,
    log_loc="./pytorch.logs",
    save_dir=Path(c.data_dir)/"int"/"deep_models",
):
    since = time.time()
    summary_writer = SummaryWriter(Path(log_loc)/"1234")
    global_step = 0
    
    preds = {}
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    logger.debug("Using torch.device: {}".format(device))
    for epoch in range(num_epochs):
        logger.debug("Epoch {}/{}".format(epoch + 1, num_epochs))
        logger.debug("-" * 10)
        
        for phase in ["train", "test"]:
            if phase == "train":
                model.train()
            else:
                model.eval()

            all_labels = {sample : [] for sample in tasks}
            all_predictions = {sample : [] for sample in tasks}
            all_ids = {sample : [] for sample in tasks}

            counter = 0
            lr = optimizer.param_groups[0]["lr"]
            summary_writer.add_scalar(
                tag="learning_rate", scalar_value=lr, global_step=global_step
            )

            dataloaders = train_dataloaders if phase == "train" else test_dataloaders

            num_batches = len(dataloaders["UAR"])
            logger.debug("Total batches: {}".format(num_batches))
            end_time = time.time()
            debug_time = time.time()
            
            for data_UAR, data_POP in tqdm(zip(dataloaders["UAR"], dataloaders["POP"]), total=num_batches):
                for sample in ["UAR", "POP"]:
                    if sample == "UAR":
                        ids, inputs, labels = data_UAR
                    elif sample == "POP":
                        ids, inputs, labels = data_POP

                    counter += 1
                    global_step += 1

                    if sample == "UAR":
                        for i in range(len(tasks_UAR)):
                            all_labels[tasks_UAR[i]] += list(np.vstack(labels.numpy()[:,i]))
                            all_ids[tasks_UAR[i]] += list(ids)
                    elif sample == "POP":
                        for i in range(len(tasks_POP)):
                            all_labels[tasks_POP[i]] += list(np.vstack(labels.numpy()[:,i]))
                            all_ids[tasks_POP[i]] += list(ids)

                    inputs = inputs.float()
                    labels = labels.float()
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    optimizer.zero_grad()

                    with torch.set_grad_enabled(phase == "train"):
                        outputs = model.forward(inputs, sample)

                        if sample == "UAR":
                            for i in range(len(tasks_UAR)):
                                all_predictions[tasks_UAR[i]] += list(outputs.detach().cpu().numpy()[:,i])
                        elif sample == "POP":
                            for i in range(len(tasks_POP)):
                                all_predictions[tasks_POP[i]] += list(outputs.detach().cpu().numpy()[:,i])

                        loss = loss_model.forward(outputs, labels, criterion, sample)
                        if phase == "train":
                            loss.backward()
                            optimizer.step()
                            summary_writer.add_scalar(
                                tag="train_loss",
                                scalar_value=loss.item(),
                                global_step=global_step,
                            )
                        else:
                            summary_writer.add_scalar(
                                tag="val_loss",
                                scalar_value=loss.item(),
                                global_step=global_step,
                            )
                yeet = 100
                if counter % yeet == 0:
                    logger.debug("Time for {} batches: {}".format(yeet, time.time() - debug_time))
                    debug_time = time.time()
                
                # Testin some stuff here
                samp = "treecover"
                std = std_UAR
                mean = mean_UAR
                
                temp_labels = np.array(all_labels[samp])
                temp_pred = np.array(all_predictions[samp])
                
                temp_labels *= std[0]
                temp_labels += mean[0]
                
                temp_pred *= std[0]
                temp_labels += mean[0]
                
                r2_score = sklearn.metrics.r2_score(temp_labels, temp_predictions)
                
                logger.debug(
                    "Epoch {0} Phase {1} of {2} complete, Aggregate R2 Score of {3}: {4}".format(
                        epoch, phase, sample, samp, r2_score
                    )
                )
    
    time_elapsed = time.time() - since
    logger.debug(
        "Training complete in {:.0f}m {:.0f}s".format(
            time_elapsed // 60, time_elapsed % 60
        )
    )
    
    return model
            

In [259]:
def run(
    ids_train_UAR,
    Y_train_UAR,
    ids_test_UAR,
    Y_test_UAR,
    ids_train_POP,
    Y_train_POP,
    ids_test_POP,
    Y_test_POP,
    model,
    data_home_UAR,
    data_home_POP,
    loss,
    num_epochs=25,
    initial_lr=0.001,
    log_loc="./pytorch.logs",
    save_dir=Path(c.data_dir)/"int"/"deep_models",
    batch_size=32,
):
    mean_UAR = np.nanmean(Y_train_UAR, axis=0)
    std_UAR = np.nanstd(Y_train_UAR, axis=0)
    Y_train_UAR = (Y_train_UAR - mean_UAR) / std_UAR
    Y_test_UAR = (Y_test_UAR - mean_UAR) / std_UAR
    
    mean_POP = np.nanmean(Y_train_POP, axis=0)
    std_POP = np.nanstd(Y_train_POP, axis=0)
    Y_train_POP = (Y_train_POP - mean_POP) / std_POP
    Y_test_POP = (Y_test_POP - mean_POP) / std_POP
    
    train_dataloaders = {}
    test_dataloaders = {}
    
    train_dataloaders["UAR"] = get_dataloader(
        data_home_UAR,
        Y_train_UAR,
        ids_train_UAR,
        batch_size=batch_size)
    
    test_dataloaders["UAR"] = get_dataloader(
        data_home_UAR,
        Y_test_UAR,
        ids_test_UAR,
        batch_size=batch_size)
    
    train_dataloaders["POP"] = get_dataloader(
        data_home_POP,
        Y_train_POP,
        ids_train_POP,
        batch_size=batch_size)
    
    train_dataloaders["POP"] = get_dataloader(
        data_home_POP,
        Y_train_POP,
        ids_train_POP,
        batch_size=batch_size)
    
    if loss == "mse":
        criterion = torch.nn.MSELoss()
    else:
        criterion = torch.nn.L1Loss()
        
    loss_model = MultiTaskModelWrapper()
        
    optimizer_ft = torch.optim.SGD(list(model.parameters()) + list(loss_model.parameters()), lr=initial_lr, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer_ft, milestones=[10], gamma=0.5)
    
    if torch.cuda.is_available():
        model.cuda()
    
    return train_model(
        model,
        loss_model,
        criterion,
        train_dataloaders,
        test_dataloaders,
        optimizer_ft,
        scheduler,
        mean_UAR,
        std_UAR,
        mean_POP,
        std_POP,
        num_epochs=num_epochs,
        log_loc=log_loc,
        save_dir=save_dir,
    )

In [260]:
trained_model = run(ids_train_UAR,
    Y_train_UAR,
    ids_test_UAR,
    Y_test_UAR,
    ids_train_POP,
    Y_train_POP,
    ids_test_POP,
    Y_test_POP,
    MultiTaskModel(),
    data_home_UAR,
    data_home_POP,
    'mse',
    num_epochs=1)

2022-05-24 11:30:09.106 | DEBUG    | __main__:train_model:23 - Using torch.device: cuda:0
2022-05-24 11:30:09.107 | DEBUG    | __main__:train_model:25 - Epoch 1/1
2022-05-24 11:30:09.108 | DEBUG    | __main__:train_model:26 - ----------
2022-05-24 11:30:09.113 | DEBUG    | __main__:train_model:46 - Total batches: 2500
0it [00:03, ?it/s]


UnboundLocalError: local variable 'loss_treecover' referenced before assignment

In [None]:
# try save params
model_param_path = code_dir + "/cs230/temp_trained_params.pt"
torch.save(trained_model.state_dict(), model_param_path)

# now try load
loaded_model = MultiTaskModel()
loaded_model.load_state_dict(torch.load(model_param_path))