In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.parallel
import random
import matplotlib.pyplot as plt
import parameters as var #Configuration and coarsening parameters
var.init() #initializes parameters
import utils as ut #Some utility functions 
import loss_function as lf #Custom loss function

import operators_torch as op #Interpolator and prolongator given a set of test vectors
from opendataset import ConfsDataset #class for opening gauge confs
import model as mod #import machine learning model
from train import train, evaluate

#--------Most likely I will need some of these in the future-------#
#import torch.optim as optim
#import torchvision.transforms as transforms
#import torchvision.utils as vutils
var.print_parameters()
device = (
    "cuda"
    if var.NGPU>0
    else "cpu"
)
print(f"Using {device} device")

In [None]:
"""
Loading the configurations and the near-kernel test vectors
We split train and test set
"""
dataset = ConfsDataset()                     
total_len = len(dataset)                    
train_len = int(var.TRAIN_PROP * total_len) 
test_len  = total_len - train_len   
torch.manual_seed(42)                       # <-- any integer you like

train_set, test_set = torch.utils.data.random_split(
    dataset,
    [train_len,  test_len]          # lengths in the same order
)

workers    = 8
# Batch size
batch_size = 50

#train dataloader
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,          # we usually want shuffling *only* for training
    num_workers=workers,
    pin_memory=True
)

#test dataloader
test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=workers,
    pin_memory=True
)

#----returns a tensor of size [ [batch_size,4,Nt,Nx], [batch_size,Nv,2,Nt,Nx], [batch_size]]----#
#    The first entry is real and second one complex. The last entry has the indices of the confs.
#first_batch = next(iter(train_loader)) 
#--------------------------------------

device = var.DEVICE
#first_batch[0][0].shape
#print("Re(U0)",first_batch[0][0][0,0,0])
#print("Re(U1)",first_batch[0][0][1,0,0])
#print("Im(U0)",first_batch[0][0][2,0,0])
#print("Im(U1)",first_batch[0][0][3,0,0])

In [None]:
"""
Custom weights initialization
"""  
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 1.0) #nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 0.0, 1.0)
        nn.init.constant_(m.bias.data, 0.0) #nn.init.constant_(m.bias.data, 0.0)

In [None]:
"""
Declare the model
"""
model = mod.TvGenerator(var.NGPU,batch_size).to(device)
if (device.type == 'cuda') and (var.NGPU > 1):
    model = nn.DataParallel(model, list(range(var.NGPU)))

# Apply the ``weights_init`` function to randomly initialize all weights
#  to ``mean=0``, ``stdev=0.02``.
model.apply(weights_init)
print("")
# Print the model
#print(model)

In [None]:
# Learning rate for optimizers
lr = 0.1
# Beta1 hyperparameter for Adam optimizers
beta1 = 0.9
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(beta1, 0.999),weight_decay=0)
#optimizer = torch.optim.SGD(model.parameters(), lr=1e-3,momentum=0.9)

In [None]:
version = 1
epochs = 5
losses = []
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_loader, model, optimizer,losses,version)
print("Done!")

In [None]:
fig = plt.figure(dpi=150)
plt.plot(np.arange(len(losses)),losses)
plt.title("Batch size = {0}, No. of examples = {1}".format(batch_size,train_len))
plt.ylabel("Loss (training)")
plt.xlabel("Epoch per batch")
plt.show()

# Check loss on test set

In [None]:
test_loss, dtest_loss, test_batch_losses = evaluate(test_loader, model, device,version)
print(f'Test average loss: {test_loss:.6f} +- {dtest_loss:.6f}')

In [None]:
#fig = plt.figure(dpi=150)
#plt.plot(np.arange(len(test_batch_losses)),test_batch_losses,linestyle='',marker='o',markersize=10)
#plt.title("Batch size = {0}, No. of examples = {1}".format(batch_size,test_len))
#plt.ylabel("Loss (test)")
#plt.xlabel("ID")
#plt.show()

In [None]:
ut.SavePredictions(test_loader, model, device)

