In [None]:
import os

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import pandas as pd
import torch
import random
from sklearn.decomposition import PCA
from sklearn.mixture import GaussianMixture
from sklearn.model_selection import train_test_split, KFold
import numpy as np
import torch.optim as optim
from pprint import pprint
from tqdm import tqdm

from MatrixVectorizer import MatrixVectorizer
from preprocessing import antivectorize_df
from model import GSRNet, Discriminator
from train import train_gan, test_gan
from utils import track_memory, compute_degree_matrix_normalization_batch_numpy, get_parser, evaluate, plot_metrics_fold, LR_size, HR_size

In [None]:
random_seed = 42
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)

# Check for CUDA (GPU support) and set device accordingly
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA is available. Using GPU.")
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)  # For multi-GPU setups
    # Additional settings for ensuring reproducibility on CUDA
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(device)

# Import Data

In [None]:
# A_LR_train = pd.read_csv("../data/lr_train.csv")
# A_HR_train = pd.read_csv("../data/hr_train.csv")
# A_LR_test = pd.read_csv("../data/lr_test.csv")

# np.save('A_LR_train_matrix.npy', antivectorize_df(A_LR_train, LR_size))
# np.save('A_HR_train_matrix.npy', antivectorize_df(A_HR_train, HR_size))
# np.save('A_LR_test_matrix.npy', antivectorize_df(A_LR_test, LR_size))

In [None]:
A_LR_train_matrix = np.load('A_LR_train_matrix.npy')
A_HR_train_matrix = np.load('A_HR_train_matrix.npy')
A_LR_test_matrix = np.load("A_LR_test_matrix.npy")

print(A_LR_train_matrix.shape)
print(A_HR_train_matrix.shape)
print(A_LR_test_matrix.shape)

# Parameters

In [None]:
parser = get_parser()
# Create an empty Namespace to hold the default arguments
args = parser.parse_args([])
pprint(args.__dict__)

In [None]:
# SIMULATING THE DATA: EDIT TO ENTER YOUR OWN DATA
X = A_LR_train_matrix  # np.random.normal(0, 0.5, (167, 160, 160))
Y = A_HR_train_matrix  # np.random.normal(0, 0.5, (167, 288, 288))
print(X.shape)
print(Y.shape)

In [None]:
X = compute_degree_matrix_normalization_batch_numpy(X)
A_LR_test_matrix = compute_degree_matrix_normalization_batch_numpy(A_LR_test_matrix)
print(X.shape)

# K-Fold Cross Validation

In [None]:
cv = KFold(n_splits=args.splits, random_state=random_seed, shuffle=True)

best_model_fold_list = []
data_fold_list = []
i = 1
for train_index, test_index in cv.split(X):

    print(f"----- Fold {i} -----")

    subjects_adj, test_adj, subjects_ground_truth, test_ground_truth = (
        X[train_index],
        X[test_index],
        Y[train_index],
        Y[test_index],
    )
    data_fold_list.append(
        (subjects_adj, test_adj, subjects_ground_truth, test_ground_truth)
    )

    netG = GSRNet(args).to(device)
    optimizerG = optim.Adam(netG.parameters(), lr=args.lr)

    netD = Discriminator(args).to(device)
    optimizerD = optim.Adam(netD.parameters(), lr=args.lr)

    track_memory()
    # GAN model
    return_model = train_gan(
        netG,
        optimizerG,
        netD,
        optimizerD,
        subjects_adj,
        subjects_ground_truth,
        args,
        test_adj=test_adj,
        test_ground_truth=test_ground_truth,
    )
    track_memory()

    test_mae = test_gan(return_model, test_adj, test_ground_truth, args)
    train_mae = test_gan(return_model, subjects_adj, subjects_ground_truth, args)
    print(f"Train MAE: {train_mae:.6f}, Val MAE: {test_mae:.6f}")
    best_model_fold_list.append(return_model)

    i += 1

In [None]:
CAL_GRAPH = False

res_list = []

for i in range(args.splits):
    _, test_adjs, _, gt_matrices = data_fold_list[i]
    model = best_model_fold_list[i]
    model.eval()
    pred_matrices = np.zeros(gt_matrices.shape)
    with torch.no_grad():
        for j, test_adj in enumerate(test_adjs):
            pred = model(torch.from_numpy(test_adj))[0]
            pred = torch.clamp(pred, min=0.0, max=1.0)
            pred = pred.cpu()
            pred_matrices[j] = pred
    res_list.append(evaluate(pred_matrices, gt_matrices, cal_graph=CAL_GRAPH))

