In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv
import numpy as np
import networkx as nx
from collections import OrderedDict

# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

# Define datatype
if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
    dtype_long = torch.cuda.LongTensor
else:
    dtype = torch.FloatTensor
    dtype_long = torch.LongTensor

In [2]:
class Quadratic_block(nn.Module):
    def __init__(self, input_dim, out_dim, depth):
        super(Quadratic_block, self,).__init__()
        self.modlist1 = nn.ModuleList()
        self.modlist2 = nn.ModuleList()
        self.depth = depth
        for i in range(depth):
            if i == depth-1:
                self.modlist1.append(torch.nn.Linear(input_dim, out_dim, bias=False))
                self.modlist2.append(torch.nn.Linear(input_dim, out_dim, bias=False))
            else:
                self.modlist1.append(torch.nn.Linear(input_dim, input_dim, bias=False))
                self.modlist2.append(torch.nn.Linear(input_dim, input_dim, bias=False))

    def forward(self, x):
        i = 0
        for m in self.modlist1:
            x1 = m(x)
            m2 = self.modlist2[i]
            x2 = m2(x)
            i += 1
            if i < self.depth:
                x1 = torch.tanh(x1)
                x2 = torch.tanh(x2)
            x = x1 * x2
        return x


class GCN_block(nn.Module):
    def __init__(self, in_channels, out_channels, layer,\
                  improved=False, cached=False, add_self_loops=True):
        super(GCN_block, self,).__init__()
        self.modlist = nn.ModuleList()
        self.layer = layer
        for i in range(layer):
            if i == 0:
                self.modlist.append(GCNConv(in_channels, out_channels,\
                                             improved, cached, add_self_loops))
            else:
                self.modlist.append(GCNConv(out_channels, out_channels,\
                                             improved, cached, add_self_loops))

    def forward(self, x, edge_index):
        i = 0
        for m in self.modlist:
            x = m(x, edge_index)
            i += 1
            if i < self.layer:
                x = torch.relu(x)
        return x

In [3]:
class AdaptedGNN_LMSC_cell(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, units:int, gcn_type: str,
                    batch_size: int, width=125, depth=4):
            super(AdaptedGNN_LMSC_cell, self).__init__()
            
            self.in_channels = in_channels
            self.units = units
            self.gcn_type = gcn_type
            self.depth = depth
            self.width = width

            start_dim = units + in_channels
            inside_dim = start_dim
            self.qb = Quadratic_block(inside_dim, width, depth)
        
            if gcn_type == 'GCNConv':
                self.gconv1 = GCN_block(inside_dim, inside_dim, layer=1)
                self.gconv2 = GCN_block(inside_dim, inside_dim, layer=1)
            
            if self.depth > 0:
                inside_dim = width
            else:
                inside_dim = in_channels

            self.fc_alpha = nn.Linear(inside_dim, units)
            self.fc_beta = nn.Linear(inside_dim, units)

            # for m in self.modules():
            #     if isinstance(m, nn.Linear):
            #         torch.nn.init.xavier_uniform_(m.weight)
            #         m.bias.data.fill_(0)

    def forward(self, X, edge_index, edge_weight=None, H=None):
        """
        Modified forward pass for your feature structure
        """
        h_t = H
        
        # REMOVE strain normalization - you don't have separate strain increments
        # Your features are: [size, phase_id, orient1, orient2, orient3, strain_step]
        
        # Use all features as input (no strain decomposition)
        x_input = X.clone()
        
        # Concatenate input with hidden state
        cat_input = torch.cat([x_input, h_t], dim=-1)  # [batch, num_nodes, units + in_channels]

        # Graph embedding
        if self.gcn_type == 'GCNConv':
            G_input1 = self.gconv1(cat_input, edge_index)
            G_input1 = F.relu(G_input1) 
            G_input2 = self.gconv2(cat_input, edge_index)
            G_input2 = F.relu(G_input2) 
        else:
            G_input1 = cat_input
            G_input2 = cat_input

        G_input1 = torch.tanh(self.qb(G_input1))
        G_input2 = torch.tanh(self.qb(G_input2))

        alpha = torch.exp(self.fc_alpha(G_input1))
        beta = torch.tanh(self.fc_beta(G_input2))

        pseudo_strain_norm = 1.0  # Constant for single-step prediction
        exp_f = torch.exp(-alpha * pseudo_strain_norm)
        
        h = exp_f * (h_t - beta) + beta

        return h

