In [18]:
from preprocessing import *
from sklearn.model_selection import KFold
import argparse
from model import *
from train import test
import torch.optim as optim

from MatrixVectorizer import *


In [19]:
# load csvs as numpy
lr_data_path = '../data/lr_train.csv'
hr_data_path = '../data/hr_train.csv'

lr_train_data = np.loadtxt(lr_data_path, delimiter=',')
hr_train_data = np.loadtxt(hr_data_path, delimiter=',')
lr_train_data[lr_train_data < 0] = 0
np.nan_to_num(lr_train_data, copy=False)

hr_train_data[hr_train_data < 0] = 0
np.nan_to_num(hr_train_data, copy=False)

# map the anti-vectorize function to each row of the lr_train_data

lr_train_data_vectorized = np.array([MatrixVectorizer.anti_vectorize(row, 160) for row in lr_train_data])
hr_train_data_vectorized = np.array([MatrixVectorizer.anti_vectorize(row, 260) for row in hr_train_data])


In [20]:

subjects_adj,subjects_labels = lr_train_data_vectorized, hr_train_data_vectorized

In [21]:
num_splt = 3
epochs = 10
lr = 0.00005
lmbda = 5000
lr_dim = 160
hr_dim = 320
hidden_dim = 512
padding = 30

args = argparse.Namespace()
args.epochs = epochs
args.lr = lr
args.lmbda = lmbda
args.lr_dim = lr_dim
args.hr_dim = hr_dim
args.hidden_dim = hidden_dim
args.padding = padding


In [22]:
cv = KFold(n_splits=3, random_state=42, shuffle=True)

In [23]:
# ks = [0]
ks = [0.7, 0.3]
model = GSRNet(ks, args)

In [24]:
criterion = nn.L1Loss()

def train(model, optimizer, subjects_adj,subjects_labels, args):
  
  all_epochs_loss = []
  no_epochs = args.epochs
  model.train()

  for epoch in range(no_epochs):
    epoch_loss = []
    epoch_error = []

    for lr,hr in zip(subjects_adj,subjects_labels):      
      lr = torch.from_numpy(lr).type(torch.FloatTensor)
      hr = torch.from_numpy(hr).type(torch.FloatTensor)
      
      model_outputs,net_outs,start_gcn_outs,layer_outs = model(lr)
      # model_outputs  = unpad(model_outputs, args.padding)
      # weights = unpad(model.layer.weights, args.padding)
      

      padded_hr = pad_HR_adj(hr,args.padding)
      eig_val_hr, U_hr = torch.linalg.eigh(padded_hr, UPLO='U')

      # print the shapes of the outputs
      # print(f"{net_outs.shape} ; {start_gcn_outs.shape}")
      # print(f"{weights.shape} ; {U_hr.shape}")
      # print(f"{model_outputs.shape} ; {hr.shape}")
      
      # loss = criterion(net_outs, start_gcn_outs) + criterion(model.layer.weights,U_hr) + args.lmbda * criterion(model_outputs, hr) 
      # loss = criterion(model_outputs, hr) 
      loss = args.lmbda * criterion(net_outs, start_gcn_outs) + criterion(model.layer.weights,U_hr) + criterion(model_outputs, padded_hr) 

      
      error = criterion(model_outputs, padded_hr)
      
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      epoch_loss.append(loss.item())
      epoch_error.append(error.item())
      
    print("Epoch: ",epoch+1, "Loss: ", np.mean(epoch_loss), "Error: ", np.mean(epoch_error))
    all_epochs_loss.append(np.mean(epoch_loss))

In [25]:
# print(model)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# optimizer = optim.SGD(model.parameters(), lr=args.lr)

for train_index, test_index in cv.split(subjects_adj):
    subjects_adj_train = subjects_adj[train_index]  # Get training data 
    subjects_adj_test = subjects_adj[test_index]   # Get testing data 
    subjects_ground_truth_train = subjects_labels[train_index]
    subjects_ground_truth_test = subjects_labels[test_index]

    train(model, optimizer, subjects_adj_train, subjects_ground_truth_train, args)
    test(model, subjects_adj_test, subjects_ground_truth_test, args)

Epoch:  1 Loss:  167.78789940902166 Error:  99.00568245777062
Epoch:  2 Loss:  120.86639038154057 Error:  98.9898308051883
Epoch:  3 Loss:  119.41274318524769 Error:  98.98385526585791
Epoch:  4 Loss:  118.14172511441367 Error:  98.98142526085887
Epoch:  5 Loss:  116.99101482970374 Error:  98.98006225696632
Epoch:  6 Loss:  115.92167899438313 Error:  98.97875495627522
Epoch:  7 Loss:  114.92061139856067 Error:  98.97756719150182
Epoch:  8 Loss:  113.98398394244057 Error:  98.97731903208685
Epoch:  9 Loss:  113.10207540648324 Error:  98.97681726482031
Epoch:  10 Loss:  112.2802266393389 Error:  98.97660394618288
0.2599385976791382
0.20166757702827454
0.17209911346435547
0.16911724209785461
0.18098625540733337
0.16789790987968445
0.20097719132900238
0.1802479773759842
0.22633971273899078
0.1820221245288849
0.20679205656051636
0.16337832808494568
0.15906211733818054
0.2293432205915451
0.38961100578308105
0.1844160109758377
0.18851490318775177
0.1682729721069336
0.17128974199295044
0.17049