In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.autograd import grad
import einops

import numpy as np
import matplotlib.pyplot as plt
import copy
import pickle
import os
import glob
from symm_loss_defs import *

Using cuda device


In [55]:
class SymmLoss_pT_eta_phi(nn.Module):

    def __init__(self,model, gens_list = ["Lx", "Ly", "Lz", "Kx", "Ky", "Kz"],device = devicef):
        super(SymmLoss_pT_eta_phi, self).__init__()
        
        self.model = model.to(device)
        self.device = device
        
        # Initialize generators (in future add different reps for inputs?)
        GenList_names = []
        Lorentz_names = ["Lx", "Ly", "Lz", "Kx", "Ky", "Kz"]
        for gen in gens_list:
            if gen in Lorentz_names:
                GenList_names.append(gen)
            else:
                print(f"generator \n {gen} needs to be one of: {Lorentz_names}") #This is for now. Later will add a part that deals with calculating the transforamtion for a given generator. 
                
                # self.generators = einops.rearrange(gens_list, 'n w h -> n w h')
                # self.generators = self.generators.to(device)
        self.generators = GenList_names
        
        

    def forward(self, input, model_rep='scalar',norm = "none",nfeatures = "",eta_linlog = "lin",phi_linlog = "lin"):
        
        input = input.clone().detach().requires_grad_(True)
        input = input.to(self.device)
        if nfeatures!="":
            dim = nfeatures
        else:
            dim = 4 #self.generators.shape[-1]
        #Assuming input is shape [B,d*N] d is the number of features, N is the number of particles
        input_reshaped = einops.rearrange(input, '... (N d) -> ... N d',d = dim)
        
        E = input[:,0::dim] #assuming input features are ordered as (E,pT,eta,phi)
        
        pT = input[:,1::dim]
        
        eta = input[:,2::dim]
        if eta_linlog == "log":
            eta = torch.exp(eta)
        
        phi = input[:,3::dim]
        if phi_linlog == "log":
            phi = torch.exp(phi)
        
        
        GenList = self.generators  
        
        #dvar/dp L p, 
        ngen = len(self.generators)
        dE = torch.zeros_like(E).to(self.device)
        dpT = torch.zeros_like(pT).to(self.device)
        deta = torch.zeros_like(eta).to(self.device)
        dphi = torch.zeros_like(phi).to(self.device)
        
        
        #Here for all the Lorentz generators. Later can add options for only some of them.
        dE   = {"Lx": torch.zeros_like(E),              "Ly": torch.zeros_like(E),                "Lz":  torch.zeros_like(E),  "Kx":pT*torch.cos(phi),                    "Ky":pT*torch.sin(phi),                    "Kz":pT*torch.sinh(eta)}
        dpT  = {"Lx": pT*torch.sin(phi)*torch.sinh(eta),"Ly": -pT*torch.cos(phi)*torch.sinh(eta), "Lz":  torch.zeros_like(pT), "Kx":E*torch.cos(phi),                     "Ky":E*torch.sin(phi),                     "Kz":torch.zeros_like(pT)}
        deta = {"Lx": -1*torch.sin(phi)*torch.cosh(eta),  "Ly": torch.cos(phi)*torch.cosh(eta),     "Lz":  torch.zeros_like(eta),"Kx":-E*torch.cos(phi)*torch.tanh(eta)/pT, "Ky":-E*torch.sin(phi)*torch.tanh(eta)/pT, "Kz":E/(pT*torch.cosh(eta))}
        dphi = {"Lx":  torch.cos(phi)*torch.sinh(eta),  "Ly": torch.sin(phi)*torch.sinh(eta),     "Lz":-1*torch.ones_like(phi),"Kx":-E*torch.sin(phi)/pT,                 "Ky":E*torch.cos(phi)/pT,                  "Kz":torch.zeros_like(phi)}
        
        
        varsE = torch.empty(ngen,E.shape[0],E.shape[1]).to(self.device)
        varspT = torch.empty(ngen,E.shape[0],E.shape[1]).to(self.device)
        varseta = torch.empty(ngen,E.shape[0],E.shape[1]).to(self.device)
        varsphi = torch.empty(ngen,E.shape[0],E.shape[1]).to(self.device)
            
        for i,gen in enumerate(GenList):
            varsE[i] = dE[GenList[i]]
            varspT[i] = dpT[GenList[i]]
            varseta[i] = deta[GenList[i]]/eta if eta_linlog == "log" else deta[GenList[i]]
            varsphi[i] = dphi[GenList[i]]/phi if phi_linlog == "log" else dphi[GenList[i]]
        
        varsSymm = torch.stack((varsE,varspT,varseta,varsphi), dim = -1) #[n,B,N,d]
        #print(varsSymm.shape)
            
        # Compute model output, shape [B]
        output = self.model(input_reshaped)

        # Compute gradients with respect to input, shape [B, d*N], B is the batch size, d is the input irrep dimension, N is the number of particles
        grads_input, = torch.autograd.grad(outputs=output, inputs=input, grad_outputs=torch.ones_like(output, device=self.device), create_graph=True)
        
        # Reshape grads to [B, N, d] 
        grads_input = einops.rearrange(grads_input, '... (N d) -> ... N d',d = dim)

            
        # Dot with input [n ,B]
        differential_trans = torch.einsum('n ... N, ... N -> n ...', varsSymm, grads_input)
        
        scalar_loss = (differential_trans ** 2).mean()
            
            #add norm part here?
     
            
        return scalar_loss


