In [None]:
print(f"Script begins the run")
import argparse
import logging
import logging.config
import os
import sys
from argparse import ArgumentParser

import torch
from omegaconf import OmegaConf

from lib import config_utils, data_utils, utils
from lib.formatter import RawFormatter
from lib.logger import prepare_logger

print(f"Packages read")

parser = ArgumentParser(
    description='U-TILISE: A Sequence-to-sequence Model for Cloud Removal in Optical Satellite Time Series (Training)',
    formatter_class=RawFormatter
)
parser.add_argument(
    'config_file', type=str,
    default="./configs/config_sen12mscrts_train.yaml",
    help='yaml configuration file to augment/overwrite the settings in configs/config_sen12mscrts_train.yaml'
)
parser.add_argument(
    '--save_dir', type=str, 
    default="./results",
    help='Path to the directory where models and logs should be saved'
)
parser.add_argument('--dataset_name', default="allclear", type=str, help='Use Weights & Biases instead of TensorBoard')
parser.add_argument('--wandb', action='store_true', default=False, help='Use Weights & Biases instead of TensorBoard')
parser.add_argument('--wandb_project', type=str, default='utilise', help='Wandb project name')

args, _ = parser.parse_known_args()

prog_name = 'U-TILISE: A Sequence-to-sequence Model for Cloud Removal in Optical Satellite Time Series (Training)'
print('\n{}\n{}\n'.format(prog_name, '=' * len(prog_name)))

if not os.path.exists(args.config_file):
    raise FileNotFoundError(f'ERROR: Cannot find the yaml configuration file: {args.config_file}')

# Import the user configuration file
cfg_custom = config_utils.read_config(args.config_file)

if not cfg_custom:
    sys.exit(1)

# Augment/overwrite the default parameter settings with the runtime arguments given by the user
cfg_default = config_utils.read_config('configs/default_sen12.yaml')
config = OmegaConf.merge(cfg_default, cfg_custom)
config.output.output_directory = args.save_dir

if args.wandb:
    config.wandb = OmegaConf.create()
    config.wandb.project = args.wandb_project

# Create the output directory. The name of the output directory is a combination of the current date, time, and an
# optional suffix.
config.output.experiment_folder = utils.create_output_directory(config)

# Set up the logger
log_file = os.path.join(config.output.experiment_folder, 'run.log') if config.output.experiment_folder else None
logger = prepare_logger('root_logger', level=logging.INFO, log_to_console=True, log_file=log_file)

# Print runtime arguments to the console
logger.info('Configuration file: %s', args.config_file)
logger.info('\nSettings\n--------\n')
config_utils.print_config(config, logger=logger)

if config.misc.random_seed is not None:
    utils.set_seed(config.misc.random_seed)

In [None]:
# Default data and model settings (i.e., settings used during training)
if args.dataset_name == "allclear":

    from dataset_wrapper import get_loader
    train_loader = get_loader(config)
    val_loader = get_loader(config)

elif args.dataset_name == "original":
    config_file_train = 'configs/demo.yaml'
    config_file_test = 'configs/config_earthnet2021_test_simulation.yaml'
    def get_dataloader_testdata(
        config_file_train: str,
        config_file_test: str,
        run_mode: str = 'test'
    ) -> torch.utils.data.dataloader.DataLoader:
        if not os.path.isfile(config_file_train):
            raise FileNotFoundError(f'Cannot find the configuration file used during training: {config_file_train}\n')
    
        if not os.path.isfile(config_file_test):
            raise FileNotFoundError(f'Cannot find the test configuration file: {config_file_test}\n')
    
        # Read the configuration file used during training
        config = config_utils.read_config(config_file_train)
    
        # Merge generic data settings (used during training) with test-specific data settings
        config_testdata = config_utils.read_config(config_file_test)
        config.data.update(config_testdata.data)
        if 'mask' in config_testdata:
            config.mask.update(config_testdata.mask)
        config.misc.run_mode = run_mode
    
        # Get the data loader
        dset = data_utils.get_dataset(config, phase=run_mode)
        dataloader = torch.utils.data.DataLoader(
            dataset=dset, batch_size=1, shuffle=False, num_workers=8, drop_last=False
        )
        
        return dataloader
    train_loader = get_dataloader_testdata(config_file_train, config_file_test)
    val_loader = get_dataloader_testdata(config_file_train, config_file_test)