pd.DataFrame(res_list)

In [None]:
plot_metrics_fold(res_list)

In [None]:
for i in range(args.splits):
    _, test_adjs, _, gt_matrices = data_fold_list[i]
    model = best_model_fold_list[i]
    model.eval()

    output_pred_list = []
    with torch.no_grad():
        for test_adj in tqdm(test_adjs):
            output_pred = model(torch.from_numpy(test_adj))[0].cpu()
            output_pred = MatrixVectorizer.vectorize(output_pred).tolist()
            output_pred_list.append(output_pred)

    output_pred_stack = np.stack(output_pred_list, axis=0)
    output_pred_1d = output_pred_stack.flatten()

    df = pd.DataFrame(
        {
            "ID": [i + 1 for i in range(len(output_pred_1d))],
            "Predicted": output_pred_1d.tolist(),
        }
    )

    df.to_csv("predictions_fold_" + str(i + 1) + ".csv", index=False)

# Final Model

In [None]:
A_HR_train = pd.read_csv("../data/hr_train.csv")

pca = PCA(n_components=0.99, whiten=False)
A_HR_train_pca = pca.fit_transform(A_HR_train)
print(f"HR Train PCA shape: {A_HR_train_pca.shape}")

gm = GaussianMixture(n_components=5, random_state=random_seed)
A_HR_train_label = gm.fit_predict(A_HR_train_pca)
unique, counts = np.unique(A_HR_train_label, return_counts=True)
print(np.asarray((unique, counts)).T)

X = np.load("A_LR_train_matrix.npy")
y = np.load("A_HR_train_matrix.npy")

X = compute_degree_matrix_normalization_batch_numpy(X)

n_sample = X.shape[0]
X_train, X_val, y_train, y_val = train_test_split(
    X.reshape(n_sample, -1),
    y.reshape(n_sample, -1),
    test_size=0.10,
    random_state=random_seed,
    stratify=A_HR_train_label,
)

X_train = X_train.reshape(-1, LR_size, LR_size)
X_val = X_val.reshape(-1, LR_size, LR_size)
y_train = y_train.reshape(-1, HR_size, HR_size)
y_val = y_val.reshape(-1, HR_size, HR_size)

print("Train size:", len(X_train))
print("Val size:", len(X_val))

netG = GSRNet(args).to(device)
print("here")
optimizerG = optim.Adam(netG.parameters(), lr=args.lr)

print("here")
netD = Discriminator(args).to(device)
optimizerD = optim.Adam(netD.parameters(), lr=args.lr)
print("here")

track_memory()
# GAN model
final_model = train_gan(
    netG,
    optimizerG,
    netD,
    optimizerD,
    X_train,
    y_train,
    args,
    test_adj=X_val,
    test_ground_truth=y_val,
)
track_memory()

In [None]:
pprint(args.__dict__)

In [None]:
final_model.eval()
pred_train_matrices = np.zeros(y_train.shape)
pred_val_matrices = np.zeros(y_val.shape)
with torch.no_grad():
    for j, test_adj in enumerate(X_train):
        pred = final_model(torch.from_numpy(test_adj))[0]
        pred = torch.clamp(pred, min=0.0, max=1.0)
        pred = pred.cpu()
        pred_train_matrices[j] = pred

    print("Train")
    evaluate(pred_train_matrices, y_train)

    for j, test_adj in enumerate(X_val):
        pred = final_model(torch.from_numpy(test_adj))[0]
        pred = torch.clamp(pred, min=0.0, max=1.0)
        pred = pred.cpu()
        pred_val_matrices[j] = pred

    print("Val")
    evaluate(pred_val_matrices, y_val)

In [None]:
output_pred_list = []
final_model.eval()
with torch.no_grad():
    for i in tqdm(range(A_LR_test_matrix.shape[0])):
        output_pred = final_model(torch.Tensor(A_LR_test_matrix[i]))[0]
        output_pred = torch.clamp(output_pred, min=0.0, max=1.0)
        output_pred = output_pred.cpu()
        output_pred = MatrixVectorizer.vectorize(output_pred).tolist()
        output_pred_list.append(output_pred)

In [None]:
output_pred_stack = np.stack(output_pred_list, axis=0)
output_pred_1d = output_pred_stack.flatten()
assert output_pred_1d.shape == (4007136,)

In [None]:
df = pd.DataFrame(
    {
        "ID": [i + 1 for i in range(len(output_pred_1d))],
        "Predicted": output_pred_1d.tolist(),
    }
)

df

In [None]:
df.to_csv("final_model.csv", index=False)