In [1]:
from dataset import *
from train import *

In [2]:
import os
import pandas as pd

import collections

import sys

import torch
import torch.optim as optim

In [3]:
# Set parameters
EuroSat_Type = 'ALL'    

In [4]:
args = str(sys.argv)
target_task = args[1]
algorithm = args[2]
algorithm = "bandit"


In [5]:
target_task = 1#"Moldova"
target_size = 160

In [6]:
from pathlib import Path
output_path = Path("derived_data")
output_path.mkdir(parents = True, exist_ok = True)

In [7]:
if EuroSat_Type == 'RGB':
  data_folder = '/content/sample_data/'
  #root = os.path.join(data_folder, '2750/')
  root = '2750/'
  download_ON = os.path.exists(root)

  if not download_ON:
    # This can be long...
    #os.chdir(data_folder)
    os.system('wget http://madm.dfki.de/files/sentinel/EuroSAT.zip') #Just RGB Bands
    !unzip EuroSAT.zip
    download_ON = True
elif EuroSat_Type == 'ALL':
    root = 'ds/images/remote_sensing/otherDatasets/sentinel_2/tif/'
    download_ON = os.path.exists(root)
    if not download_ON:
      os.system('wget http://madm.dfki.de/files/sentinel/EuroSATallBands.zip') #All bands
      !unzip EuroSATallBands.zip
      download_ON = True

In [8]:
geo_df = pd.read_csv("metadata_clustered.csv")

In [9]:
data = torchvision.datasets.DatasetFolder(root=root,loader = iloader, transform = None, extensions = 'tif')
labels = [v[1] for (i, v) in enumerate(data)]

In [12]:
def prepare_input_data(geo_df, target_task, group_by = "country",
                       labels = None, 
                       train_size = 320, test_size = 320, target_size = 1600):
    
    geo_dict = geo_df.to_dict()
    groups = list(set(geo_dict[group_by].values()))
    groups = [x for x in groups if str(x) != "nan"]
    id_groups = dict.fromkeys(groups)
    for k in id_groups.keys():
        id_groups[k] = [v for (i, v) in enumerate(geo_dict["id"]) if geo_dict[group_by][i] == k]

    # create a dictionary for input data
    
    input_data = {
        "source_task": list(set(id_groups.keys()) - set([target_task])),
        "target_task": target_task
    }

    
    # all data, both source and target
    
    input_data["data_dict"] = {}
    for k in geo_dict.keys():
        input_data["data_dict"][k] = [geo_dict[k][i] for (i, v) in enumerate(geo_dict[group_by].values()) if str(v) != "nan"]


        
    # split indices for source and target
    
    input_data["idx_source"] = [i for (i, v) in enumerate(input_data["data_dict"][group_by]) if v != input_data["target_task"]]
    input_data["idx_target"] = [i for (i, v) in enumerate(input_data["data_dict"][group_by]) if v == input_data["target_task"]]

    target_labels = list(set([labels[i] for i in input_data["idx_target"]]))

    
    # For source data, create a dictionary to record the countries
    
    input_data["source_dict"] = {}
    for k in geo_dict.keys():
        input_data["source_dict"][k] = [input_data["data_dict"][k][i] for i in input_data["idx_source"] if labels[i] in target_labels]
    
    # rewrite the source tasks, because some countries may have non-overlapping labels with the target task
    input_data["source_task"] = list(set(input_data["source_dict"][group_by]))

    # resample the target to make the number of samples is fixed
    
   # if len(input_data["idx_target"]) >= target_size:
       # input_data["idx_target"] = random.sample(input_data["idx_target"], k = target_size)
    #else:
       # input_data["idx_target"] = random.choices(input_data["idx_target"], k = target_size)
        
    
    # split the target data into train / validation / test sets
    
    y_target = [labels[i] for i in input_data["idx_target"]]
    input_data["idx_train"], idx_rest, _, y_rest = train_test_split(input_data["idx_target"],
                                                              y_target,
                                                             train_size = train_size,
                                                              random_state = 0, shuffle = True)
    
    input_data["idx_val"], input_data["idx_test"], _, _ = train_test_split(idx_rest,
                                                              y_rest,
                                                              train_size = train_size,
                                                              test_size = test_size,
                                                              random_state = 0, shuffle = True)
    

    return input_data


In [13]:
input_data = prepare_input_data(geo_df, target_task, group_by = "cluster", 
                                labels = labels,
                               target_size = target_size)

In [14]:
np.random.seed(0)
torch.cuda.manual_seed(0)
random.seed(0)

In [15]:
# prepare data ---

target_val_loader =  torch.utils.data.DataLoader(torch.utils.data.Subset(data, input_data["idx_val"]), 
                                              batch_size = 16, shuffle = True, num_workers = 0)
target_train_loader =  torch.utils.data.DataLoader(torch.utils.data.Subset(data, input_data["idx_train"]), 
                                                  batch_size = 16, shuffle = True, num_workers = 0)
target_test_loader =  torch.utils.data.DataLoader(torch.utils.data.Subset(data, input_data["idx_test"]), 
                                                  batch_size = 16, shuffle = True, num_workers = 0)



# initialize hyperparameters ---

bandit_selects = [None]
alpha = dict.fromkeys(input_data["source_task"], [1])
beta = dict.fromkeys(input_data["source_task"], [1])
pi = dict.fromkeys(input_data["source_task"], [0])


In [None]:
net, bandit_selects, accs, alpha, beta, pi = bandit_selection(data, input_data, 
                                                            n_epochs = 2, n_it = 3,
                                                            algorithm = algorithm, iter_samples = 160,
                                                           output_path = output_path)

  return 1/(1+np.exp(-z))


In [None]:
test_performance = validation(net, target_test_loader)
pd.DataFrame({"test_acc": test_performance,
             "algorithm": algorithm,
             "target_size": target_size,
             "target_task": target_task},
            index = [0]).to_csv(output_path / Path(str(target_task) + "_" + algorithm + "_" + str(target_size) + "_test_acc.csv"))