In [4]:
class SingleStepGrainPredictor(nn.Module):
    def __init__(self, args):
        super(SingleStepGrainPredictor, self).__init__()
        self.hidden_dim = args.hidden_dim
        self.output_dim = args.output_dim  
        self.batch_size = args.batch_size
        
        # Input dimension matches your features: [size, phase_id, orient1, orient2, orient3, strain_step]
        self.input_dim = args.input_dim  
        
        # Single LMSC cell for one-step prediction
        self.lmsc = AdaptedGNN_LMSC_cell(
            self.input_dim,  
            self.hidden_dim,  # Output hidden states, not predictions
            units=self.hidden_dim, 
            gcn_type=args.GCN_type, 
            batch_size=args.batch_size
        )
        
        # Decoder to predict next grain state (excluding strain_step)
        self.decoder = nn.Linear(self.hidden_dim, self.output_dim)

    def forward(self, data):
        '''
        Modified for single-step prediction with variable subgraphs
        '''
        x, edge_index = data.x.type(dtype), data.edge_index
        edge_index = edge_index.to(device)

        # Your data: x shape [batch_size * num_nodes, input_dim]
        # Reshape to [batch_size, num_nodes, input_dim]
        batch_size = data.num_graphs if hasattr(data, 'num_graphs') else 1
        num_nodes = data.num_nodes
        
        x = x.view(batch_size, num_nodes, self.input_dim)

        # Hidden state initialization
        h0 = torch.zeros(batch_size, num_nodes, self.hidden_dim).to(device)
        
        # Initialize with current features (orientation part)
        init_ori_size = data.init_ori.size(-1)  # This will be 4 in your case
        h0[:, :, :init_ori_size] = data.init_ori.unsqueeze(1).expand(-1, num_nodes, -1)

        # SINGLE step prediction (not sequence)
        # Use current state as input to predict next state
        h_next = self.lmsc(x, edge_index, H=h0)  # [batch, num_nodes, hidden_dim]
        
        # Extract only the target grain (always index 0 in your subgraph)
        target_hidden = h_next[:, 0, :]  # [batch, hidden_dim]
        
        # Predict next grain state (excluding strain_step)
        next_state = self.decoder(target_hidden)  # [batch, output_dim]
        
        return next_state

In [6]:
from sklearn.model_selection import train_test_split

# all_data_points = torch.load('grain_data.pt')
all_data_points = torch.load('grain_data_normalised.pt')

train_data, temp_data = train_test_split(all_data_points, test_size=0.3, random_state=42)
val_data, test_data   = train_test_split(temp_data, test_size=0.5, random_state=42)

train_loader = DataLoader(train_data, batch_size=1, shuffle=True)
val_loader   = DataLoader(val_data, batch_size=1, shuffle=False)
test_loader  = DataLoader(test_data, batch_size=1, shuffle=False)

Removed Strain Step Model

In [None]:
import torch
from torch_geometric.loader import DataLoader
import torch.optim as optim
from sklearn.model_selection import train_test_split

# 1. Define your model and configuration
class Args:
    hidden_dim = 250
    input_dim = 7    # [size, avg_shear, phase_id, orient1, orient2, orient3]
    output_dim = 7   # [size, avg_shear, phase_id, orient1, orient2, orient3]
    batch_size = 1
    GCN_type = 'GCNConv'

args = Args()
model = SingleStepGrainPredictor(args).to(device)

s_size = torch.nn.Parameter(torch.tensor(0.0, device=device))
s_shear = torch.nn.Parameter(torch.tensor(0.0, device=device))
s_orient = torch.nn.Parameter(torch.tensor(0.0, device=device))
s_phase = torch.nn.Parameter(torch.tensor(0.0, device=device))

loss_fn_size = nn.MSELoss()
loss_fn_shear = nn.MSELoss()
loss_fn_phase = nn.BCEWithLogitsLoss()
loss_fn_orient = nn.MSELoss()

