# The purpose of this notebook is simply to grab and save measurement data from chips

In [None]:
import torch
import random
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

import torch.nn as nn
import torch.nn.functional as F

import math
import time

In [None]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
print(device)

# 1. Load the data and get measurements

In [None]:
ROOT_PATH = "/scratch1/04703/sravula/UTAFSDataNew/new_data"
CHIP_NUM = 21

In [None]:
import utils

data_dict = utils.grab_chip_data(ROOT_PATH, CHIP_NUM)

## Convert the raw data matrices to only the unique S-parameters

In [None]:
vf_data = utils.matrix_to_sparams(data_dict['vf_matrix'])
gt_data = utils.matrix_to_sparams(data_dict['gt_matrix'])
y_data = utils.matrix_to_sparams(data_dict['y_matrix'])

print("VF S-parameters shape: ", vf_data.shape)
print("GT S-parameters shape: ", gt_data.shape)
print("Y S-parameters shape: ", y_data.shape)

In [None]:
gt_freqs = data_dict['gt_freqs']
y_freqs = data_dict['y_freqs']

print("GT frequencies shape: ", gt_freqs.shape)
print("Y frequencies shape: ", y_freqs.shape)

## Make some variables we will need

In [None]:
N_FREQS = gt_data.shape[-1]
N_SPARAMS = gt_data.shape[0]  

print("N_FREQS: ", N_FREQS)
print("N_SPARAMS: ", N_SPARAMS)

In [None]:
x = torch.from_numpy(gt_data).view(-1, N_FREQS).unsqueeze(0).to(device)

print("x shape: ", x.shape)

## Grab Some Measurements

In [None]:
PROBLEM_TYPE = "sqrt" #[random, equal, forecast, full, log, sqrt]
M = 0.2

kept_inds, missing_inds = utils.get_inds(PROBLEM_TYPE, N_FREQS, M)

M = len(kept_inds) #re-define in case kept_inds is off by 1 or something

print("Number of Ground Truth Frequency Points: ", N_FREQS)
print("Number of Measurements: ", M)
print("Undersampling Ratio: ", M/N_FREQS)

In [None]:
y = torch.clone(x)[:, :, kept_inds]
y_freqs = gt_freqs[kept_inds]

print("y shape: ", y.shape)
print("y_freqs shape: ", y_freqs.shape)

## Visualise measurements

In [None]:
plt.figure()
for i in range(N_SPARAMS):
    plt.plot(gt_freqs, x[0,2*i].cpu(), label=str(i)+" Re")
    plt.plot(gt_freqs, x[0,2*i+1].cpu(), label=str(i)+" Im")
    plt.scatter(y_freqs, y[0,2*i].cpu())
    plt.scatter(y_freqs, y[0,2*i+1].cpu())
if N_SPARAMS <= 10:
    plt.legend()
plt.title("Ground Truth Complex Representation")
plt.xlabel("Frequency")
plt.show()

## Save measurements

In [None]:
SAVE_ROOT = "/scratch1/04703/sravula/Siemens_VF_Data/case21/"

In [None]:
Y_MATRIX = utils.sparams_to_matrix(y)
print(Y_MATRIX.shape)

In [None]:
Y_net = utils.matrix_to_network(Y_MATRIX, y_freqs, "Observations for Chip "+str(CHIP_NUM))

In [None]:
temp = Y_net.s

print(temp.shape)
print(temp.dtype)

In [None]:
WRITE_PTH = SAVE_ROOT + "case" + str(CHIP_NUM) + "_" + PROBLEM_TYPE + "_" + str(M)
print(WRITE_PTH)

In [None]:
Y_net.write_touchstone(WRITE_PTH)