# ------------------------------------------------- Data loaders ------------------------------------------------- #
# logger.info('\nInitialize data loader (training set)...')
# train_loader = data_utils.get_dataloader(
#     config, phase='train', pin_memory=config.misc.pin_memory, drop_last=True, logger=logger
# )
# logger.info('Initialize data loader (validation set)...\n')
# val_loader = data_utils.get_dataloader(
#     config, phase='val', pin_memory=config.misc.pin_memory, drop_last=False, logger=logger
# )

logger.info('Number of training samples: %d', train_loader.dataset.__len__())
logger.info('Number of validation samples: %d', val_loader.dataset.__len__())
logger.info('Variable sequence lengths: %r\n', train_loader.dataset.variable_seq_length)

# ----------------------------------------- Prepare the output directory ----------------------------------------- #
logger.info('\nPrepare output folders and files\n--------------------------------\n')

# Save the path of the checkpoint directory
config.output.checkpoint_dir = os.path.join(config.output.experiment_folder, 'checkpoints')
os.makedirs(config.output.checkpoint_dir, exist_ok=True)
logger.info('Model weights will be stored in: %s\n', config.output.checkpoint_dir)

# Write the runtime configuration to file
config_file = os.path.join(config.output.experiment_folder, 'config.yaml')
config_utils.write_config(config, config_file)

# ----------------------------------------------- Define the model ----------------------------------------------- #
logger.info('\nModel Architecture\n------------------\n')
logger.info('Architecture: %s', config.method.model_type)

input_dim = train_loader.dataset.num_channels
model, args_model = utils.get_model(config, input_dim, logger)
logger.info('Number of trainable parameters: %d\n', utils.count_model_parameters(model))

# Log model parameters to file
config_file = os.path.join(config.output.experiment_folder, 'model_config.yaml')
config_utils.write_config(OmegaConf.create({config.method.model_type: args_model}), config_file)

# Write model architecture to txt file
if config.output.plot_model_txt:
    file = os.path.join(config.output.experiment_folder, 'model_parameters.txt')
    logger.info('Writing model architecture to file: %s\n', file)
    utils.write_model_structure_to_file(
        file, model, config.training_settings.batch_size, train_loader.dataset.seq_length, input_dim,
        train_loader.dataset.image_size
    )

# --------------------------------------------------- Training --------------------------------------------------- #
logger.info('\nPrepare training\n----------------\n')
logger.info('Python version: %s', sys.version)
logger.info('Torch version: %s', torch.__version__)
logger.info('CUDA version: %s\n', torch.version.cuda)

# Get optimizer and learning rate scheduler
optimizer = utils.get_optimizer(config, model, logger)
scheduler = utils.get_scheduler(config, optimizer, logger)

if config.misc.random_seed is not None:
    utils.set_seed(config.misc.random_seed)

# Initialize the trainer and start training
trainer = utils.get_trainer(config, train_loader, val_loader, model, optimizer, scheduler)
trainer.train()

In [None]:
for batch in train_loader: break

In [None]:
print(f"keys: {batch.keys()}")
print(f"""x: {batch["x"].shape}""")
print(f"""y: {batch["y"].shape}""")
print(f"""masks: {batch["masks"].shape}""")
print(f"""position_days: {batch["position_days"]}""")
print(f"""days: {batch["days"]}""")
print(f"""sample_index: {batch["sample_index"]}""")
print(f"""c_index_rgb: {batch["c_index_rgb"]}""")
print(f"""c_index_nir: {batch["c_index_nir"]}""")
print(f"""cloud_mask: {batch["cloud_mask"].shape}""")

In [None]:
import matplotlib.pyplot as plt

In [None]:
x = batch["x"][0][2].permute(1,2,0)
x = x * 5
x = torch.clip(x, 0, 1)
plt.figure(), plt.imshow(x[:,:,(2,1,0)])

x = batch["masks"][0][2].permute(1,2,0)
x = x * 5
x = torch.clip(x, 0, 1)
plt.figure(), plt.imshow(x[:,:,0])

x = batch["cloud_mask"][0][2].permute(1,2,0)
x = x * 5
x = torch.clip(x, 0, 1)
plt.figure(), plt.imshow(x[:,:,0])