optimizer = optim.Adam(list(model.parameters()), lr=0.001)

train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
val_loader   = DataLoader(val_data, batch_size=args.batch_size, shuffle=False)
test_loader  = DataLoader(test_data, batch_size=args.batch_size, shuffle=False)

# --- 2. Define helper: loss computation ---
def compute_losses(predictions, batch):
    target_size = batch.y[:, 0]
    target_shear = batch.y[:, 1]
    target_phase = batch.y[:, 2]
    target_phase_binary = target_phase - 1.0
    target_orientation = batch.y[:, 3:7]

    pred_size = predictions[:, 0]
    pred_shear = predictions[:, 1]
    pred_phase_logit = predictions[:, 2]
    pred_orientation = predictions[:, 3:7]

    loss_size   = loss_fn_size(pred_size.squeeze(), target_size)
    loss_shear  = loss_fn_shear(pred_shear, target_shear)
    loss_orient = loss_fn_orient(pred_orientation, target_orientation)
    loss_phase  = loss_fn_phase(pred_phase_logit, target_phase_binary)

    total = loss_size + loss_shear + loss_orient + loss_phase

    return total, (loss_size, loss_shear, loss_phase, loss_orient)


# --- 3. Training + validation loop with early stopping ---
def train_val_model(model, train_loader, val_loader, optimizer, num_epochs=100, patience=10):
    best_val_loss = float('inf')
    patience_counter = 0

    avg_losses, size_losses, shear_losses, phase_losses, orient_losses, val_losses = [], [], [], [], [], []

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        total_loss_size = 0
        total_loss_shear = 0
        total_loss_phase = 0
        total_loss_orient = 0

        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            predictions = model(batch)
            loss, (l_size, l_shear, l_phase, l_orient) = compute_losses(predictions, batch)
            # print(loss.item())
            # print(f'Losses: {l_size.item()}, {l_shear.item()}, {l_phase.item()}, {l_orient.item()}')
            # print(f'Scales: {s_size.item()}, {s_shear.item()}, {s_phase.item()}, {s_orient.item()}')

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_loss_size += (torch.exp(-s_size) * l_size).item() 
            # print(((torch.exp(-s_size) * l_size) + s_size).item())
            total_loss_shear += (torch.exp(-s_shear) * l_shear).item() 
            # print((torch.exp(-s_shear) * l_shear).item() + s_shear)
            total_loss_phase += (torch.exp(-s_phase) * l_phase).item() 
            # print((torch.exp(-s_shear) * l_shear).item() + s_shear)
            total_loss_orient += (torch.exp(-s_orient) * l_orient).item() 
            # print((torch.exp(-s_orient) * l_orient).item() + s_orient)

        avg_loss = total_loss / len(train_loader)
        avg_loss_size = total_loss_size / len(train_loader)
        avg_loss_shear = total_loss_shear / len(train_loader)
        avg_loss_phase = total_loss_phase / len(train_loader)
        avg_loss_orient = total_loss_orient / len(train_loader)

        # --- Validation ---
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                predictions = model(batch)
                loss, _ = compute_losses(predictions, batch)
                val_loss += loss.item()
        val_loss /= len(val_loader)

        # --- Store metrics ---
        avg_losses.append(avg_loss)
        size_losses.append(avg_loss_size)
        shear_losses.append(avg_loss_shear)
        phase_losses.append(avg_loss_phase)
        orient_losses.append(avg_loss_orient)
        val_losses.append(val_loss)

        # --- Print progress ---
        print(f'Epoch {epoch:3d}, Train Loss: {avg_loss:.6f}, Val Loss: {val_loss:.6f}')
        print(f'  Size Loss: {avg_loss_size:.6f}')
        print(f'  Shear Loss: {avg_loss_shear:.6f}')
        print(f'  Phase Loss: {avg_loss_phase:.6f}')
        print(f'  Orient Loss: {avg_loss_orient:.6f}')
        # print(s_size, s_shear, s_phase, s_orient)
        print('-' * 60)

        # --- Early stopping check ---
        if val_loss < best_val_loss - 1e-6:  # small tolerance
            best_val_loss = val_loss
            patience_counter = 0
            best_model_state = model.state_dict()
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}! Best Val Loss: {best_val_loss:.6f}")
                model.load_state_dict(best_model_state)
                break

    return model, avg_losses, size_losses, shear_losses, phase_losses, orient_losses, val_losses

