In [None]:
import math
import numpy as np


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, Sampler, Dataset
import torch.utils.data as data_utils

In [None]:
class TRDataset(Dataset):
    """TR amplitude, F_gn, Dataset."""

    def __init__(self, csv_file, log_scale=True, mode='train', genus = 15):
        """Initializes instance of class TRDataset.

        Args:
            csv_file (str): Path to the csv file with the amplitudes data.

        """
        self.log_scale = log_scale
        
        chunk = pd.read_csv(csv_file, chunksize=64)
        df = pd.concat(chunk)
        #df = df[3:]
        if mode == 'train':
            df = df[df['g']<14]
        else:
            #df = df[df['g']== genus]
            df = df[df['g']>= genus]
            
         
        df = df.applymap(lambda x: np.array(ast.literal_eval(str(x).replace('{}', '{{{0.}}}').replace('{', '[').replace('}', ']').replace('*^', 'e')),
                                            dtype=float))
        
        df['Permutations'] = df.apply(lambda row: np.delete(row['Permutations'], 
                                                            np.where(row['Fgn'] == 0)[0], axis=0), axis = 1)
        df['Fgn'] = df.apply(lambda row: row['Fgn'][row['Fgn'] != 0], axis = 1)
        df = df.explode(['Permutations','Fgn'], ignore_index=True)
        df[['Fgn']] = df[['Fgn']].apply(pd.to_numeric)
        df.drop(df[df['g']==1][df['n']==2].index, inplace=True)
        df.drop(df[df['g']==1][df['n']==1].index, inplace=True)
        df.drop(df[df['g']==0][df['n']==3].index, inplace=True)
        df.drop(df[df['g']==0][df['n']==4].index, inplace=True)
        df.drop(df[df['g']==0][df['n']==5].index, inplace=True)


        self.g = df.iloc[:,0].to_numpy()
        self.n = df.iloc[:,1].to_numpy()
        self.b = df.iloc[:,3].to_numpy()
        self.c = df.iloc[:,4].to_numpy()
        self.x = df.iloc[:,6].to_numpy()
        self.y = df.iloc[:,7].to_numpy()

        del [[df]]
        gc.collect()
        df=pd.DataFrame()

    def __len__(self) -> int:
        return len(self.x)

    def __getitem__(self, idx):
        X = torch.tensor(self.x[idx], dtype=torch.int) #max value 50: dict size
        card = X.shape[0]
        xpd = (0, 52-card) #22 --> 52
        X_padded = F.pad(X, xpd, "constant", 0)
        permutations = X_padded#.int8 
        
        mask_x = torch.zeros(52, dtype=torch.bool)
        mask_x[:card] = 1

        G = torch.tensor(self.g[idx], dtype=torch.int)
        N = torch.tensor(self.n[idx], dtype=torch.int)
        gn = torch.stack([G, N])#.int8
        
        #Creating coo of C tensor, pad it, and generating the mask
        cz = np.zeros((np.shape(np.nonzero(self.c[idx]))[1], 4)) #xyz = np.zeros((400, 4))
        cz[:, 0] = np.nonzero(self.c[idx])[0]
        cz[:, 1] = np.nonzero(self.c[idx])[1]
        cz[:, 2] = np.nonzero(self.c[idx])[2]
        cz[:, -1] = np.exp(self.c[idx][self.c[idx]!=0])
        
        mask_c = np.full((1500, 4), True, dtype=bool)
        bool_c = np.full((cz.shape), False, dtype=bool)
        mask_c[:cz.shape[0]] = bool_c[:]
        mask_c = torch.tensor(mask_c.all(-1))
        
        cpd = (0, 0, 0, 1500-np.shape(cz)[0]) #max:378(400) --> 1275(1500)
        cz = F.pad(torch.tensor(cz), cpd, "constant", 0)
        
        #Creating coo of B tensor, pad it, and generating the mask
        bz = np.zeros((np.shape(np.nonzero(self.b[idx]))[1], 4))
        bz[:, 0] = np.nonzero(self.b[idx])[0]
        bz[:, 1] = np.nonzero(self.b[idx])[1]
        bz[:, 2] = np.nonzero(self.b[idx])[2]
        bz[:, -1] = np.log(1 +100*self.b[idx][self.b[idx]!=0])
        
        mask_b = np.full((1500, 4), True, dtype=bool)
        bool_b = np.full((bz.shape), False, dtype=bool)
        mask_b[:bz.shape[0]] = bool_b[:]
        mask_b = torch.tensor(mask_b.all(-1))
        
        bpd = (0, 0, 0, 1500-np.shape(bz)[0])  #max:462(470) --> 1428(1500)
        bz = F.pad(torch.tensor(bz, dtype=torch.float32), bpd, "constant", 0)
        
        amplitudes = torch.tensor(self.y[idx], dtype=torch.float32)
        if self.log_scale:
            amplitudes = torch.log(amplitudes)#.double() #.unsqueeze(1)
            
        return permutations, amplitudes, gn, bz, cz, mask_x, mask_b, mask_c

