# Prepare source domain dataloaders and experts

In [1]:
# Add to path
import sys
import pandas as pd
import torch
from torch import nn
import numpy as np
from torch import optim
import tqdm
LOAD_EXPERTS = True
DEVICE = 'cpu'

sys.path.append('../shifts/weather/')

In [2]:


df_train = pd.read_csv('canonical-paritioned-dataset/shifts_canonical_train.csv')
df_dev_in = pd.read_csv('canonical-paritioned-dataset/shifts_canonical_dev_in.csv')
df_dev_out = pd.read_csv('canonical-paritioned-dataset/shifts_canonical_dev_out.csv')
df_dev = pd.concat([df_dev_in, df_dev_out])

domains_train = df_train.climate.unique()

In [4]:
df_train.climate.unique()

array(['dry', 'mild temperate', 'tropical'], dtype=object)

In [3]:
 categorical_cols = [
                  'cmc_available',
                  'gfs_available',
                  'gfs_soil_temperature_available',
                  'wrf_available'
            ]

In [10]:
df_train.isnull().any().any()

True

In [6]:
d = df_train[df_train.climate == 'dry'].iloc[:,6:].copy()
d.drop(labels=categorical_cols, axis=1).isnull().any().any()

True

In [7]:
import torch

class Dataset(torch.utils.data.Dataset):
    def __init__(self, df, climate):
        if climate is not None:
            self.X_source_domain = df[df.climate == climate].iloc[:,6:].copy()
            self.y_source_domain = df[df.climate == climate]['fact_temperature'].copy()
            self.climate = climate
        else:
            self.X_source_domain = df.iloc[:,6:].copy()
            self.y_source_domain = df['fact_temperature'].copy()
            self.climate = climate

        assert len(self.X_source_domain) == len(self.y_source_domain)

    def __len__(self):
        return len(self.y_source_domain)

    def __getitem__(self, index):
        X = torch.tensor(self.X_source_domain.iloc[index].values)
        y = torch.tensor(self.y_source_domain.iloc[index])
        metadata = {
            'cat_cols': [
                  'cmc_available',
                  'gfs_available',
                  'gfs_soil_temperature_available',
                  'wrf_available'
            ]
        }
        return X, y, metadata

### Create source domain loaders

In [8]:
batch_size = 64
source_train_domains_loaders =  {
    climate: torch.utils.data.DataLoader(Dataset(df_train, climate), batch_size = batch_size)
    for climate in domains_train
}
val_loader = torch.utils.data.DataLoader(Dataset(df_dev, climate=None), batch_size = batch_size)

### Train or Load experts

In [9]:
source_domains_experts = {}
def initializeExperts():

    model = nn.Sequential(nn.Linear(123, 1)).to(DEVICE)

    for climate in domains_train:
        source_domains_experts[climate] = model
initializeExperts()

In [10]:
experts_train_config = {
    'epochs': 100,
    'lr': 0.005,
    'l2': 0, 
    'decayRate': 0.96
}

In [11]:
for climate, model in source_domains_experts.items():
    break

In [12]:

def get_single_expert_prediction(batch, expert):
    return expert.forward(batch)

def get_all_experts_predictions(batch, domain, experts, ):
    all_preds = []
    for climate, expert in experts.items():
        raise NotImplementedError
        all_preds.append(preds)
    return torch.stack(all_preds, axis=0)

In [13]:
def get_validation_metrics(model, data_loader, domain=None):
    raise NotImplementedError

In [14]:
def train_expert(model, train_loader, val_loader, experts_train_config = experts_train_config, domain=None):
    
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=experts_train_config['lr'], weight_decay=experts_train_config['l2'])
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=experts_train_config['decayRate'])
    
    i = 0
    
    losses = []
    acc_best = 0

    tot = len(train_loader)
    
    for epoch in range(experts_train_config['epochs']):
        
        print(f"Epoch:{epoch} || Total:{tot}")
        
        for X, y_true, metadata in iter(train_loader):
            model.train()
            
            X = X.to(torch.float32).to(DEVICE)
            y_true = y_true.to(DEVICE)
            
            pred = model(X)

            loss = criterion(pred, y_true)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            losses.append(loss.item()/batch_size)
            
            if i % (tot//2) == 0 and i != 0:
                losses = np.mean(losses)
                #acc, worst_acc = get_validation_metrics(model, val_loader, grouper, device=device, dataset=dataset)
                
                print("Iter: {} || Loss: {:.4f} ".format(i,losses))
                losses = []
                
                #if worst_acc > acc_best and save:
                #    print("Saving model ...")
                #    save_model(model, model_name+"_exp", 0, test_way=test_way)
                #    acc_best = worst_acc
                
            i += 1
        scheduler.step()

In [15]:
def load_expert(model):
    raise NotImplementedError

In [9]:

if LOAD_EXPERTS:
    print('Loading experts...')
    dir_path = 'trained_experts'
    source_domains_experts = {}
    for climate in domains_train:
        load_expert(source_domains_experts[climate])
else:
    print('Training experts...')
    print(experts_train_config)

    for climate in domains_train:
        train_loader = source_train_domains_loaders[climate]
        train_expert(source_domains_experts[climate], train_loader, val_loader)
        