# Building Dataset 
We build a Pytorch Dataset Object to store and structure the graph data: 

In [1]:
# import libraries : 
import gc 
import os 
import random 
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
from typing import Optional, Tuple
import seaborn as sns 

# set random seed : 
np.random.seed( 41 )
random.seed( 41 )

import torch
import psutil
import torch.nn as nn 
from torch import cdist
from torch import Tensor 
import torch.nn.functional as F 
import torch.utils.data as data 

import torch_geometric
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import  remove_self_loops , scatter 

In [2]:
# check if CPU is available for training : 

device = 'gpu'
if torch.cuda.is_available(): 
    device = 'cuda'
elif torch.mps.is_available(): 
    device = 'mps'

device = torch.device( device )
device 

device(type='mps')

In [3]:
class EventData(data.Dataset): 
    
    # initialaize the event dataset 
    def __init__(self,path:str,device,threshold_dist:float=20)->None:
        '''
        Inputs : 
            path: path to the folder where the csv file was contained.  
        '''
        super(EventData,self).__init__()
        self.events = [code[:-9] for code in os.listdir(path) if code.endswith('-hits.csv')]
        self.num_events = len(self.events)
        self.threshold_dist  = threshold_dist
        self.path = path
        self.device = device  
        
    # function returns graph type represntation of the event dataset 
    def GraphData(self,idx:int) -> Data :
        eventid = self.events[idx] 
        
        # read the required csv files : 
        hits = pd.read_csv(self.path+eventid+'-hits.csv')
        truth = pd.read_csv(self.path+eventid+'-truth.csv')
        cells = pd.read_csv(self.path+eventid+'-cells.csv')
        particles = pd.read_csv( self.path+eventid+'-particles.csv')
        particle_ids  = truth.particle_id   
        # get charges corrsponding to the hits : 
        charge = Tensor([ particles.loc[ particles.particle_id.index[ particles.particle_id == ids  ] , 'q' ].tolist()[0] if ids != 0 else 0  for ids in particle_ids ])
        del particles , particle_ids 
        gc.collect()
        # find the charges left on the hit ( q = +- 1 )
        
        
        # total number of hits : these form the NODES of our graph. 
        nhits = hits.shape[0] 
        # x , y , z spatial featuers of the hits:  
        hits_spatial = hits.to_numpy()[: , 1:4 ]
        # Add a new feature vector : the number of cells that detect the hit : 
        node_fets = np.concatenate(
            (
                hits_spatial ,
                cells.hit_id.value_counts().get( hits.hit_id , 0 ).to_numpy().reshape((-1,1))
            ), 
            axis = 1 
        )
        del cells 
        gc.collect()
        # id's related to the hits 
        # this will help to initialize the graph structure : 
        hit_ids = Tensor(hits.hit_id.to_numpy( )).int()
        volume_id = Tensor(hits.volume_id.to_numpy( ))
        layer_id = Tensor(hits.layer_id.to_numpy( ))
        
        # get the particle true hit position and momentum, we add this to the node feat matrix : 
        node_fets = np.concatenate(
            (
                node_fets , 
                truth[['tx' , 'ty' , 'tz'  ]].to_numpy() - hits_spatial , 
                truth[['tpx' , 'tpy' ,'tpz']].to_numpy()
            ), 
            axis = 1 
        )
        node_fets = Tensor( node_fets )
        hits_spatial = Tensor( hits_spatial )
        
        # here we create edge_index's for the graph skeleton : 
        batch_size = 10000  # Process 10K nodes at a time
        edges = []
        
        
        for i in range(0, nhits, batch_size):
            # set distance threshold : 
            mask = cdist( hits_spatial[i : i + batch_size , : ] , hits_spatial , p = 2 ) < self.threshold_dist
            # mask2 ensures hits are either conect to another hit iff the volume_id of the dst > volume if of src or 
            # layer id of dst > layer id of src in case they have the same volume id 
            mask2 = volume_id.unsqueeze(0) -  volume_id[i:i+batch_size].unsqueeze(1) >= 0  
            mask2 = mask2 | ((volume_id.unsqueeze(0) == volume_id[i:i+batch_size].unsqueeze(1) ) & (layer_id.unsqueeze(0) - layer_id[i:i+batch_size].unsqueeze(1) >= 0 ))
            mask = mask & ( charge.unsqueeze(0) - charge[i:i+batch_size].unsqueeze(1) == 0 )
            # ensure both conditions are satisfied 
            mask = mask & mask2 
            del mask2 
            gc.collect()
            src, dst = torch.where(mask)  # Get valid edges
            del mask 
            gc.collect()
            edges.append(torch.stack([src + i, dst]))  # Offset indices
            del src , dst 
            gc.collect()

        del volume_id , layer_id , hits_spatial 
        gc.collect()
        edge_index = torch.cat(edges, dim=1)
        del edges 
        gc.collect()
        # remove self loops from the edge_index thus generated : 
        edge_index, _ = remove_self_loops(edge_index)
        row , col = edge_index 
        
        # number of edges : 
        num_edges = edge_index.shape[1]
        
        # create edge labels and edge attributes : 
        # Lables : 
            # label == 0 if the two nodes are not part of a traj 
            # label == 1 otherwise 
        edge_labels = ((truth.particle_id.to_numpy()[row] == truth.particle_id.to_numpy()[col]) & ( truth.particle_id.to_numpy()[row] != 0 ))
        edge_labels = Tensor( edge_labels ).float()
        
        # Attributes : 
            # Angle: between the momentum vector of the particle and the displacement vector between the hits. 
            # Distance: euclidean distance between the two hits. 
        pVector = Tensor( truth[['tpx' , 'tpy' ,'tpz']].to_numpy()[row] )
        pVector = pVector/torch.linalg.norm( pVector , ord = 2 , dim = 1 , keepdim= True )
        disp = Tensor( hits[['x','y','z']].to_numpy()[row] -  hits[['x','y','z']].to_numpy()[col] )
        del row , col 
        gc.collect()
        dist = torch.linalg.norm( disp , ord = 2 , dim = 1 , keepdim=True )
        angle = torch.sum( pVector*(disp/dist) , dim = 1 , keepdim=True )
        angle[torch.isnan(angle)] = 0.
        del pVector , disp  
        gc.collect()
        edge_attr = torch.cat([angle , dist] , dim = 1 )
        del angle , dist , hits , truth 
        gc.collect()
        
        # define graph data : 
        graph_data = Data(
            x = node_fets , 
            edge_index=edge_index , 
            edge_attr = edge_attr , 
            label = edge_labels.unsqueeze(1) , 
            num_nodes = nhits , 
            num_edges = num_edges ,
            hit_ids = hit_ids 
        )
        
        return graph_data 
    
    def __len__(self)->int: 
        return self.num_events 
    
    def __getitem__(self,index:int)->Data:
        return self.GraphData(index)

