In [None]:
import sys
import os
sys.path.insert(0, os.path.abspath('..'))
from os.path import expanduser

%load_ext autoreload
%autoreload 2

from dotenv import load_dotenv
load_dotenv()

from google.cloud import storage
from project_config import GCP_PROJECT_NAME, DATASET_JSON_PATH

gcp_client = storage.Client(project=GCP_PROJECT_NAME)

In [None]:
import os, torch
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32" #to prevent cuda out of memory error
torch.cuda.empty_cache()

#For reproducibility
torch.manual_seed(13)

In [None]:
from experiment_configs.configs import lora_config, satmae_large_config_lora_methodA

config = satmae_large_config_lora_methodA
lora_config = lora_config

In [None]:
from torch.utils.data import ConcatDataset
import json
from utils.rastervision_pipeline import observation_to_scene, scene_to_training_ds, scene_to_validation_ds, scene_to_inference_ds
from utils.data_management import observation_factory, characterize_dataset
import random

from utils.rastervision_pipeline import GoogleCloudFileSystem
GoogleCloudFileSystem.storage_client = gcp_client

#set the seed
random.seed(13)

# get the current working directory
root_dir = os.getcwd()

# define the relative path to the dataset JSON file
json_rel_path = '../' + DATASET_JSON_PATH

# combine the root directory with the relative path
json_abs_path = os.path.join(root_dir, json_rel_path)

dataset_json = json.load(open(json_abs_path, 'r'))

all_scenes = [observation_to_scene(config, observation) for observation in observation_factory(dataset_json)]
cluster_ids = [observation.cluster_id for observation in observation_factory(dataset_json)]

In [None]:
import numpy as np
import random

val_cluster_id = np.unique(cluster_ids).max() + 1
for cid in np.unique(cluster_ids):
    scene_idx = [i for i in range(len(cluster_ids)) if cluster_ids[i] == cid]
    val_idx = random.sample(scene_idx, 1)[0]
    cluster_ids[val_idx] = val_cluster_id

In [None]:
training_datasets = [scene_to_training_ds(config, scene) for scene, cid in zip(all_scenes, cluster_ids) if cid != val_cluster_id]
validation_datasets = [scene_to_inference_ds(config, scene, full_image=False, stride=int(config.tile_size/2)) for scene, cid in zip(all_scenes, cluster_ids) if cid == val_cluster_id]

In [None]:
from torch.utils.data import ConcatDataset

train_dataset_merged = ConcatDataset(training_datasets)
val_dataset_merged = ConcatDataset(validation_datasets)

In [None]:
from models.model_factory import model_factory, print_trainable_parameters
from ml.optimizer_factory import optimizer_factory
from ml.learner_factory import learner_factory
from experiment_configs.schemas import ThreeClassVariants

_, _, n_channels = training_datasets[0].scene.raster_source.shape
model = model_factory(
    config,
    n_channels=n_channels,
    config_lora=lora_config
)

optimizer = optimizer_factory(config, model)

learner = learner_factory(
    config=config,
    model=model,
    optimizer=optimizer,
    train_ds=train_dataset_merged,  # for development and debugging, use training_datasets[0] or similar to speed up
    valid_ds=val_dataset_merged,  # for development and debugging, use training_datasets[1] or similar to speed up
    output_dir=expanduser("~/sandmining-watch/out/OUTPUT_DIR")
)
print_trainable_parameters(learner.model)

In [None]:
# learner.initialize_wandb_run()
learner.train(epochs=1)

In [None]:
from utils.rastervision_pipeline import scene_to_training_ds, scene_to_validation_ds
from torch.utils.data import ConcatDataset
from sklearn.model_selection import GroupKFold, LeavePGroupsOut
import numpy as np

from models.model_factory import model_factory, print_trainable_parameters
from ml.optimizer_factory import optimizer_factory
from ml.learner_factory import learner_factory
from experiment_configs.schemas import ThreeClassVariants

import wandb
import gc



