In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from glob import glob
import seaborn as sns
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import xarray as xr
import rasterio as rio
import rioxarray
import math
from sklearn.preprocessing import RobustScaler
from sklearn.decomposition import PCA
import pandas as pd
import os

import deep_snow.models
from deep_snow.dataset import norm_dict
from deep_snow.utils import calc_norm, undo_norm, calc_dowy
from tqdm import tqdm
import shutil

In [110]:
parent_dir = "/mnt/c/Users/JackE/uw/courses/aut24/ml_geo/final_data/subsets_v4/train"
files = glob(f"{parent_dir}/*")

In [None]:
s1_variables = [
    "snowon_vv", "snowon_vh", "snowoff_vv", "snowoff_vh"
]
s2_variables = [
    "AOT", "B01", "B02", "B03", "B04", "B05", "B06", "B07",
    "B08", "B09", "B11", "B12", "B8A", "SCL", "WVP", "visual"
]
s1_s2_vars = [s1_variables, s2_variables]
for file in tqdm(files):
    skip_file = False
    ds = xr.open_dataset(file)
    for i, variables in enumerate(s1_s2_vars):
        data_array = np.stack([ds[var].values for var in variables], axis=-1)
        n_samples = np.prod(data_array.shape[:-1])
        n_features = data_array.shape[-1]
        reshaped_data = data_array.reshape(n_samples, n_features)
        nan_mask = np.isnan(reshaped_data)
        if nan_mask.any():
            column_means = np.nanmean(reshaped_data, axis=0)
            reshaped_data[nan_mask] = np.take(column_means, np.where(nan_mask)[1])
        scaler = RobustScaler()
        scaled_data = scaler.fit_transform(reshaped_data)
        pca = PCA(n_components=4)
        pca_result = pca.fit_transform(scaled_data)
        if np.isnan(pca.explained_variance_ratio_).any():
            skip_file = True
            break  # Skip this file
        else:
            if i == 0:
                s1_pc1, s1_pc2 = pca_result[:,0], pca_result[:,1]
            else:
                s2_pc1, s2_pc2, s2_pc3 = pca_result[:,0], pca_result[:,1],  pca_result[:,2]

    if skip_file:
        continue

    fn = os.path.split(file)[-1]
    dowy_1d = calc_dowy(pd.to_datetime(fn.split('_')[4]).dayofyear)
    dowy_array = np.full((128, 128), dowy_1d)

    new_ds = xr.Dataset({
        "aso_sd": (["x", "y"], ds['aso_sd'].values),
        "fcf": (["x", "y"], ds["fcf"].values),
        "elevation": (["x", "y"], ds["elevation"].values),
        "tri": (["x", "y"], ds["tri"].values),
        "tpi": (["x", "y"], ds["tpi"].values),
        "latitude": (["x", "y"], ds["latitude"].values),
        "longitude": (["x", "y"], ds["longitude"].values),
        "s1_pc1": (["samples"], s1_pc1),
        "s1_pc2": (["samples"], s1_pc2),
        "s2_pc1": (["samples"], s2_pc1),
        "s2_pc2": (["samples"], s2_pc2),
        "s2_pc3": (["samples"], s2_pc3),
        "dowy": (["x", "y"], dowy_array)
    })

    new_ds.to_netcdf(f"/mnt/c/Users/JackE/uw/courses/aut24/ml_geo/jack_subsets/ncs/{fn}")

    source = f"/mnt/c/Users/JackE/uw/courses/aut24/ml_geo/final_data/subsets_v4_tif/train/{fn.split('.')[0]}.tif"
    destination = f"/mnt/c/Users/JackE/uw/courses/aut24/ml_geo/jack_subsets/tifs/{fn.split('.')[0]}.tif"
    shutil.copy(source, destination)