In [None]:
import sys, os
if "ck696" in os.getcwd():
    sys.path.append("/share/hariharan/ck696/allclear")
else:
    sys.path.append("/share/hariharan/cloud_removal/allclear")

from dataset.dataloader_v1 import CRDataset
from torch.utils.data import DataLoader, Dataset

class CRDatasetWrapper(Dataset):
    def __init__(self, original_dataset):
        self.original_dataset = original_dataset

    def __len__(self):
        return 2376

    def __getitem__(self, idx):
        batch = self.original_dataset[idx]
        # cond_image = batch["input_images"][:4,...].reshape(9,256,256) * 2 - 1
        # gt_image = batch["target"][:4,...].reshape(3,256,256) * 2 - 1
        
        return {"x": batch["input_images"].permute(1,0,2,3), 
                "y": batch["target"].permute(1,0,2,3),
                "masks": batch["input_cld_shdw"].permute(1,0,2,3).max(dim=1, keepdim=True).values,
                "position_days": batch["time_differences"],
                "days": batch["time_differences"],
                "sample_index": 0,
                "c_index_rgb": torch.Tensor([0,1,2]),
                "c_index_nir": torch.Tensor([3]),
                "cloud_mask": batch["target_cld_shdw"].permute(1,0,2,3).max(dim=1, keepdim=True).values
               }

import json
# with open('/share/hariharan/cloud_removal/metadata/v3/s2s_tx6_v1.json') as f:
#     metadata = json.load(f)
    
# for i in range(len(metadata)):
#     for j in range(6):
#         # metadata[f"{i}"]["target"][0][1] = "/share/hariharan/cloud_removal/MultiSensor/dataset_30k_v4/" + metadata[f"{i}"]["target"][0][1].split("dataset_30k_v4")[1]
#         metadata[f"{i}"]["s2_toa"][j][1] = "/share/hariharan/cloud_removal/MultiSensor/dataset_30k_v4/" + metadata[f"{i}"]["s2_toa"][j][1].split("dataset_30k_v4")[1]
#     try:
#         for j in range(6):
#             # metadata[f"{i}"]["target"][0][1] = "/share/hariharan/cloud_removal/MultiSensor/dataset_30k_v4/" + metadata[f"{i}"]["target"][0][1].split("dataset_30k_v4")[1]
#             metadata[f"{i}"]["s1"][j][1] = "/share/hariharan/cloud_removal/MultiSensor/dataset_30k_v4/" + metadata[f"{i}"]["s1"][j][1].split("dataset_30k_v4")[1]
#     except:
#         pass
# save metadata to json
# file_name = '/share/hariharan/cloud_removal/metadata/v3/s2s_tx6_v1_bh.json'
# with open(file_name, 'w') as f:
#     json.dump(metadata, f)

import json
with open('/share/hariharan/cloud_removal/metadata/v3/s2s_tx6_v1_bh.json') as f:
    metadata = json.load(f)
    
clds_shdws = torch.ones(1000, 2, 256, 256)

train_data = CRDataset(metadata, 
                    selected_rois="all", 
                    main_sensor="s2_toa", 
                    aux_sensors=[],
                    aux_data=["cld_shdw"],
                    format="stp",
                    target="s2s",
                    clds_shdws=clds_shdws,
                    tx=6,
                    s2_toa_channels=[4,3,2,8]
                    )
# wrapped_train_data = CRDatasetWrapper(train_data)
# phase_loader = DataLoader(train_data, batch_size=2, shuffle=True, num_workers=2, pin_memory=True)

wrapped_train_data = CRDatasetWrapper(train_data)
phase_loader = DataLoader(wrapped_train_data, batch_size=2, shuffle=True, num_workers=2, pin_memory=True)
for ac_batch in phase_loader: break

In [None]:
for ac_batch in phase_loader: break

In [None]:
for key in ac_batch.keys():
    if len(ac_batch[key].shape) == 1:
        print(f"{key}: {ac_batch[key]}")
    else:
        print(f"{key}: {ac_batch[key].shape}")

In [None]:
x = ac_batch["x"][0][5].permute(1,2,0)
x = x * 5
x = torch.clip(x, 0, 1)
plt.figure(), plt.imshow(x[:,:,:3])

In [None]:
trainer = utils.get_trainer(config, phase_loader, phase_loader, model, optimizer, scheduler)
trainer.train()