In [4]:
# test event data code : 
dataset = EventData(path='../data/train_100_events/',device=device)
size = len( dataset )
size 

100

In [5]:
# rnum = np.random.choice(np.arange(size))
# random_event = dataset[rnum]
# random_event

In [6]:
# make a function that builds a MultiLayerPercerptron : 

def buildMLP( insize:int, outsize:int, features:list, add_bnorm:bool = False, add_activation=None): 
    layers = [] 
    layers.append(nn.Linear( insize , features[0]))
    layers.append( nn.ReLU() )
    for i in range( 1 , len( features ) ): 
        if add_bnorm: 
            layers.append( nn.BatchNorm1d( features[i-1]) )
        layers.append( nn.Linear( features[i-1] , features[i] ) )
        layers.append( nn.ReLU() )
    layers.append(nn.Linear(features[-1],outsize))
    if add_activation != None: 
        layers.append( add_activation )
    return nn.Sequential(*layers)


In [7]:
# set the dataloder for the loading the graph data in batches of 5 
dataset_loder = DataLoader(dataset=dataset, batch_size=2, shuffle = True)

In [None]:
batch_data = next(iter(dataset))
batch_data

In [None]:
type(batch_data.num_edges)

int

In [None]:
## Create the meta layer class 
class MetaLayer( torch.nn.Module ): 
    def __init__( self,
                 edge_model: Optional[torch.nn.Module] = None ,
                 node_model: Optional[torch.nn.Module] = None ,
                 global_model: Optional[torch.nn.Module] = None ):
        
        super(MetaLayer, self).__init__()
        self.edge_model = edge_model
        self.node_model = node_model
        self.global_model = global_model
    
        # self.reset_parameters()
    
    def reset_parameters(self):
        """Resets all learnable parameters of the module."""
        for item in [self.node_model, self.edge_model, self.global_model]:
            if hasattr(item, 'reset_parameters'):
                item.reset_parameters()
                
    def forward(
        self,
        x:Tensor, edge_index:Tensor, 
        edge_attr: Optional[Tensor] = None, 
        u : Optional[Tensor] = None, 
        batch : Optional[Tensor] = None
    ) -> Tuple[ Tensor , Optional[Tensor] , Optional[Tensor] ] : 
        
        row , col = edge_index[0] , edge_index[1] 
        
        y =  batch if batch is None else batch[row] 
        # print( x.shape )
        # Edge level step 
        if self.edge_model is not None: 
            edge_attr = self.edge_model( x[row] , x[col] , edge_attr, u ,   y  ) 
        # Node level Step 
        if self.node_model is not None: 
            x = self.node_model(x,edge_index,edge_attr,u,batch) 
        # Graph Level Step 
        if self.global_model is not None: 
            u = self.global_model(x,edge_index,edge_attr,u,batch)  
        
        return x , edge_attr , u 
    
    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(\n'
                f'  edge_model={self.edge_model},\n'
                f'  node_model={self.node_model},\n'
                f'  global_model={self.global_model}\n'
                f')')

