In [1]:
''' All of the imports live in this code cell'''

import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader
import scipy.io as sio
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import matplotlib.pyplot as plt
from torchmetrics.regression import MeanAbsolutePercentageError
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
import ast
import re
from torch.cuda.amp import autocast, GradScaler

# Attempt to use GPU device, check if you have MPS, else CPU device
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

'cpu'

In [None]:
class Regression_Net(nn.Module):
    def __init__(self, in_features, out_features=3, n=128):
        """
        in_features: Number of input features
        out_features: Number of output features
        n: Number of hidden units
        """
        super().__init__()
        self.n = n
        self.in_features = in_features
        self.out_features = out_features

        self.flatten = nn.Flatten()
        self.bn = nn.BatchNorm1d(in_features)
        self.bn1 = nn.BatchNorm1d(n)

        self.linear_relu_stack = nn.Sequential(
            nn.Linear(in_features, n),
            nn.ReLU(),
            nn.Linear(n, n),
            nn.ReLU(),
            nn.Linear(n, n),
        )

        self.linear_relu_stack1 = nn.Sequential(
            nn.Linear(n, n),
            nn.ReLU(),
            nn.Linear(n, out_features)
        )

    def forward(self, x):
        x = self.flatten(x)
        x = self.bn(x)
        x = self.linear_relu_stack(x)
        x = self.bn1(x)
        x = self.linear_relu_stack1(x)
        return x


In [None]:
# Define a PyTorch Dataset
class CrackDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe):
        self.data = dataframe

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # Extract crack start X (keep only the X component)
        crack_start = torch.tensor(safe_eval(row["crack_start"]), dtype=torch.float32)
        crack_start_x = crack_start[0].unsqueeze(0)  # Keep only x-component

        # Extract crack tip Y (keep only the Y component)
        crack_tip = torch.tensor(safe_eval(row["crack_tip"]), dtype=torch.float32)
        crack_tip_y = crack_tip[1].unsqueeze(0)  # Keep only y-component

        # Parse load_location and load_condition
        load_location = torch.tensor(safe_eval(row["load_location"]), dtype=torch.float32)
        load_condition = torch.tensor(safe_eval(row["load_condition"]), dtype=torch.float32)

        # **Remove y-coordinate (always 1) from load_location**
        load_location = load_location[0].unsqueeze(0)  # Keep only x-component

        # **Remove x-component from load_condition**
        load_condition = load_condition[1].unsqueeze(0)  # Keep only y-component

        # Parse bottom_surface_deflections as float array
        bottom_surface_deflections = torch.tensor(safe_eval(row["bottom_surface_deflections"]), dtype=torch.float32)
        if bottom_surface_deflections.dim() == 0:
            bottom_surface_deflections = bottom_surface_deflections.unsqueeze(0)

        # Concatenate input features
        inputs = torch.cat([load_location, load_condition, bottom_surface_deflections])

        # Targets: crack_start_x and crack_tip_y
        targets = torch.cat([crack_start_x, crack_tip_y])

        return inputs, targets


In [None]:
'''The purpose of this code cell is to provide a safe evaluation helper function for reading the CSVs. There is a holdover from where I would insert a displacement field of [0] if any errors occurred, but this has been removed from the data generation process and would no longer be required.'''

def safe_eval(val):
    """Handles parsing of tuples (comma-separated) and arrays (space-separated), removing brackets if needed."""
    if isinstance(val, (list, tuple)):
        return list(val)  # Already a list or tuple

    if isinstance(val, str):
        val = val.strip()  # Remove extra spaces

        # Handle tuple case (comma-delimited)
        if val.startswith("(") and val.endswith(")"):
            return list(ast.literal_eval(val))  # Use ast safely

        # Remove brackets from incorrectly formatted arrays
        val = re.sub(r"[\[\]]", "", val).strip()  # Remove square brackets

        # Handle array case (space-delimited)
        try:
            return [float(x) for x in val.split()]  # Convert space-separated values to floats
        except ValueError:
            raise ValueError(f"Error parsing space-separated values: {val}")

    if isinstance(val, (int, float)):
        return [val]  # Wrap single numbers in a list

    raise ValueError(f"Unexpected data type: {type(val)} - {val}")