# --- 5. Run training ---
print("Starting training with validation and early stopping...")
trained_model, avg_losses, size_losses, shear_losses, phase_losses, orient_losses, val_losses = \
    train_val_model(model, train_loader, val_loader, optimizer, num_epochs=20, patience=8)
print("Training completed!")


Starting training with validation and early stopping...


  return F.mse_loss(input, target, reduction=self.reduction)


In [None]:
# Save model
torch.save(trained_model, "gnn_model.pth")

In [6]:
# Load model
# model = torch.load("gnn_model.pth")
# model = torch.load("gnn_model.pth", map_location=torch.device('cpu')) # no gpu
model = torch.load("gnn_model_normalised.pth", map_location=torch.device('cpu')) # no gpu
model.eval()

SingleStepGrainPredictor(
  (lmsc): AdaptedGNN_LMSC_cell(
    (qb): Quadratic_block(
      (modlist1): ModuleList(
        (0): Linear(in_features=257, out_features=257, bias=False)
        (1): Linear(in_features=257, out_features=257, bias=False)
        (2): Linear(in_features=257, out_features=257, bias=False)
        (3): Linear(in_features=257, out_features=125, bias=False)
      )
      (modlist2): ModuleList(
        (0): Linear(in_features=257, out_features=257, bias=False)
        (1): Linear(in_features=257, out_features=257, bias=False)
        (2): Linear(in_features=257, out_features=257, bias=False)
        (3): Linear(in_features=257, out_features=125, bias=False)
      )
    )
    (gconv1): GCN_block(
      (modlist): ModuleList(
        (0): GCNConv(257, 257)
      )
    )
    (gconv2): GCN_block(
      (modlist): ModuleList(
        (0): GCNConv(257, 257)
      )
    )
    (fc_alpha): Linear(in_features=125, out_features=250, bias=True)
    (fc_beta): Linear(in_featu

In [16]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, r2_score
import torch
import numpy as np

def evaluate_model(model, test_loader):
    model.eval()
    test_loss = 0
    
    # Collect metrics
    all_true_phase = []
    all_pred_phase = []
    all_true_size = []
    all_pred_size = []
    all_true_shear = []
    all_pred_shear = []
    all_true_orient = []
    all_pred_orient = []
    
    mse_size = 0
    mse_shear = 0
    mse_orient = 0
    mae_size = 0
    mae_shear = 0
    mae_orient = 0
    num_batches = 0
    
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            predictions = model(batch)
            loss, (loss_size, loss_shear, loss_phase, loss_orient) = compute_losses(predictions, batch)
            test_loss += loss.item()
            num_batches += 1

            # --- Extract true & predicted values ---
            true_size = batch.y[:, 0].cpu().numpy()
            true_shear = batch.y[:, 1].cpu().numpy()
            true_phase = batch.y[:, 2].cpu().numpy()  # stays 1 or 2
            true_orient = batch.y[:, 3:7].cpu().numpy()
            
            pred_size = predictions[:, 0].cpu().numpy()
            pred_shear = predictions[:, 1].cpu().numpy()
            
            phase_logit = predictions[:, 2]
            phase_prob = torch.sigmoid(phase_logit)
            pred_phase = (phase_prob > 0.5).float() + 1  # still 1 or 2
            pred_orient = predictions[:, 3:7].cpu().numpy()

            # --- Regression metrics ---
            mse_size += np.mean((pred_size - true_size) ** 2)
            mae_size += np.mean(np.abs(pred_size - true_size))
            mse_shear += np.mean((pred_shear - true_shear) ** 2)
            mae_shear += np.mean(np.abs(pred_shear - true_shear))
            mse_orient += np.mean((pred_orient - true_orient) ** 2)
            mae_orient += np.mean(np.abs(pred_orient - true_orient))

            # --- Accumulate for R² computation ---
            all_true_size.extend(true_size)
            all_pred_size.extend(pred_size)
            all_true_shear.extend(true_shear)
            all_pred_shear.extend(pred_shear)
            all_true_orient.extend(true_orient.flatten())
            all_pred_orient.extend(pred_orient.flatten())

            # --- Classification metrics ---
            all_true_phase.extend(true_phase)
            all_pred_phase.extend(pred_phase.cpu().numpy())

    # --- Average regression metrics ---
    test_loss /= num_batches
    mse_size /= num_batches
    mae_size /= num_batches
    mse_shear /= num_batches
    mae_shear /= num_batches
    mse_orient /= num_batches
    mae_orient /= num_batches

    # --- Compute R² scores ---
    r2_size = r2_score(all_true_size, all_pred_size)
    r2_shear = r2_score(all_true_shear, all_pred_shear)
    r2_orient = r2_score(all_true_orient, all_pred_orient)

    # --- Classification scores (still using 1/2 labels) ---
    all_true_phase = np.asarray(all_true_phase, dtype=int).ravel()
    all_pred_phase = np.asarray(all_pred_phase, dtype=int).ravel()

    # Sanity prints
    print("Phase label counts (true):", dict(zip(*np.unique(all_true_phase, return_counts=True))))
    print("Phase label counts (pred):", dict(zip(*np.unique(all_pred_phase, return_counts=True))))

    phase_acc = accuracy_score(all_true_phase, all_pred_phase)
    phase_prec = precision_score(all_true_phase, all_pred_phase, pos_label=2, zero_division=0)
    phase_rec = recall_score(all_true_phase, all_pred_phase, pos_label=2, zero_division=0)
    phase_f1 = f1_score(all_true_phase, all_pred_phase, pos_label=2, zero_division=0)

    # --- Print summary ---
    print("\n==== Test Evaluation Results ====")
    print(f"Total Weighted Test Loss: {test_loss:.6f}")
    print(f"Size  → MSE: {mse_size:.6f}, MAE: {mae_size:.6f}, R²: {r2_size:.4f}")
    print(f"Shear → MSE: {mse_shear:.6f}, MAE: {mae_shear:.6f}, R²: {r2_shear:.4f}")
    print(f"Orient→ MSE: {mse_orient:.6f}, MAE: {mae_orient:.6f}, R²: {r2_orient:.4f}")
    print("\nPhase Classification Metrics (labels 1/2):")
    print(f"  Accuracy:  {phase_acc:.4f}")
    print(f"  Precision: {phase_prec:.4f}")
    print(f"  Recall:    {phase_rec:.4f}")
    print(f"  F1 Score:  {phase_f1:.4f}")
    print('=' * 60)

# evaluate_model(trained_model, test_loader)
evaluate_model(model, test_loader)

  return F.mse_loss(input, target, reduction=self.reduction)


Phase label counts (true): {1: 2998, 2: 710}
Phase label counts (pred): {1: 2824, 2: 884}

==== Test Evaluation Results ====
Total Weighted Test Loss: 0.199760
Size  → MSE: 0.008899, MAE: 0.044323, R²: 0.0211
Shear → MSE: 0.011764, MAE: 0.084537, R²: 0.0220
Orient→ MSE: 0.056961, MAE: 0.116118, R²: 0.7396

Phase Classification Metrics (labels 1/2):
  Accuracy:  0.9531
  Precision: 0.8032
  Recall:    1.0000
  F1 Score:  0.8908


In [17]:
import matplotlib.pyplot as plt

# --- Plot total loss and each component ---
plt.figure(figsize=(10, 6))
plt.plot(avg_losses, label='Total Loss', linewidth=2)
plt.plot(size_losses, label='Size Loss', linestyle='--')
plt.plot(shear_losses, label='Shear Loss', linestyle='--')
plt.plot(phase_losses, label='Phase Loss', linestyle='--')
plt.plot(orient_losses, label='Orientation Loss', linestyle='--')

plt.title("Training Losses Across Epochs", fontsize=14, fontweight='bold')
plt.xlabel("Epoch", fontsize=12, fontweight='bold')
plt.ylabel("Loss", fontsize=12, fontweight='bold')

plt.legend(fontsize=10, frameon=True)
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.show()


NameError: name 'avg_losses' is not defined

<Figure size 1000x600 with 0 Axes>

In [7]:
import pickle
import networkx as nx
import numpy as np

with open(r"..\All Grain Data\all_grains.pkl", "rb") as f:
    all_grains = pickle.load(f)

with open(r"..\All Grain Data\grain_property.pkl", "rb") as f:
    grain_property = pickle.load(f)
    
with open("..\grain_tracking_graph_7steps.pkl", "rb") as f:
    linked_grain_graph = pickle.load(f)

with open('..\strain_grains.pkl', 'rb') as f:
    strain_grains = pickle.load(f)

Compare input, model prediction, true output

In [None]:
def get_data(step_and_id: tuple):
    step, grain_id = step_and_id
    positions = all_grains[step][grain_id]
    property = grain_property[step][grain_id]
    return positions, property


def compare_predictions(index, model):
    sample = all_data_points[index].to(device)

    input_step = sample.current_step.item()
    input_gid = sample.current_gid.item()
    next_step = sample.next_step.item()
    next_gid = sample.next_gid.item()

    input_pos, input_prop = get_data((input_step, input_gid))
    input_shear = strain_grains[input_step]['strains'][input_gid]
    print(f'Input Sample: {input_step, input_gid}')
    print(f'Input Shear: {input_shear}')
    print(f'Input Size: {len(input_pos)}')
    print(f"Input Phase: {input_prop['phase_id']}")
    print(f"nput Phase: {input_prop['average_orientation']}")
    print('\n')

    true_pos, true_prop = get_data((next_step, next_gid))
    true_shear = strain_grains[next_step]['strains'][next_gid]
    print(f'Next step: {next_step, next_gid}')
    print(f'True Shear: {true_shear}')
    print(f'True Size: {len(true_pos)}')
    print(f"True -hase: {true_prop['phase_id']}")
    print(f"True Phase: {true_prop['average_orientation']}")
    print('\n')


    with torch.no_grad():  # disable gradient tracking
        pred = model(sample)

    # print(pred.tolist()[0])
    pred_tensor = pred[0]  # Get first batch element if needed, but keep as tensor

    grain_size = pred_tensor[0] * (5016-10) + 10
    grain_shear = pred_tensor[1]
    phase_logit = pred_tensor[2]
    phase_prob = torch.sigmoid(phase_logit)
    phase = (phase_prob > 0.5).float() + 1
    orientations = pred_tensor[3:7]

    grain_size_value = grain_size.item()
    grain_shear = grain_shear.item()
    phase_value = phase.item()
    orientations_list = orientations.tolist()

    print(f"Prediction Grain size: {grain_size_value}")
    print(f"Prediction Grain shear: {grain_shear}")
    print(f"Prediction Phase: {phase_value}")
    print(f"Prediction Orientations: {orientations_list}")

compare_predictions(254, model)

Input Sample: (4, 2409)
Input Shear: 0.02307494208188692
Input Size: 429
Input Phase: 1
nput Phase: [ 0.9309755  -0.14824287  0.06433457 -0.32736789]


Next step: (5, 2448)
True Shear: 0.02212937857689785
True Size: 436
True -hase: 1
True Phase: [ 0.93124301 -0.14806403  0.06096297 -0.32733319]


Prediction Grain size: 246.9119873046875
Prediction Grain shear: 0.03876620531082153
Prediction Phase: 1.0
Prediction Orientations: [0.8718612790107727, -0.17640642821788788, 0.0202375166118145, -0.12630008161067963]


In [8]:
def get_data(step_and_id: tuple):
    step, grain_id = step_and_id
    positions = all_grains[step][grain_id]
    property = grain_property[step][grain_id]
    return positions, property


def compare_predictions(index, model):
    sample = all_data_points[index].to(device)

    input_step = sample.current_step.item()
    input_gid = sample.current_gid.item()
    next_step = sample.next_step.item()
    next_gid = sample.next_gid.item()

    input_pos, input_prop = get_data((input_step, input_gid))
    input_shear = strain_grains[input_step]['strains'][input_gid]
    print(f'Input Sample: {input_step, input_gid}')
    print(f'Input Shear: {input_shear}')
    print(f'Input Size: {len(input_pos)}')
    print(f"Input Phase: {input_prop['phase_id']}")
    print(f"nput Phase: {input_prop['average_orientation']}")
    print('\n')

    true_pos, true_prop = get_data((next_step, next_gid))
    true_shear = strain_grains[next_step]['strains'][next_gid]
    print(f'Next step: {next_step, next_gid}')
    print(f'True Shear: {true_shear}')
    print(f'True Size: {len(true_pos)}')
    print(f"True -hase: {true_prop['phase_id']}")
    print(f"True Phase: {true_prop['average_orientation']}")
    print('\n')


    with torch.no_grad():  # disable gradient tracking
        pred = model(sample)

    # print(pred.tolist()[0])
    pred_tensor = pred[0]  # Get first batch element if needed, but keep as tensor

    # grain_size = pred_tensor[0] * (5016-10) + 10
    grain_size = pred_tensor[0]
    grain_shear = pred_tensor[1]
    phase_logit = pred_tensor[2]
    phase_prob = torch.sigmoid(phase_logit)
    phase = (phase_prob > 0.5).float() + 1
    orientations = pred_tensor[3:7]

    grain_size_value = grain_size.item()
    grain_shear = grain_shear.item()
    phase_value = phase.item()
    orientations_list = orientations.tolist()

    print(f"Prediction Grain size: {grain_size_value}")
    print(f"Prediction Grain shear: {grain_shear}")
    print(f"Prediction Phase: {phase_value}")
    print(f"Prediction Orientations: {orientations_list}")

compare_predictions(254, model)

Input Sample: (4, 2409)
Input Shear: 0.02307494208188692
Input Size: 429
Input Phase: 1
nput Phase: [ 0.9309755  -0.14824287  0.06433457 -0.32736789]


Next step: (5, 2448)
True Shear: 0.02212937857689785
True Size: 436
True -hase: 1
True Phase: [ 0.93124301 -0.14806403  0.06096297 -0.32733319]


Prediction Grain size: 0.05725782364606857
Prediction Grain shear: 0.22193476557731628
Prediction Phase: 1.0
Prediction Orientations: [0.9125292301177979, -0.13527855277061462, 0.09346653521060944, -0.08168549090623856]


In [10]:
def compare_predictions(index):
    sample = all_data_points[index]

    print(f"Input {sample['x'][0]}")
    print(f"True: {sample['y'][0]}")

    with torch.no_grad():  # disable gradient tracking
        pred = model(sample)

    pred_tensor = pred[0]

    grain_size = pred_tensor[0]
    grain_shear = pred_tensor[1]
    phase_logit = pred_tensor[2]
    phase_prob = torch.sigmoid(phase_logit)
    phase = (phase_prob > 0.5).float() + 1
    orientations = pred_tensor[3:7]

    grain_size_value = grain_size.item()
    grain_shear = grain_shear.item()
    phase_value = phase.item()
    orientations_list = orientations.tolist()

    print(f"Prediction: {[grain_size_value, grain_shear, phase_value, orientations_list]}")

compare_predictions(100)

Input tensor([ 0.0534,  0.1087,  1.0000,  0.1189, -0.0442, -0.0522, -0.9905])
True: tensor([ 0.0554,  0.1668,  1.0000,  0.1204, -0.0466, -0.0507, -0.9903])
Prediction: [0.05481639504432678, 0.21847331523895264, 1.0, [0.18501469492912292, -0.029368892312049866, -0.00817003846168518, 0.022139888256788254]]


In [32]:
min_size = float('inf')
max_size = 0

for sample in all_data_points:              # iterate over all data points

    input_step = sample.current_step.item()
    input_gid = sample.current_gid.item()
    next_step = sample.next_step.item()
    next_gid = sample.next_gid.item()

    input_pos, input_prop = get_data((input_step, input_gid))
    next_pos, next_prop = get_data((next_step, next_gid))

    curr_size = len(input_pos)
    next_size = len(next_pos)

    # update global min/max
    min_size = min(min_size, curr_size, next_size)
    max_size = max(max_size, curr_size, next_size)

print("Min size:", min_size)
print("Max size:", max_size)


Min size: 10
Max size: 5016