In [None]:
# Build the edge network : 

class EdgeNet(nn.Module): 
    def __init__(
        self, in_edge:int, 
        out_edge:int, node_dim:int,
        features:list  , global_dim:Optional[int] = None 
    ):
        super( EdgeNet , self ).__init__()
        if global_dim is None : 
            global_dim = 0 
        self.edge_mlp =  buildMLP(
            insize= 2*node_dim + global_dim +  in_edge , 
            outsize= out_edge , features= features
        )
    
    def forward(
        self,src:Tensor,dst:Tensor, 
        edge_attr:Tensor, 
        u:Optional[Tensor]=None, 
        edge_batch:Optional[Tensor]=None
    ): 
        if (u is not None) and (edge_batch is None) : 
            raise ValueError('Must Pass edge_batch if global data is present' )
        
        out = torch.cat([src,dst,edge_attr],dim=1)
        
        if u is not None : 
            out = torch.cat([out,u[edge_batch]],dim=1)
            
        return self.edge_mlp(out)

In [None]:
# let us check the model's vaidity : 

# number of node and edge featuers : 
num_node_feat = batch_data.x.shape[1] 
out_node_feat = 2*num_node_feat 
num_edge_feat = batch_data.edge_attr.shape[1]
out_edge_feat = 2*num_edge_feat 

# define the edge_network : 
edge_network =  EdgeNet( in_edge= num_edge_feat ,out_edge=out_edge_feat, node_dim=num_node_feat , features=[3 , 4 , 2 ] )

In [None]:
# get the source and destination indesis : 
srcidx , dstidx = batch_data.edge_index 
print( 'srcidx shape ==' ,  srcidx.shape )

src , dst = batch_data.x[srcidx] , batch_data.x[dstidx]


updated_edge = edge_network( src , dst , batch_data.edge_attr )
print('Shape of Edge Attributes before update' , batch_data.edge_attr.shape )
print('Shape of Edge Attributes after update' , updated_edge.shape )

srcidx shape == torch.Size([1995932])
Shape of Edge Attributes before update torch.Size([1995932, 2])
Shape of Edge Attributes after update torch.Size([1995932, 4])


In [None]:
# Node update block : 

