In [1]:
import torch
import torch.nn as nn

import random
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

import time
from tqdm import tqdm

from skrf import Network, Frequency

import utils
from models import RESNET_BACKBONE, RESNET_HEAD, MODULAR_RESNET

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


# Define and Grab Data

In [3]:
def grab_sparams(root_pth, chip_num):
    #first grab the chip
    chip_dict = utils.get_network_from_file(root_pth, chip_num)
    
    out_network = chip_dict["network"]
    out_freqs = out_network.frequency
    
    #resample to minimum length if necessary
    MIN_LEN = 1000
    
    if out_freqs.npoints < MIN_LEN:
        scale_fac = int(np.ceil(MIN_LEN / out_freqs.npoints))
        new_len = scale_fac * (out_freqs.npoints - 1) + 1 #this is smarter scaling that just divides current spacing
        
        out_network.resample(new_len)
        out_freqs = out_network.frequency
    
    #convert to unique s-parameters tensor
    out_matrix_re = out_network.s.real
    out_matrix_im = out_network.s.imag
    out_matrix = np.stack((out_matrix_re, out_matrix_im), axis=-1)

    out_sparams = utils.matrix_to_sparams(out_matrix)

    out_sparams = out_sparams.reshape(1, -1, out_freqs.npoints)

    return torch.tensor(out_sparams)

# Function to Fit a single DIP to given measurements

In [4]:
def fit_DIP(model, y, z, 
            lr, num_iter, 
            train_loss, train_reg, reg_lambda=0, 
            start_noise=None, noise_decay=None):
    
    #we can make the optmizer within the function since we don't need the stats between fits
    optim = torch.optim.Adam(model.parameters(), lr=lr)
    
    for i in range(num_iter):
        optim.zero_grad()
        
        #get the output with or without additive noise in the input
        if (start_noise is not None) and (noise_decay is not None):
            noisy_z = z + torch.randn_like(z) * start_noise
            out = model.forward(noisy_z)
            start_noise *= noise_decay
        else:
            out = model.forward(z)
        
        #loss and regularization
        error = train_loss(out, y) 
        if reg_lambda > 0:
            reg = reg_lambda * train_reg(out)
            loss = error + reg
        else:
            loss = error

        loss.backward()
        optim.step()
    
    return model, out

# Running FNs

In [5]:
def grab_data_and_net(data_root, chip_num, measurement_spacing, num_measurements,
                      ngf, kernel_size, causal, passive, backbone):
    
    x = grab_sparams(data_root, chip_num)

    #grab the appropriate measurements
    kept_inds, missing_inds = utils.get_inds(measurement_spacing, x.shape[-1], num_measurements)

    y = torch.clone(x)[:, :, kept_inds]

    z = torch.clone(x)
    z[:, :, missing_inds] = 0

    #set up the clone network and head and make modular net
    net_head = RESNET_HEAD(nz=x.shape[1],
                           ngf_in_out=ngf,
                           nc=x.shape[1],
                           output_size=x.shape[-1],
                           kernel_size=kernel_size,
                           causal=causal,
                           passive=passive)

    backbone_clone = backbone.make_clone() 

    net = MODULAR_RESNET(backbone=backbone_clone,
                         head=net_head)
    
    return x, y, z, kept_inds, net

In [6]:
def train_step(data_root, chip_num, measurement_spacing, num_measurements, ngf,
               kernel_size, causal, passive, backbone, device, lr_inner, 
               num_iters_inner, reg_lambda_inner, start_noise_inner, noise_decay_inner):
    #sample chip
    x, y, z, kept_inds, net = grab_data_and_net(data_root=data_root, chip_num=chip_num, measurement_spacing=measurement_spacing, 
                                     num_measurements=num_measurements, ngf=ngf, kernel_size=kernel_size, 
                                     causal=causal, passive=passive, backbone=backbone)
    x = x.to(device)
    y = y.to(device)
    z = z.to(device)
    net = net.to(device)

    #set up losses and regularisations
    criterion = utils.Measurement_MSE_Loss(kept_inds=kept_inds, per_param=True, reduction="mean")
    criterion = criterion.to(device)

    regularizer = utils.Smoothing_Loss(per_param=True, reduction="mean")
    regularizer = regularizer.to(device)

    #Run DIP and get the metrics 
    updated_net, x_hat = fit_DIP(model=net, y=y, z=z, 
                                 lr=lr_inner, num_iter=num_iters_inner, 
                                 train_loss=criterion, train_reg=regularizer, reg_lambda=reg_lambda_inner, 
                                 start_noise=start_noise_inner, noise_decay=noise_decay_inner) 
    with torch.no_grad():
        test_mse = nn.MSELoss()(x_hat, x).item()
    
    return updated_net, test_mse

