In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.autograd import grad
import einops

import numpy as np
import matplotlib.pyplot as plt
import copy
import pickle
import os
import glob
from symm_loss_defs import *

Using cuda device


In [19]:
class SymmLoss_norm(nn.Module):

    def __init__(self, gens_list,model,device = devicef):
        super(SymmLoss_norm, self).__init__()
        
        self.model = model.to(device)
        self.device = device
        # Initialize generators (in future add different reps for inputs?)
        self.generators = einops.rearrange(gens_list, 'n w h -> n w h')
        self.generators = self.generators.to(device)
        

    
    def forward(self, input, model_rep='scalar',norm = "none"):
        
        input = input.clone().detach().requires_grad_(True)
        input = input.to(self.device)
        # Compute model output, shape [B]
        output = self.model(input)

        # Compute gradients with respect to input, shape [B, d*N], B is the batch size, d is the input irrep dimension, N is the number of particles
        grads, = torch.autograd.grad(outputs=output, inputs=input, grad_outputs=torch.ones_like(output, device=self.device), create_graph=True)
        
        grads_norm = torch.einsum('... N, ... N -> ...', grads, grads)
        print(grads_norm.mean())
        
        # Reshape grads to [B, N, d] 
        grads = einops.rearrange(grads, '... (N d) -> ... N d',d = self.generators.shape[-1])

        # Contract grads with generators, shape [n (generators), B, N, d]
        gen_grads = torch.einsum('n h d, ... N h->  n ... N d ',self.generators, grads)
        # Reshape to [n, B, (d N)]
        gen_grads = einops.rearrange(gen_grads, 'n ... N d -> n ... (N d)')

        # Dot with input [n ,B]
        differential_trans = torch.einsum('n ... N, ... N -> n ...', gen_grads, input)
        
       
        
        scalar_loss = (differential_trans ** 2).mean()
        print(f"symm loss = {scalar_loss}")
        
        scalar_loss_norm = (1/len(self.generators))*(torch.sum(differential_trans**2,dim = 0)/grads_norm).mean()
     
            
        return scalar_loss_norm



In [4]:
dinput = 4
N = 100
norm = 1

train_data = (torch.rand(N,dinput)-0.5)*norm
train_dataset = TensorDataset(train_data)
train_loader = DataLoader(train_dataset, batch_size=N, shuffle=True)

In [5]:
mymodelLorentz = inv_model(dinput = 4, init = "eta")

In [6]:
lossLorentz = SymmLoss(gens_list=gens_Lorentz, model = mymodelLorentz)

In [7]:
loss_res = lossLorentz(input = train_data)

In [8]:
print(loss_res)

tensor(5.5106e-18, device='cuda:0', grad_fn=<MeanBackward0>)


In [15]:
lossLorentzNorm = SymmLoss_norm(gens_list=gens_Lorentz, model = mymodelLorentz)

In [16]:
loss_res_norm = lossLorentzNorm(input = train_data)

symm loss = 5.510590386863952e-18


In [18]:
print(loss_res_norm)

tensor(3.3546e-18, device='cuda:0', grad_fn=<MulBackward0>)
