In [1]:
from dataset import *
from train import *

In [2]:
import os
import pandas as pd

import collections

import sys

import torch
import torch.optim as optim

In [3]:
# Set parameters
EuroSat_Type = 'ALL'    

In [4]:
args = str(sys.argv)
target_task = args[1]
algorithm = args[2]
algorithm = "bandit"


In [5]:
target_task = "France"#"Lietuva"

In [6]:
from pathlib import Path
output_path = Path("derived_data")
output_path.mkdir(parents = True, exist_ok = True)

In [7]:
if EuroSat_Type == 'RGB':
  data_folder = '/content/sample_data/'
  #root = os.path.join(data_folder, '2750/')
  root = '2750/'
  download_ON = os.path.exists(root)

  if not download_ON:
    # This can be long...
    #os.chdir(data_folder)
    os.system('wget http://madm.dfki.de/files/sentinel/EuroSAT.zip') #Just RGB Bands
    !unzip EuroSAT.zip
    download_ON = True
elif EuroSat_Type == 'ALL':
    root = 'ds/images/remote_sensing/otherDatasets/sentinel_2/tif/'
    download_ON = os.path.exists(root)
    if not download_ON:
      os.system('wget http://madm.dfki.de/files/sentinel/EuroSATallBands.zip') #All bands
      !unzip EuroSATallBands.zip
      download_ON = True

In [8]:
geo_df = pd.read_csv("metadata.csv")

In [9]:
data = torchvision.datasets.DatasetFolder(root=root,loader = iloader, transform=None, extensions = 'tif')
labels = [v[1] for (i, v) in enumerate(data)]

In [10]:
def check_labels(input_data, labels):
    """
    check labels across source / target train / target validation / target test sets, and keep labels in common
    """
    train_labels = [labels[i] for i in input_data["idx_train"]]
    val_labels = [labels[i] for i in input_data["idx_val"]]
    test_labels = [labels[i] for i in input_data["idx_test"]]
    source_labels = [labels[i] for i in input_data["idx_source"]]

    common_labels = list(set(train_labels).intersection(val_labels).intersection(test_labels).intersection(source_labels))

    input_data["idx_train"] = [i for i in input_data["idx_train"] if labels[i] in common_labels]
    input_data["idx_test"] = [i for i in input_data["idx_test"] if labels[i] in common_labels]
    input_data["idx_val"] = [i for i in input_data["idx_val"] if labels[i] in common_labels]
    input_data["idx_source"] = [i for i in input_data["idx_source"] if labels[i] in common_labels]
    
    return input_data

In [17]:
def prepare_input_data(geo_df, target_task = "France", labels = None, train_size = 640, val_size = 160, test_size = 160):
    geo_dict = geo_df.to_dict()
    countries = list(set(geo_dict["country"].values()))
    countries = [x for x in countries if str(x) != "nan"]
    id_countries = dict.fromkeys(countries)
    for k in id_countries.keys():
        id_countries[k] = [v for (i, v) in enumerate(geo_dict["id"]) if geo_dict["country"][i] == k]

    
    input_data = {
        "source_task": list(set(id_countries.keys()) - set(target_task)),
        "target_task": target_task
    }
    
    # define a dictionary for all data
    
    input_data["data_dict"] = {}
    for k in geo_dict.keys():
        input_data["data_dict"][k] = [geo_dict[k][i] for (i, v) in enumerate(geo_dict["country"].values()) if str(v) != "nan"]
    
    
    
    # split indices to source and target
    
    input_data["idx_source"] = [i for (i, v) in enumerate(input_data["data_dict"]['country']) if v != input_data["target_task"]]
    input_data["idx_target"] = [i for (i, v) in enumerate(input_data["data_dict"]['country']) if v == input_data["target_task"]]
    
    
    # define a dictionary for source data
    
    input_data["source_dict"] = {}
    for k in geo_dict.keys():
        input_data["source_dict"][k] = [input_data["data_dict"][k][i] for i in input_data["idx_source"]]

    
    if labels is None:
        input_data["idx_train"] = random.sample(input_data["idx_target"], train_size)
        idx_rest = list(set(input_data["idx_target"]) - set(input_data["idx_train"]))
        input_data["idx_test"] = random.sample(idx_rest, test_size)
        input_data["idx_val"] = list(set(idx_rest) - set(input_data["idx_test"]))
        input_data["idx_val"] = random.sample(input_data["idx_test"], val_size)
    else:
        target_labels = [labels[i] for i in input_data["idx_target"]]
        input_data["idx_train"], idx_rest, _, _ = train_test_split(input_data["idx_target"], target_labels, test_size = .3, random_state = 0,
                                 shuffle = True)
        idx_rest = list(set(input_data["idx_target"]) - set(input_data["idx_train"]))
        input_data["idx_val"], input_data["idx_test"], _, _ = train_test_split(idx_rest, 
                                                                               [labels[i] for i in idx_rest],
                                                                               test_size = .5, random_state = 0, shuffle = True)
        input_data = check_labels(input_data, labels)
    return input_data


In [23]:
input_data = prepare_input_data(geo_df, target_task, labels = labels)

In [24]:
np.random.seed(0)
torch.cuda.manual_seed(0)
random.seed(0)