In [7]:
def reptile(backbone, data_root, device, measurement_spacing, num_measurements, 
            num_epochs, lr_outer, test_inds, train_inds, 
            lr_inner, num_iters_inner, reg_lambda_inner, start_noise_inner, noise_decay_inner,    
            ngf, kernel_size, causal, passive):
    
    optim = torch.optim.Adam(backbone.parameters(), lr=lr_outer)
    
    inner_test_losses = []
    outer_test_losses = []
    
    for epoch in range(num_epochs):

        print("STARTING EPOCH " + str(epoch) + "\n")

        #testing - don't update parameters, just track the metrics
        print("TESTING\n")

        for test_chip_ind in tqdm(test_inds):
            _, outer_test_mse = train_step(data_root=data_root, chip_num=test_chip_ind, 
                                           measurement_spacing=measurement_spacing, num_measurements=num_measurements, 
                                           ngf=ngf, kernel_size=kernel_size, causal=causal, passive=passive, 
                                           backbone=backbone, device=device, lr_inner=lr_inner, 
                                           num_iters_inner=num_iters_inner, reg_lambda_inner=reg_lambda_inner, 
                                           start_noise_inner=start_noise_inner, noise_decay_inner=noise_decay_inner)

            outer_test_losses.append(outer_test_mse)

            print("CHIP " + str(test_chip_ind) + " TEST MSE: " + str(outer_test_mse) + "\n")
        
        #training - update params and track metrics
        print("TRAINING\n")

        train_shuffle = np.random.permutation(train_inds)

        for train_chip_ind in tqdm(train_shuffle):
            updated_net, inner_test_mse = train_step(data_root=data_root, chip_num=train_chip_ind, 
                                           measurement_spacing=measurement_spacing, num_measurements=num_measurements, 
                                           ngf=ngf, kernel_size=kernel_size, causal=causal, passive=passive, 
                                           backbone=backbone, device=device, lr_inner=lr_inner, 
                                           num_iters_inner=num_iters_inner, reg_lambda_inner=reg_lambda_inner, 
                                           start_noise_inner=start_noise_inner, noise_decay_inner=noise_decay_inner)

            inner_test_losses.append(inner_test_mse)

            print("CHIP " + str(train_chip_ind) + " TEST MSE: " + str(inner_test_mse) + "\n")

            #update params
            new_backbone = updated_net.backbone.cpu()

            for p, new_p in zip(backbone.parameters(), new_backbone.parameters()):
                p.grad = p.data - new_p.data
            
            optim.step()
            optim.zero_grad()
    
    return backbone, inner_test_losses, outer_test_losses

# Make the running parameters

In [8]:
ROOT_PATH = "/scratch1/04703/sravula/UTAFSDataNew/new_data"

NUM_SAMPLES = 62

TEST_SET = [9, 21, 29, 38, 49, 50, 61]
TRAIN_SET = [i for i in range(NUM_SAMPLES) if i not in TEST_SET]

In [9]:
NUM_LAYERS = 8
BASE_NGF = 512
KERNEL_SIZE = 3

CAUSAL = False
PASSIVE = False

net_backbone = RESNET_BACKBONE(ngf=BASE_NGF,
                               ngf_in_out=BASE_NGF,
                               kernel_size=KERNEL_SIZE,
                               num_layers=NUM_LAYERS)

In [10]:
NUM_EPOCHS = 10
LR_OUTER = 1e-3

LR_INNER = 1e-4
NUM_ITERS = 1000
REG_LAMBDA = 0.1

MEASUREMENT_SPACING = "equal"
NUM_MEASUREMENTS = 0.1

In [11]:
START_NOISE_LEVEL = 0.1
END_NOISE_LEVEL = 0.001
NOISE_DECAY_FACTOR = (END_NOISE_LEVEL / START_NOISE_LEVEL)**(1 / NUM_ITERS) 

print("Noise decay factor: ", NOISE_DECAY_FACTOR)

Noise decay factor:  0.995405417351527


# Run the thing