In [None]:
Ds = TRDataset('data_decimals.csv', log_scale=True, mode= 'test', genus = 14)
train_loader = DataLoader(
    Ds,
    batch_size = 50,
    shuffle = False,
    num_workers = 4,
    drop_last = True)

In [None]:
def compute_multiplicities_np(partitions, ns):
    counts = np.zeros((partitions.shape[0], 4), dtype=int)
    # Iterate over each partition to count 0, 1, 2, 3 up to n entries
    for i in range(partitions.shape[0]):
        sub_partition = partitions[i, :ns[i]]
        for value in range(4):  # Only considering values 0, 1, 2, 3
            counts[i, value] = np.sum(sub_partition == value)
                
    return counts

def compute_G_values_np(counts, ns):
    """Compute G1, G2, G3 values for each partition"""
    ns = ns.astype(float)
    p0, p1, p2, p3 = counts[:, 0], counts[:, 1], counts[:, 2], counts[:, 3]
    
    # 50 is the batch size --> change in case of varable bs
    G0 = torch.ones(50)
    
    G1 = ((ns - 1) * (ns - 6) + (5 - p0) * p0) / 12
    
    # Compute G2
    G2 = ((ns - 1) * (3 * ns**3 - 59 * ns**2 + 298 * ns - 228) +
          p0 * (346 - 390 * ns + 30 * ns**2) +
          p0**2 * (69 + 78 * ns - 6 * ns**2) -
          46 * p0**3 + 3 * p0**4 -
          p1 * (204 - 180 * p0 + 36 * p0**2) -
          60 * p2) / 864
    
    # Compute G3
    G3 = (ns**6 - 41 * ns**5 + 555 * ns**4 - 3031 * ns**3 + 6092 * ns**2 - 5160 * ns + 1584 +
          -p0**6 + 31 * p0**5 +
          p0**4 * (ns**2 - 19 * ns - 73 + 12 * p1) -
          p0**3 * (46 * ns**2 - 874 * ns + 552 * p1 + 120 * p2 + 127) +
          p0**2 * (-3 * ns**4 + 98 * ns**3 - 36 * ns**2 * (p1 + 20) + ns * (684 * p1 - 1253)) +
          p0**2 * (-54 * p1**2 + 312 * p1 + 285 * p2 + 409) +
          p0 * (15 * ns**4 - 490 * ns**3 + ns**2 * (4291 + 180 * p1) - 12 * ns * (572 + 285 * p1) +
                90 * p1**2 + 171 * p1 - 285 * p2 - 70 * p3 + 258) -
          102 * p1**2 + p1 * (17 * ns**2 - 323 * ns + 60 * p2 + 402) +
          5 * (ns**2 * p2 - 19 * ns * p2 - 28 * p3)) / 10368
    
    return np.column_stack((G0, G1, G2, G3)), G0+G1+G2+G3