In [25]:
# prepare data ---

target_val_loader =  torch.utils.data.DataLoader(torch.utils.data.Subset(data, input_data["idx_val"]), 
                                              batch_size = 16, shuffle = True, num_workers = 0)
target_train_loader =  torch.utils.data.DataLoader(torch.utils.data.Subset(data, input_data["idx_train"]), 
                                                  batch_size = 16, shuffle = True, num_workers = 0)
target_test_loader =  torch.utils.data.DataLoader(torch.utils.data.Subset(data, input_data["idx_test"]), 
                                                  batch_size = 16, shuffle = True, num_workers = 0)



# initialize hyperparameters ---

bandit_selects = [None]
alpha = dict.fromkeys(input_data["source_task"], [1])
beta = dict.fromkeys(input_data["source_task"], [1])
pi = dict.fromkeys(input_data["source_task"], [0])


In [26]:
def bandit_selection(data, input_data, n_epochs = 3, n_it = 2, algorithm = "bandit",iter_samples = 160,
                     lr = .01, milestones = milestones,
                     criteria = criteria, output_path = "."):
    # prepare data ---
    
    target_val_loader =  torch.utils.data.DataLoader(torch.utils.data.Subset(data, input_data["idx_val"]), 
                                                  batch_size = 16, shuffle = True, num_workers = 0)
    target_train_loader =  torch.utils.data.DataLoader(torch.utils.data.Subset(data, input_data["idx_train"]), 
                                                      batch_size = 16, shuffle = True, num_workers = 0)
    target_test_loader =  torch.utils.data.DataLoader(torch.utils.data.Subset(data, input_data["idx_test"]), 
                                                      batch_size = 16, shuffle = True, num_workers = 0)
    

    
    # initialize hyperparameters ---
    train_log = []
    bandit_selects = [None]
    alpha = dict.fromkeys(input_data["source_task"], [1])
    beta = dict.fromkeys(input_data["source_task"], [1])
    pi = dict.fromkeys(input_data["source_task"], [0])
    
    
    # initialize model ---
   
    net = Load_model()
    optimizer = optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones = milestones, gamma=0.1)
    if torch.cuda.is_available():
        net=net.cuda()

    net, val_acc, _, _ = train(net, target_train_loader, target_test_loader , criteria, optimizer, 2, scheduler)

    print("Model initiated with acc ", val_acc[-1])
    accs = [val_acc[-1]]
    
    # train ---
    
    for t in range(n_it):
        if algorithm == "bandit":
            bandit_current, pi = get_bandit(input_data, alpha, beta,t, pi)
            bandit_selects.append(bandit_current)
            print("---", "At iteration ", t, ", source country is ", bandit_current, "-----\n")
            current_id = [input_data["source_dict"]["id"][i] for (i, v) in enumerate(input_data["source_dict"]['country']) if v == bandit_current]
            current_id = random.choices(current_id, k = iter_samples)
        else:
            bandit_current = 0
            current_id = random.sample(input_data["idx_source"], k = iter_samples)
        current_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(data, input_data["idx_test"]), 
                                                          batch_size = 16, shuffle = True, num_workers = 0)
        net, val_acc, train_acc, train_losses = train(net, current_loader, target_test_loader , criteria, optimizer, 2, scheduler)
        
        print("At iteration ", t, ", source country is ", bandit_current, ", acc is ", val_acc[-1])
        accs += [val_acc[-1]]
        
        
        # save logs
        train_log.append({"iter": [t for i in range(n_it)],
                          "train_acc": val_acc.tolist(),
                          "val_acc": val_acc.tolist(),
                          "train_losses": train_losses.tolist()})
        
        if algorithm == "bandit":
            alpha, beta = update_hyper_para(alpha, beta, t, accs,
                                            bandit_current
                                           )
        if not output_path is None:
            if t % 10 == 0:
                torch.save(net.state_dict(), output_path / Path(input_data["target_task"] + "_" + algorithm + ".pt" ))
                save_output(output_path / Path(input_data["target_task"] + "_" + algorithm + "_evaluation.csv" ), accs, accs)
                
                print(train_log)
                log_df = pd.concat([pd.DataFrame(r) for r in train_log])
                log_df.to_csv(output_path /  Path(input_data["target_task"] + "_" + algorithm + "train_log.csv"))

                if algorithm == "bandit":
                    pd.DataFrame.from_dict(alpha).to_csv(output_path /  Path(input_data["target_task"] + "_" + algorithm + "alpha.csv"))
                    pd.DataFrame.from_dict(beta).to_csv(output_path /  Path(input_data["target_task"] + "_" + algorithm + "beta.csv"))
                    pd.DataFrame.from_dict(pi).to_csv(output_path / Path(input_data["target_task"] + "_" + algorithm +  "pi.csv"))
    return net, bandit_selects, accs, alpha, beta, pi


In [27]:
_, bandit_selects, accs, alpha, beta, pi = bandit_selection(data, input_data, 
                                                            n_epochs = 1, n_it = 2,
                                                            algorithm = algorithm, iter_samples = 160,
                                                           output_path = output_path)

ValueError: all the input array dimensions for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 16 and the array at index 1 has size 4