class NodeNet( torch.nn.Module ): 
    def __init__(
        self, innode:int, outnode:int, 
        inedge:Optional[int] = 0 , 
        inglobal:Optional[int] = 0 ,  
        features:Optional[list] = 0 
    ): 
        
        super( NodeNet , self ).__init__()
        self.node_mlp = buildMLP(
            insize=innode+inedge+inglobal, 
            outsize=outnode , 
            features=features
        )
        
    def forward(
        self , x:Tensor , 
        edge_index:Tensor, 
        edge_attr:Optional[Tensor]=None, 
        u:Optional[Tensor]=None, 
        batch:Optional[Tensor]=None 
    ): 
        if u is not None and batch is None : 
            raise ValueError('Must Pass edge_batch if global data is present' )
        
        _ , col = edge_index 
        
        out = x 
        
        if edge_attr is not None : 
            y = scatter(
                edge_attr , col , dim = 0 , 
                dim_size=x.size(0) , reduce='mean'
            )
            out = torch.cat( [ out , y  ] , dim = 1 )
            del y 
        
        if u is not None : 
            out = torch.cat( out , u[batch] , dim = 1 )
             
        return self.node_mlp( out )
    

In [None]:
# here we test our node update part of the model : 

# define the node network : 
node_network = NodeNet(
    inedge = out_edge_feat, 
    innode = num_node_feat , 
    outnode = out_node_feat, inglobal=0, 
    features=[12 , 14 , 17 ] 
)

# get the updated node features : 
updated_node = node_network(
    x = batch_data.x , 
    edge_index = batch_data.edge_index , 
    edge_attr = updated_edge 
)

print('Node featuers shape before update :' , batch_data.x.shape ) 
print('Node featuers after update:' , updated_node.shape )

Node featuers shape before update : torch.Size([115584, 10])
Node featuers after update: torch.Size([115584, 20])


In [None]:
# Global Update Block : 

class GlobalNet( torch.nn.Module ): 
    def __init__(
        self, 
        inglobal:int, outglobal:int, 
        features:list, innode:int , 
        inedge:Optional[int]=0 
    ): 
        
        super( GlobalNet , self ).__init__() 
        self.global_mlp = buildMLP(
            insize=inedge+innode+inglobal, 
            outsize=outglobal, 
            features=features
        )
        
    def forward(
        self, x:Tensor, 
        edge_index:Tensor,
        u:Tensor, batch:Tensor , 
        edge_attr:Optional[Tensor]=None
    ): 
        
        src_idx , _ = edge_index 
        
        out = torch.cat([u, scatter(x,batch,dim=0,reduce='mean')] , dim = 1 )
        if edge_attr is not None : 
            out = torch.cat([u,scatter(edge_attr,batch[src_idx],dim=0,reduce='mean')],dim=1)
        
        return self.global_mlp( out )

In [None]:
# Now we will see a full meta layer in action : 

# note for our purpose we only need node and edge networks 
# and do not need a global update block :
example_meta_layer = MetaLayer(
    edge_model=edge_network , 
    node_model=node_network
)

post_meta_layer_nodes , post_meta_layer_edges , _ = example_meta_layer(
    x = batch_data.x , 
    edge_index = batch_data.edge_index , 
    edge_attr = batch_data.edge_attr 
)

print( 'Post meta layer node featuers shape' , post_meta_layer_nodes.shape )
print( 'Post meta layer edge featuers shape' , post_meta_layer_edges.shape )

Post meta layer node featuers shape torch.Size([115584, 20])
Post meta layer edge featuers shape torch.Size([1995932, 4])


In [None]:
# now we define the full GNN Model using metalayers: 