class CrossValidator:
    def __init__(self, scenes, cluster_ids, split_groups=None, num_splits=None, size_validation_group=None) -> None:
        """
        split_groups is for manually assigning splits. Should be a list (number of splits in length) containing a list of 
        training and validation cluster ids. i.e. [([1, 2], [3]), ([4, 5], [6])].
        
        num_splits is the number of splits
        
        size_validation_group is used for leave p groups validation set and the rest in the training set
        """
        assert (split_groups is not None) ^ (num_splits is not None) ^ (size_validation_group is not None), "Only one of splits, num_splits, size_validation_group should not be None"
        
        self.scenes = scenes
        self.cluster_ids = np.array(cluster_ids)
        self.splits = None
        self.split_groups = split_groups
        self.num_splits = num_splits
        
        if self.split_groups is not None:
            self.splits = []
            for split in self.split_groups:
                train_split = []
                for cid in split[0]:
                    assert cid in self.cluster_ids, f"Training Cluster {cid} not in the available clusters"
                    train_split += [i for i in range(len(self.cluster_ids)) if self.cluster_ids[i] == cid]
                val_split = []
                for cid in split[1]:
                    assert cid in self.cluster_ids, f"Validation Cluster {cid} not in the available clusters"
                    val_split += [i for i in range(len(self.cluster_ids)) if self.cluster_ids[i] == cid]
                self.splits.append((np.array(train_split), np.array(val_split)))
        elif self.num_splits is not None:
            gkf = GroupKFold(self.num_splits)
            scenes = list(gkf.split(np.arange(len(self.cluster_ids)), groups=self.cluster_ids))
            self.splits = scenes
            self.split_groups = []
            for split in self.splits:
                train_cids = np.unique(self.cluster_ids[split[0]])
                val_cids = np.unique(self.cluster_ids[split[1]])
                self.split_groups.append((train_cids, val_cids))
        else:   # size_validation_group is not None
            lpgo = LeavePGroupsOut(n_groups=size_validation_group)
            scenes = list(lpgo.split(np.arange(len(self.cluster_ids)), groups=self.cluster_ids))
            self.splits = scenes
            self.split_groups = []
            for split in self.splits:
                train_cids = np.unique(self.cluster_ids[split[0]])
                val_cids = np.unique(self.cluster_ids[split[1]])
                self.split_groups.append((train_cids, val_cids))
        self.num_splits = len(self.splits)
        
    def _train(self, model_config, lora_config, train_split, val_split, num_epochs, wandb_group_name, run_name, model_weights_output_folder):
        train_ds = [scene_to_training_ds(model_config, self.scenes[sid]) for sid in train_split]
        valid_ds = [scene_to_inference_ds(model_config, self.scenes[sid], full_image=False, stride=int(config.tile_size/2)) for sid in val_split]
        train_ds_merged = ConcatDataset(train_ds)
        valid_ds_merged = ConcatDataset(valid_ds)
        
        torch.cuda.empty_cache()
        
        _, _, n_channels = train_ds[0].scene.raster_source.shape
        model = model_factory(
            model_config,
            n_channels=n_channels,
            config_lora=lora_config
        )
        
        output_dir = expanduser(model_weights_output_folder + run_name) if model_weights_output_folder is not None else None
        
        optimizer = optimizer_factory(config, model)
        learner = learner_factory(
            config=model_config,
            model=model,
            optimizer=optimizer,
            train_ds=train_ds_merged,  # for development and debugging, use training_datasets[0] or similar to speed up
            valid_ds=valid_ds_merged,  # for development and debugging, use training_datasets[1] or similar to speed up
            output_dir=output_dir,
        )
        
        print_trainable_parameters(learner.model)
        # learner.initialize_wandb_run(run_name=run_name, group=wandb_group_name)
        learner.train(epochs=num_epochs)
        # wandb.finish()
        del learner
        del model
        del train_ds
        del valid_ds
        gc.collect()
        torch.cuda.empty_cache()
        
    def cross_val(self, model_config, num_epochs, lora_config=None, wandb_group_name=None, model_weights_output_folder=None):
        for i, (train_split, valid_split) in enumerate(self.splits):
            self._train(model_config, 
                        lora_config, 
                        train_split, 
                        valid_split, 
                        num_epochs,
                        wandb_group_name,
                        "val_cids_" + "_".join([str(j) for j in self.split_groups[i][1]]),
                        model_weights_output_folder)

In [None]:
cv = CrossValidator(all_scenes, cluster_ids, size_validation_group=1)

In [None]:
cv.split_groups = cv.split_groups[2:]
cv.split_groups

In [None]:
cv.splits = cv.splits[2:]
cv.splits

In [None]:
cv.cross_val(config,
             1, 
             lora_config=lora_config, 
             wandb_group_name="SatMAE Large LoRA Method A", 
             model_weights_output_folder="~/sandmining-watch/out/CrossValTest/",
             )

In [None]:
valid_ds = [scene_to_inference_ds(config, cv.scenes[sid], full_image=False, stride=int(config.tile_size/2)) for sid in cv.splits[0][1]]

In [None]:
val_ds_merged = ConcatDataset(valid_ds)
val_ds_merged

In [None]:
from torch.utils.data import DataLoader

In [None]:
num_workers = 10
val_dl = val_dl = DataLoader(
            val_ds_merged,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=num_workers,
            persistent_workers=True if num_workers > 0 else False,
            worker_init_fn=lambda x: torch.multiprocessing.set_sharing_strategy("file_system") if num_workers > 0 else None,
            pin_memory=True)

In [None]:
for i, (x, y) in enumerate(val_dl):
    print(f"{i}_x: {x.shape}\t\t{i}_y: {y.shape}")