In [None]:
_, _, _ = reptile(backbone=net_backbone, data_root=ROOT_PATH, device=device, 
                  measurement_spacing=MEASUREMENT_SPACING, num_measurements=NUM_MEASUREMENTS, 
                  num_epochs=NUM_EPOCHS, lr_outer=LR_OUTER, test_inds=TEST_SET, train_inds=TRAIN_SET, 
                  lr_inner=LR_INNER, num_iters_inner=NUM_ITERS, reg_lambda_inner=REG_LAMBDA, 
                  start_noise_inner=START_NOISE_LEVEL, noise_decay_inner=NOISE_DECAY_FACTOR,    
                  ngf=BASE_NGF, kernel_size=KERNEL_SIZE, causal=CAUSAL, passive=PASSIVE)

STARTING EPOCH 0

TESTING



 14%|█▍        | 1/7 [01:18<07:49, 78.18s/it]

CHIP 9 TEST MSE: 0.002094756346195936



 29%|██▊       | 2/7 [01:45<04:01, 48.30s/it]

CHIP 21 TEST MSE: 0.00013425626093521714



 43%|████▎     | 3/7 [03:01<04:04, 61.05s/it]

CHIP 29 TEST MSE: 0.000141730866744183



 57%|█████▋    | 4/7 [04:42<03:50, 76.83s/it]

CHIP 38 TEST MSE: 0.0018753366312012076



 71%|███████▏  | 5/7 [06:19<02:48, 84.13s/it]

CHIP 49 TEST MSE: 0.0007656976231373847



 86%|████████▌ | 6/7 [07:56<01:28, 88.41s/it]

CHIP 50 TEST MSE: 0.0006887433119118214



100%|██████████| 7/7 [08:29<00:00, 72.85s/it]


CHIP 61 TEST MSE: 0.0035398993641138077

TRAINING



  0%|          | 0/55 [00:00<?, ?it/s]

CHIP 61 TEST MSE: 0.0020003104582428932



  2%|▏         | 1/55 [00:55<49:31, 55.03s/it]

CHIP 61 TEST MSE: 0.0019496030872687697



  4%|▎         | 2/55 [01:54<50:41, 57.40s/it]

CHIP 61 TEST MSE: 0.001088720397092402



  5%|▌         | 3/55 [02:35<43:27, 50.14s/it]

CHIP 61 TEST MSE: 0.0008807232370600104



  7%|▋         | 4/55 [04:16<59:45, 70.31s/it]

CHIP 61 TEST MSE: 0.0011434133630245924



  9%|▉         | 5/55 [04:59<50:15, 60.31s/it]

CHIP 61 TEST MSE: 0.000721480930224061



 11%|█         | 6/55 [05:28<40:39, 49.78s/it]

CHIP 61 TEST MSE: 0.00039237196324393153



 13%|█▎        | 7/55 [07:21<56:17, 70.36s/it]

CHIP 61 TEST MSE: 0.00035913949250243604



 15%|█▍        | 8/55 [07:49<44:32, 56.87s/it]

CHIP 61 TEST MSE: 0.0014020254602655768



 16%|█▋        | 9/55 [08:27<39:08, 51.04s/it]

CHIP 61 TEST MSE: 0.0009148393291980028



 18%|█▊        | 10/55 [08:53<32:27, 43.28s/it]

CHIP 61 TEST MSE: 0.001427693641744554



 20%|██        | 11/55 [09:19<27:52, 38.01s/it]

CHIP 61 TEST MSE: 0.0013254453660920262



 22%|██▏       | 12/55 [10:00<27:47, 38.78s/it]

CHIP 61 TEST MSE: 0.002393728354945779



 24%|██▎       | 13/55 [10:34<26:08, 37.35s/it]

CHIP 61 TEST MSE: 0.002645677886903286



 25%|██▌       | 14/55 [12:06<36:45, 53.80s/it]

CHIP 61 TEST MSE: 0.001139677013270557



 27%|██▋       | 15/55 [12:46<33:09, 49.75s/it]

CHIP 61 TEST MSE: 0.001159651787020266



 29%|██▉       | 16/55 [13:26<30:28, 46.87s/it]

CHIP 61 TEST MSE: 0.0003509735397528857



 31%|███       | 17/55 [13:52<25:41, 40.58s/it]

CHIP 61 TEST MSE: 0.0011183989699929953



 33%|███▎      | 18/55 [15:44<38:14, 62.01s/it]

CHIP 61 TEST MSE: 0.0011442668037489057



 35%|███▍      | 19/55 [16:13<31:21, 52.25s/it]

CHIP 61 TEST MSE: 0.0012421810533851385



 36%|███▋      | 20/55 [16:56<28:41, 49.20s/it]