In [None]:
with torch.no_grad():
    batch = next(iter(test_loader)) 
    confs_batch = batch[0].to(device)          # (B, …)
    pred = model(confs_batch)                  # (B, 4*NV, NT, NX)
    confsID = batch[2]
    B = pred.shape[0]
    pred = pred.view(B, var.NV, 4, var.NT, var.NX)   # (B,NV,4,NT,NX)
    # Build complex tensor (B,NV,2,NT,NX)
    real = torch.stack([pred[:, :, 0], pred[:, :, 1]], dim=2)   # (B,NV,2,NT,NX)
    imag = torch.stack([pred[:, :, 2], pred[:, :, 3]], dim=2)   # (B,NV,2,NT,NX)
    pred_complex = torch.complex(real, imag)
    norms = torch.linalg.vector_norm(pred_complex[:,:],dim=(-3,-2, -1)).view(batch_size, var.NV, 1, 1, 1)
    pred_complex_normalized = pred_complex / norms
    pred_complex_normalized = pred_complex_normalized.cpu().detach().numpy()

In [None]:
tv = 0
conf = 5 #ConfID in the test_loader
factor = 1
norm = np.linalg.norm(batch[1][conf][tv].reshape(2*var.NX*var.NT))
norm_pred = np.linalg.norm(pred_complex_normalized[conf,tv].reshape(2*var.NX*var.NT))
print("Norm",norm,norm_pred)
plt.title("Components distribution for conf {0} and test vector {1}".format(conf,tv))

plt.scatter(np.real(batch[1][conf][tv].reshape(2*var.NX*var.NT))/norm,
            np.imag(batch[1][conf][tv].reshape(2*var.NX*var.NT))/norm,
            marker="*",label="SAP tv")

plt.scatter(np.real(pred_complex_normalized[conf,tv].reshape(2*var.NX*var.NT))/norm_pred,
            np.imag(factor*pred_complex_normalized[conf,tv].reshape(2*var.NX*var.NT))/norm_pred,
            label="Fake tv")

plt.xlabel("Re")
plt.ylabel("Im")
plt.legend()
plt.show()

In [None]:
with torch.no_grad():
    batch = next(iter(train_loader)) 
    confs_batch = batch[0].to(device)          # (B, …)
    pred = model(confs_batch)                  # (B, 4*NV, NT, NX)
    confsID = batch[2]
    B = pred.shape[0]
    pred = pred.view(B, var.NV, 4, var.NT, var.NX)   # (B,NV,4,NT,NX)
    # Build complex tensor (B,NV,2,NT,NX)
    real = torch.stack([pred[:, :, 0], pred[:, :, 1]], dim=2)   # (B,NV,2,NT,NX)
    imag = torch.stack([pred[:, :, 2], pred[:, :, 3]], dim=2)   # (B,NV,2,NT,NX)
    pred_complex = torch.complex(real, imag)
    norms = torch.linalg.vector_norm(pred_complex[:,:],dim=(-3,-2, -1)).view(batch_size, var.NV, 1, 1, 1)
    pred_complex_normalized = pred_complex / norms
    pred_complex_normalized = pred_complex_normalized.cpu().detach().numpy()

In [None]:
tv = 10
conf = 1 #ConfID in the test_loader

norm = np.linalg.norm(batch[1][conf][tv].reshape(2*var.NX*var.NT))
norm_pred = np.linalg.norm(pred_complex_normalized[conf,tv].reshape(2*var.NX*var.NT))
print("Norm",norm,norm_pred)

plt.title("Components distribution for conf {0} and test vector {1}".format(conf,tv))

plt.scatter(np.real(batch[1][conf][tv].reshape(2*var.NX*var.NT))/norm,
            np.imag(batch[1][conf][tv].reshape(2*var.NX*var.NT))/norm,
            marker="*",
            label="SAP tv")

plt.scatter(np.real(pred_complex_normalized[conf,tv].reshape(2*var.NX*var.NT))/norm_pred,
            np.imag(pred_complex_normalized[conf,tv].reshape(2*var.NX*var.NT))/norm_pred,
            label="Fake tv")
plt.xlabel("Re")
plt.ylabel("Im")
plt.legend()
plt.show()