class GNN_MetaLayer_Model(torch.nn.Module): 
    
    def __init__(
        self , nmeta_layers:int, 
        node_feats:list, 
        inter_node_feats:list, 
        edge_feats:Optional[list]=None, 
        inter_edge_feats:Optional[list]=None, 
        global_feats:Optional[list]=None, 
        inter_global_feats:Optional[list]=None 
    ):
        super(GNN_MetaLayer_Model , self ).__init__()
        
        self.meta_layers = nn.ModuleList([])
        
        if inter_node_feats == None : 
            raise ValueError('Inter Node feats must also be supplied along with node_feats')
        if len( node_feats ) != nmeta_layers + 1 : 
            raise ValueError('The length of \'node_feats\' must be equal to nmeta_layers + 1 ' )
        if len( inter_node_feats ) != nmeta_layers : 
            raise ValueError('The length of \'inter_node_feats\' must be equal to nmeta_layers ' )
        
        if edge_feats != None : 
            if inter_edge_feats == None : 
                raise ValueError('Inter edge feats must also be supplied along with edge_feats')
            if len( edge_feats ) != nmeta_layers + 1 : 
                raise ValueError('The length of \'edge_feats\' must be equal to nmeta_layers + 1 ' )
            if len( inter_edge_feats ) != nmeta_layers  : 
                raise ValueError('The length of \'inter_edge_feats\' must be equal to nmeta_layers ' )
            
        if global_feats != None : 
            if inter_global_feats == None : 
                raise ValueError('Inter global feats must also be supplied along with global_feats')
            if len( global_feats ) != nmeta_layers + 1 : 
                raise ValueError('The length of \'global_feats\' must be equal to nmeta_layers + 1 ' )
            if len( inter_global_feats ) != nmeta_layers : 
                raise ValueError('The length of \'inter_global_feats\' must be equal to nmeta_layers ' )
        
        if (node_feats is None ) and (edge_feats is None ) and (global_feats is None ) : 
            raise ValueError('All Type of Netorks are Null')
        
        self.nmeta = nmeta_layers
        
        edge_part , node_part , global_part = None , None , None 
        for i in range( nmeta_layers ): 
            
            if global_feats is not None : 
                current_global_feat = global_feats[i]
            else : 
                current_global_feat = 0 
    
            if edge_feats is not None : 
                edge_part = EdgeNet(
                    in_edge=edge_feats[i],
                    out_edge=edge_feats[i+1] , 
                    node_dim=node_feats[i] , 
                    global_dim=current_global_feat, 
                    features=inter_node_feats[i]
                )
                current_edge_feat = edge_feats[i+1]
            else : 
                current_edge_feat = 0
                
            if node_feats is not None : 
                node_part = NodeNet(
                    innode=node_feats[i] , 
                    outnode=node_feats[i+1],
                    features=inter_node_feats[i],
                    inedge=current_edge_feat,
                    inglobal=current_global_feat 
                )
            if global_feats is not None : 
                global_part = GlobalNet(
                    inedge=current_edge_feat,
                    innode=node_feats[i+1],
                    inglobal=global_feats[i] , 
                    outglobal=global_feats[i+1] , 
                    features=inter_global_feats[i]
                )
            self.meta_layers.append(
                MetaLayer(
                    edge_model=edge_part,
                    node_model=node_part,
                    global_model=global_part
                )
            )
            
    def forward( self , data ):
        x , edge_attr , global_data = data.x , None , None 
        
        if 'edge_attr' in data : 
            edge_attr = data.edge_attr  
            
        if 'global_data' in data : 
            global_data = data.global_data 
        
        for i in range(self.nmeta) : 
            x , edge_attr , global_data = self.meta_layers[i](
                x = x , edge_index = data.edge_index , 
                edge_attr = edge_attr , u = global_data , batch = data.batch     
            )
            
        # return final updated  featuers ! 
        return x , edge_attr , global_data 

In [None]:
# now we test the model code : 

model = GNN_MetaLayer_Model(
    nmeta_layers=2,
    node_feats=[num_node_feat,4,1] , 
    inter_node_feats=[[8,6],[3,2]] , 
    edge_feats=[num_edge_feat,7,1] , 
    inter_edge_feats=[[3,5],[5,3]] , 
)


updated_node , updated_edge , _ = model(batch_data)
print( batch_data.x.shape , updated_node.shape )
print( batch_data.edge_attr.shape , updated_edge.shape )
print( batch_data.label )

torch.Size([115584, 10]) torch.Size([115584, 1])
torch.Size([1995932, 2]) torch.Size([1995932, 1])
tensor([[0.],
        [0.],
        [0.],
        ...,
        [1.],
        [1.],
        [1.]])