In [37]:
class inv_model_pT_eta_phi(nn.Module):

    def __init__(self,dinput = 4, doutput = 1,init = "rand"):
        super(inv_model_pT_eta_phi,self).__init__()
        

        bi_tensor = torch.randn(dinput,dinput)

        if init=="eta":
            diag = torch.ones(dinput)*(-1.00)
            diag[0]=1.00
            bi_tensor = torch.diag(diag)

        elif init=="delta":
            bi_tensor = torch.diag(torch.ones(dinput)*1.00)
        
        
        bi_tensor = ((bi_tensor+torch.transpose(bi_tensor,0,1))*0.5).to(devicef)
        self.bi_tensor = torch.nn.Parameter(bi_tensor)
        self.bi_tensor.requires_grad_()

    def forward(self,x, sig = "euc", d = 3 ):
        #y = x @ (self.bi_tensor @ x.T)
        x = x.to(devicef)
        E = x[:,0::4]
        #print(E)
        pT = x[:,1::4]
        eta = x[:,2::4]
        phi = x[:,3::4]
        px = pT*torch.cos(phi)
        py = pT*torch.sin(phi)
        pz = pT*torch.tanh(eta)/torch.sqrt(1-torch.tanh(eta)**2)
        
        #z = torch.transpose(torch.stack((torch.transpose(E),torch.transpose(pT),torch.transpose(eta),torch.transpose(phi)),dim = 1))
        z = torch.cat((E,px,py,pz),dim = 1).to(devicef)
        print(z)
        y = torch.einsum("...i,ij,...j-> ...",z,self.bi_tensor,z)
       
        return y


In [49]:
class inv_model_pT_eta_phi_dim(nn.Module):

    def __init__(self,dinput = 4, doutput = 1,init = "rand"):
        super(inv_model_pT_eta_phi_dim,self).__init__()
        

        bi_tensor = torch.randn(dinput,dinput)

        if init=="eta":
            diag = torch.ones(dinput)*(-1.00)
            diag[0]=1.00
            bi_tensor = torch.diag(diag)

        elif init=="delta":
            bi_tensor = torch.diag(torch.ones(dinput)*1.00)
        
        
        bi_tensor = ((bi_tensor+torch.transpose(bi_tensor,0,1))*0.5).to(devicef)
        self.bi_tensor = torch.nn.Parameter(bi_tensor)
        self.bi_tensor.requires_grad_()

    def forward(self,x, sig = "euc", d = 3 ):
        #y = x @ (self.bi_tensor @ x.T)
        x = x.to(devicef)
        print(x)
        E = x[:,:,:,0]
        
        pT = x[:,:,:,1]
        eta = x[:,:,:,2]
        phi = x[:,:,:,3]
        px = pT*torch.cos(phi)
        py = pT*torch.sin(phi)
        pz = pT*torch.tanh(eta)/torch.sqrt(1-torch.tanh(eta)**2)
        
        #z = torch.transpose(torch.stack((torch.transpose(E),torch.transpose(pT),torch.transpose(eta),torch.transpose(phi)),dim = 1))
        ps = torch.cat((E,px,py,pz),dim = 1).to(devicef)
        z = ps.sum(dim=1)
        print(z)
        y = torch.einsum("...i,ij,...j-> ...",z,self.bi_tensor,z)
       
        return y


In [271]:
dinput = 4
N = 100
norm = 1

train_data = (torch.rand(N,dinput)-0.5)*norm
train_dataset = TensorDataset(train_data)
train_loader = DataLoader(train_dataset, batch_size=N, shuffle=True)

In [272]:
E = train_data[:,0::4]
px =train_data[:,1::4]
py = train_data[:,2::4]
pz = train_data[:,3::4]
pT = torch.sqrt(px**2+py**2)
eta = torch.arctanh(pz/torch.sqrt(px**2+py**2+pz**2))
phi = torch.arctan(py/px)

In [273]:
train_data_prime = torch.stack((E,pT,eta,phi),dim=1).squeeze()


In [274]:
mymodelLorentz = inv_model_pT_eta_phi(dinput = 4, init = "eta")

In [315]:
lossLorentz = SymmLoss_pT_eta_phi(model = mymodelLorentz)

In [316]:
loss_res = lossLorentz(input = train_data_prime)

In [317]:
print(loss_res)