In [None]:
import numpy as np
from scipy.special import gamma
from functools import reduce
from operator import mul

def double_factorial(n):
    """Compute the double factorial of a number using gamma function for non-integer values."""
    result = np.ones_like(n, dtype=float)
    even_idx = (n % 2 == 0)
    odd_idx = ~even_idx
    
    result[even_idx] = 2**(n[even_idx] // 2) * gamma(n[even_idx] // 2 + 1)
    result[odd_idx] = 2**(n[odd_idx] / 2) * gamma(n[odd_idx] / 2 + 1) * np.sqrt(2 / np.pi)
    
    return result

def intersection_number_asymptotic(g, n, ds):
    """Compute the asymptotic intersection number given genus g and degrees ds."""
    top = double_factorial(6 * g - 5 + 2 * n)
    bottom_genus = 24**g * gamma(g + 1)
    
    bottom_degrees = np.prod(double_factorial(2 * ds + 1), axis=1)
    
    #return np.log(top / (bottom_genus * bottom_degrees))
    return np.log(top / bottom_genus)


def compute_gradient_wrt_g(g, n, ds, epsilon=0.00001):
    """Compute the numerical gradient of the intersection number with respect to g using central differences."""
    f_g_minus_epsilon = intersection_number_asymptotic(g - epsilon, n, ds)
    f_g_plus_epsilon = intersection_number_asymptotic(g + epsilon, n, ds)
    gradient_g = (f_g_plus_epsilon - f_g_minus_epsilon) / (2 * epsilon)
    return gradient_g


# Example usage
g_batch = np.full(10000, 15.0)
n_batch = gns
ds_batch = xs

intersection_values = intersection_number_asymptotic(g_batch, n_batch, ds_batch)
gradient_g = compute_gradient_wrt_g(g_batch, n_batch, ds_batch)


In [None]:
import numpy as np
from scipy.special import gamma
from functools import reduce
from operator import mul


def double_factorial(n):
    """Compute the double factorial of a number using gamma function for non-integer values."""
    result = np.ones_like(n, dtype=float)
    even_idx = (n % 2 == 0)
    odd_idx = ~even_idx
    
    result[even_idx] = 2**(n[even_idx] // 2) * gamma(n[even_idx] // 2 + 1)
    result[odd_idx] = 2**(n[odd_idx] / 2) * gamma(n[odd_idx] / 2 + 1) * np.sqrt(2 / np.pi)
    
    return result

def intersection_number_asymptotic(g, n, ds):
    """Compute the asymptotic intersection number given genus g and degrees ds."""
    top = double_factorial(6 * g - 5 + 2 * n)
    bottom_genus = 24**g * gamma(g + 1)
    
    bottom_degrees = np.prod(double_factorial(2 * ds + 1), axis=1)
    
    #return np.log(top / (bottom_genus * bottom_degrees))
    return np.log(top / bottom_genus)


def compute_gradient_wrt_g(g, n, ds, epsilon=0.00001):
    """Compute the numerical gradient of the intersection number with respect to g using central differences."""
    f_g_minus_epsilon = intersection_number_asymptotic(g - epsilon, n, ds)
    f_g_plus_epsilon = intersection_number_asymptotic(g + epsilon, n, ds)
    gradient_g = (f_g_plus_epsilon - f_g_minus_epsilon) / (2 * epsilon)
    return gradient_g


# Example usage
g_batch = np.full(10000, 15.0)
n_batch = gns
ds_batch = xs

intersection_values = intersection_number_asymptotic(g_batch, n_batch, ds_batch)
gradient_g = compute_gradient_wrt_g(g_batch, n_batch, ds_batch)


In [None]:

def log_double_factorial(n):
    """Compute the logarithm of the double factorial using gamma function for non-integer values."""
    result = torch.zeros_like(n, dtype=torch.float32)
    even_idx = (n % 2 == 0)
    odd_idx = ~even_idx

    result[even_idx] = (n[even_idx] // 2) * torch.log(torch.tensor(2.0)) + torch.lgamma(n[even_idx] // 2 + 1)
    result[odd_idx] = (n[odd_idx] / 2) * torch.log(torch.tensor(2.0)) + torch.lgamma(n[odd_idx] / 2 + 1) + torch.log(torch.tensor(math.sqrt(2 / math.pi)))

    return result

def intersection_number_asymptotic(g, n, ds):
    """Compute the asymptotic intersection number given genus g and degrees ds using log computations."""
    log_top = log_double_factorial(6 * g - 5 + 2 * n)
    log_bottom_genus = g * torch.log(torch.tensor(24.0)) + torch.lgamma(g + 1)
    log_bottom_degrees = torch.sum(log_double_factorial(2 * ds + 1), dim=1)
    
    log_result = log_top - log_bottom_genus #- log_bottom_degrees
    return log_result, log_bottom_degrees #.exp()

def intersection_number_asymptotic_X(g, n, ds):
    """Compute the logarithmic form of the given function."""
    # Compute individual log components
    log_part1 = n * torch.log(torch.tensor(2.0))
    log_part2 = -torch.log(torch.tensor(4.0 * torch.pi))
    log_part3 = torch.lgamma(2 * g - 2 + n)
    log_part4 = (2 * g - 2 + n) * torch.log(torch.tensor(2.0/3.0)) #2.0 / 3.0
    log_bottom_degrees = torch.sum(log_double_factorial(2 * ds + 1), dim=1)
    
    # Combine the components
    log_result = log_part1 + log_part2 + log_part3 - log_part4 - log_bottom_degrees
    return log_result, log_part4

def intersection_number_asymptotic(g, n, ds):
    """Compute the asymptotic intersection number given genus g and degrees ds using log computations."""
    log_top = log_double_factorial(6 * g - 5 + 2 * n)
    log_bottom_genus = g * torch.log(torch.tensor(24.0)) + torch.lgamma(g + 1)
    log_bottom_degrees = torch.sum(log_double_factorial(2 * ds + 1), dim=1)
    
    log_result = log_top - log_bottom_genus #- log_bottom_degrees
    return log_result, log_bottom_degrees #.exp()


def compute_gradient_wrt_g(g, n, ds, epsilon=0.01):
    """Compute the numerical gradient of the intersection number with respect to g using central differences."""
    f_g_minus_epsilon, _= intersection_number_asymptotic(g - epsilon, n, ds)
    f_g_plus_epsilon, _= intersection_number_asymptotic(g + epsilon, n, ds)
    gradient_g = (f_g_plus_epsilon - f_g_minus_epsilon) / (2 * epsilon)
    return gradient_g

def compute_gradient_wrt_n(g, n, ds, epsilon=0.01):
    """Compute the numerical gradient of the intersection number with respect to g using central differences."""
    f_n_minus_epsilon, _ = intersection_number_asymptotic(g, n- epsilon, ds)
    f_n_plus_epsilon, _ = intersection_number_asymptotic(g, n + epsilon, ds)
    gradient_n = (f_n_plus_epsilon - f_n_minus_epsilon) / (2 * epsilon)
    return gradient_n

def compute_gradient_wrt_ds(g, n, ds, epsilon=0.001):
    """Compute the numerical gradient of the intersection number with respect to g using central differences."""
    f_ds_minus_epsilon, _ = intersection_number_asymptotic(g, n, ds- epsilon)
    f_ds_plus_epsilon, _ = intersection_number_asymptotic(g, n, ds+ epsilon)
    gradient_ds = (f_ds_plus_epsilon - f_ds_minus_epsilon) / (2 * epsilon)
    return gradient_ds


g = torch.tensor(gs, dtype=torch.float32)
n = torch.tensor(gns, dtype=torch.float32)
ds = torch.tensor(xs, dtype=torch.float32)  

intersection_values, parts = intersection_number_asymptotic(g, n, ds)
intersection_X, parts_a = intersection_number_asymptotic_X(g, n, ds)
gradient_g = compute_gradient_wrt_g(g, n, ds)
gradient_n = compute_gradient_wrt_n(g, n, ds)
#gradient_ds = compute_gradient_wrt_ds(g, n, ds)

#print("Intersection Numbers:", intersection_values)
#print("Gradients with respect to g:", gradient_g)

In [None]:
def log_double_factorial(n):
    """Compute the logarithm of the double factorial using gamma function for non-integer values."""
    result = torch.zeros_like(n, dtype=torch.float32)
    even_idx = (n % 2 == 0)
    odd_idx = ~even_idx

    result[even_idx] = (n[even_idx] // 2) * torch.log(torch.tensor(2.0)) + torch.lgamma(n[even_idx] // 2 + 1)
    result[odd_idx] = (n[odd_idx] / 2) * torch.log(torch.tensor(2.0)) + torch.lgamma(n[odd_idx] / 2 + 1) + torch.log(torch.tensor(math.sqrt(2 / math.pi)))

    return result


conj_1, conj_2, p, emb_1, emb_2, emb_3, emb_4, attn_p, attn_b, ys, xs, ns, gs, gradients_gn = ([] for _ in range(14))


for i, (x, y, gn, b, c, mask_x, mask_b, mask_c) in enumerate(tqdm(train_loader)):
    #if i == 500:
    #    break

    # Compute multiplicities
    partition_counts = compute_multiplicities_np(x.numpy(), gn[:, 1].numpy())
    g_values = compute_G_values_np(partition_counts, gn[:, 1].numpy())
    conj_1.append(g_values[0])
    conj_2.append(g_values[1])

    gn = gn.to(device).float().requires_grad_(True)
    xs.append(x)

    outputs = model(x.to(device), gn, b.float().to(device), c.float().to(device),
                    mask_x.to(device), mask_b.to(device), mask_c.to(device))

    predictions = torch.exp(outputs[0]).squeeze().detach().cpu()
    p.append(predictions)

    embeddings = [output.squeeze().detach().cpu().numpy() for output in outputs[1:8]]
    emb_1.append(embeddings[0])
    emb_2.append(embeddings[1])
    emb_3.append(embeddings[2])
    emb_4.append(embeddings[3])
    attn_p.append(embeddings[4])
    attn_b.append(embeddings[5])
    #attn_c.append(embeddings[6])
    
    

    log_bottom_degrees = torch.sum(log_double_factorial(2 * x.float().to(device) + 1), dim=1)
    predictions_tensor = outputs[0] + log_bottom_degrees
    predictions_tensor.backward(torch.ones_like(predictions_tensor))
    gradients_gn.append(gn.grad.cpu().numpy())

    ys.append(torch.exp(y).cpu().detach().numpy())
    ns.append(gn[:, 1].detach().cpu().numpy())
    gs.append(gn[:, 0].detach().cpu().numpy())


p = np.concatenate(p)
emb_1 = np.concatenate(emb_1)
emb_2 = np.concatenate(emb_2)
emb_3 = np.concatenate(emb_3)
emb_4 = np.concatenate(emb_4)
attn_p = np.concatenate(attn_p)
attn_b = np.concatenate(attn_b)
ys = np.concatenate(ys)
xs = np.concatenate(xs)
gns = np.concatenate(ns)
gs = np.concatenate(gs)
conj_1 = np.concatenate(conj_1)
conj_2 = np.concatenate(conj_2)
gradients_gn = np.concatenate(gradients_gn)

torch.cuda.empty_cache()