In [None]:
batch_data

Data(x=[115584, 10], edge_index=[2, 1995932], edge_attr=[1995932, 2], label=[1995932, 1], num_nodes=115584, num_edges=1995932, hit_ids=[115584])

In [None]:
# we write train and test functions : 

def train(model, train_loader, optimizer)->float:
    
    train_loss_ep , data_pts = 0. , 0 
    
    model.train()
    for _ , data in enumerate(train_loader):
        # data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        _ , edge_attr , _  = model(data)
        loss = F.binary_cross_entropy_with_logits(edge_attr, data.label , reduction='sum' )
        loss.backward()
        optimizer.step()
        print( loss )
        train_loss_ep += loss.item() 
        print( train_loss_ep )
        data_pts += data.num_edges 
        
    return train_loss_ep/data_pts


def test(model, test_loader)->float:
    
    test_loss_ep  , data_pts = 0. , 0 
    
    model.eval()
    for _ , data in enumerate(test_loader):
        # data, target = data.to(device), target.to(device)
         
        _ , edge_attr , _  = model(data)
        loss = F.binary_cross_entropy_with_logits(edge_attr, data.label , reduction='sum' )
        
        test_loss_ep += loss.item() 
        data_pts += data.num_edges 
        
    return test_loss_ep/data_pts 

In [None]:
from torch.utils.data.sampler import SubsetRandomSampler

# number of subprocesses to use for data loading
num_workers = 0
# how many samples per batch to load
batch_size = 2
# percentage of training set to use as validation
train_size, valid_size = 0.1 , 0.1 
# # convert data to torch.FloatTensor
# transform = transforms.ToTensor()
# choose the training and testing datasets
# obtain training indices that will be used for validation
dataset = EventData(path='../data/train_100_events/',device=device)
num_train = len(dataset)
indices = list(range(num_train))
np.random.shuffle(indices)
train_split = int(np.floor(train_size * num_train))
valid_split = int(np.floor(valid_size * num_train))

print(train_split, valid_split)
train_index, valid_index, test_index = indices[0:train_split], indices[train_split:train_split + valid_split], indices[train_split + valid_split:]
# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_index)
valid_sampler = SubsetRandomSampler(valid_index)
test_sampler = SubsetRandomSampler(test_index)
# prepare data loaders
train_loader = DataLoader(dataset=dataset, batch_size = batch_size, 
                                           num_workers = num_workers, sampler = train_sampler  )
valid_loader = DataLoader(dataset=dataset, batch_size = batch_size,
                                          num_workers = num_workers,  sampler = valid_sampler  )
test_loader = DataLoader(dataset=dataset, batch_size = batch_size,
                                          num_workers = num_workers,  sampler = test_sampler )

10 10


In [None]:
# - Define the optimizer --- #
optimizer = torch.optim.SGD(model.parameters(),lr = 0.01)

In [None]:
# number of epochs to train the model
n_epochs = 5
# initialize tracker for minimum validation loss
valid_loss_min = np.inf  # set initial "min" to infinity
for epoch in range(n_epochs):
    # monitor losses
    
    train_loss = train(model, train_loader, optimizer)
    
    valid_loss = test(model, valid_loader)
    
    # print training/validation statistics 
    
    print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
        epoch+1, 
        train_loss,
        valid_loss
        ))
    
    # save model if validation loss has decreased
    if valid_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
        valid_loss_min,
        valid_loss))
        torch.save(model.state_dict(), '../data/models_/event_trained_model.pt')
        valid_loss_min = valid_loss

tensor(1720788.7500, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
1720788.75
tensor(nan, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
nan
tensor(nan, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
nan
tensor(nan, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
nan


Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x1075077f0>>
Traceback (most recent call last):
  File "/Users/ashmitbathla/Documents/UGP/TrackML/VirtualEnv/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 


In [None]:
print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
        epoch+1, 
        train_loss,
        valid_loss
        ))


TypeError: unsupported format string passed to Tensor.__format__

In [None]:
train_loss

tensor([47.9358, 48.7042])