In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.parallel
import random
import parameters as var #Configuration and coarsening parameters
var.init() #initializes parameters
import utils as ut #Some utility functions 
import loss_function as lf #Custom loss function

import operators_torch as op #Interpolator and prolongator given a set of test vectors
from opendataset import ConfsDataset #class for opening gauge confs
import model as mod #import machine learning model

#--------Most likely I will need some of these in the future-------#
#import matplotlib.pyplot as plt
#import torch.optim as optim
#import torch.utils.data
#import torchvision.datasets as dset
#import torchvision.transforms as transforms
#import torchvision.utils as vutils
var.print_parameters()
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
# Batch size during training
batch_size = 32
# Learning rate for optimizers
lr = 0.1
# Beta1 hyperparameter for Adam optimizers
beta1 = 0.9
# Number of GPUs available. Use 0 for CPU mode.
ngpu = 0

In [None]:
"""
Loading the configurations and the near-kernel test vectors
"""
workers = 4
dataset = ConfsDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                shuffle=True, num_workers=workers)

#----returns a tensor of size [ [batch_size,4,Nx,Nt], [batch_size,Nv,2,Nx,Nt] ]----#
#    The first entry is real and second one complex
first_batch = next(iter(dataloader)) 
#--------------------------------------

device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
#first_batch[0][0].shape
#print("Re(U0)",first_batch[0][0][0,0,0])
#print("Re(U1)",first_batch[0][0][1,0,0])
#print("Im(U0)",first_batch[0][0][2,0,0])
#print("Im(U1)",first_batch[0][0][3,0,0])

In [None]:
"""
Custom weights initialization
"""  
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02) #nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.0) #nn.init.constant_(m.bias.data, 0.0)

In [None]:
model = mod.TvGenerator(ngpu).to(device)
if (device.type == 'cuda') and (ngpu > 1):
    model = nn.DataParallel(model, list(range(ngpu)))

# Apply the ``weights_init`` function to randomly initialize all weights
#  to ``mean=0``, ``stdev=0.02``.
model.apply(weights_init)

# Print the model
print(model)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(beta1, 0.999))
#optimizer = torch.optim.SGD(model.parameters(), lr=1e-3,momentum=0.9)

In [None]:
def train(dataloader, model, optimizer):
    model.train()
    criterion = lf.CustomLossTorch()               # instantiate once, reuse

    for batch_id, batch in enumerate(dataloader):
        # -------------------------------------------------
        # Load the data
        # -------------------------------------------------
        confs_batch   = batch[0].to(device)            # shape (B, …)
        near_kernel   = batch[1].to(device)            # shape (B, NV, 2, NT, NX)

        # -------------------------------------------------
        # Forward pass of the model → predicted test vectors
        # -------------------------------------------------
        # model returns a real‑valued tensor of shape [B, 4*NV, NT, NX]
        pred = model(confs_batch)                     # still a torch Tensor
        # -------------------------------------------------
        # Reshape / convert to complex dtype
        # -------------------------------------------------
        # Example: you stored real/imag in 4 channels (Re0, Re1, Im0, Im1)
        B = pred.shape[0]
        pred = pred.view(B, var.NV, 4, var.NT, var.NX)      # (B,NV,4,NT,NX)

        # Build a complex tensor of shape (B, NV, 2, NT, NX)
        #   channel 0 → real part of component 0
        #   channel 1 → real part of component 1
        #   channel 2 → imag part of component 0
        #   channel 3 → imag part of component 1
        real = torch.stack([pred[:,:, 0], pred[:,:, 1]], dim=2)   # (B,NV,2,NT,NX)
        imag = torch.stack([pred[:,:, 2], pred[:,:, 3]], dim=2)   # (B,NV,2,NT,NX)
        pred_complex = torch.complex(real, imag)                  # (B,NV,2,NT,NX)
        # -------------------------------------------------
        # Compute loss (still on the same device)
        # -------------------------------------------------
        #NOTICED THAT I INVERTED THE ORDER 18/11/25 17:05 pm
        loss = criterion(near_kernel, pred_complex)   # loss is a scalar Tensor

        # -------------------------------------------------
        # Back‑propagation
        # -------------------------------------------------
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # -------------------------------------------------
        # Logging
        # -------------------------------------------------
        loss_val = loss.item()
        current = (batch_id + 1) * B
        print(f"loss: {loss_val:>7f}  [{current:>5d}/{var.NO_CONFS:>5d}]")

In [None]:
epochs = 30
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(dataloader, model, optimizer)
print("Done!")