tensor(9.8683e-16, device='cuda:0', grad_fn=<MeanBackward0>)


In [291]:
mymodelLorentz_orig = inv_model(dinput = 4, init = "eta")

In [292]:
lossLorentz_orig = SymmLoss(gens_list=gens_Lorentz, model = mymodelLorentz_orig)

In [252]:
mymodelLorentz_orig(train_data.to(devicef))

tensor([-0.4956, -0.2114, -0.4526, -0.6944, -0.2188, -0.4268, -0.6869, -3.2348,
        -2.0263, -2.5780, -2.1771, -2.4627, -0.7361, -1.7357, -0.3375, -2.2630,
        -1.4821, -3.2321, -1.8447, -2.6270, -1.3080, -0.2570, -0.9055, -1.1938,
        -0.9377, -0.0368, -1.5147, -0.0163, -0.9473, -0.8020, -1.4284, -1.0142,
        -0.2353, -0.3589, -0.9659, -1.0337, -1.4175, -2.6716, -0.5331, -0.9565,
        -0.0633, -1.2289, -3.7621, -0.4080, -7.5633, -3.6914, -3.0746, -1.8080,
        -3.3444, -0.3049, -0.9497, -1.2709, -0.9815, -1.4976, -0.9070, -4.1121,
        -2.6181, -3.1166, -1.8257, -0.6951, -1.0641, -0.5842, -0.7990, -2.7878,
        -0.3988, -2.3080, -0.5472, -3.2336, -0.6079, -2.0571, -1.3465, -0.4643,
        -3.0727, -1.2150, -1.0939, -0.9017, -1.7119, -1.8762, -1.6534, -2.9313,
        -0.7373, -1.0086, -2.7413, -2.0608, -0.2969, -3.9602, -0.8981, -1.1807,
        -0.5664, -1.3759, -0.3659, -0.7868, -3.5042, -1.0727, -1.1628, -0.8147,
        -1.0165, -2.7074, -1.5103, -1.16

In [255]:
mymodelLorentz(train_data_prime.to(devicef))

tensor([-0.4956, -0.2114, -0.4526, -0.6944, -0.2188, -0.4268, -0.6869, -3.2348,
        -2.0263, -2.5780, -2.1771, -2.4627, -0.7361, -1.7357, -0.3375, -2.2630,
        -1.4821, -3.2321, -1.8447, -2.6270, -1.3080, -0.2570, -0.9055, -1.1938,
        -0.9377, -0.0368, -1.5147, -0.0163, -0.9473, -0.8020, -1.4284, -1.0142,
        -0.2353, -0.3589, -0.9659, -1.0337, -1.4175, -2.6716, -0.5331, -0.9565,
        -0.0633, -1.2289, -3.7621, -0.4080, -7.5633, -3.6914, -3.0746, -1.8080,
        -3.3444, -0.3049, -0.9497, -1.2709, -0.9815, -1.4976, -0.9070, -4.1121,
        -2.6181, -3.1166, -1.8257, -0.6951, -1.0641, -0.5842, -0.7990, -2.7878,
        -0.3988, -2.3080, -0.5472, -3.2336, -0.6079, -2.0571, -1.3465, -0.4643,
        -3.0727, -1.2150, -1.0939, -0.9017, -1.7119, -1.8762, -1.6534, -2.9313,
        -0.7373, -1.0086, -2.7413, -2.0608, -0.2969, -3.9602, -0.8981, -1.1807,
        -0.5664, -1.3759, -0.3659, -0.7868, -3.5042, -1.0727, -1.1628, -0.8147,
        -1.0165, -2.7074, -1.5103, -1.16

In [293]:
loss_res_orig = lossLorentz_orig(input = train_data)

In [294]:
print(loss_res_orig)

tensor(6.2248e-18, device='cuda:0', grad_fn=<MeanBackward0>)


In [9]:
p = torch.randn(100,10,3)
E = p.norm(dim=-1)
P = torch.cat([E.unsqueeze(-1),p],dim=-1).unsqueeze(0)
train_data = P.clone().to(devicef)
train_dataset = TensorDataset(train_data)
train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)

In [40]:
E = train_data[:,:,:,0]
px =train_data[:,:,:,1]
py = train_data[:,:,:,2]
pz = train_data[:,:,:,3]
pT = torch.sqrt(px**2+py**2)
eta = torch.arctanh(pz/torch.sqrt(px**2+py**2+pz**2))
phi = torch.arctan(py/px)

In [53]:
train_data_prime = torch.stack((E,pT,eta,phi),dim=-1).squeeze().flatten(dim=1)

TypeError: flatten() got an unexpected keyword argument 'dim'

In [54]:
mymodelLorentz = inv_model_pT_eta_phi_dim(dinput = 4, init = "eta")
lossLorentz = SymmLoss_pT_eta_phi(model = mymodelLorentz)
loss_res = lossLorentz(input = train_data_prime)
print(loss_res)

RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1