In [1]:
import numpy as np
import matplotlib.pyplot as plt
from model import *
import sys 
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import StepLR
from loss_function import *

In [2]:
data_model = diffusion_equation()
dataset = data_model.generate_training_data(500, 10, dlt_t = 0.001)

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam

data_loader = DataLoader(dataset, batch_size=64, shuffle=True)

# Define a simple model with a single trainable parameter `k`
class SimpleModel(nn.Module):
    def __init__(self, k_value = None):
        super(SimpleModel, self).__init__()
        if k_value is None:
            self.k = nn.Parameter(torch.randn((), dtype=torch.float64))
        else:
            self.k = nn.Parameter(torch.tensor(k_value, dtype=torch.float64))  # Initialize k as a parameter

    def forward(self, lace_1, lace_2):
        g = lace_1
        h = self.k * lace_1 + lace_2
        return g, h

model = SimpleModel()

In [11]:
from loss_function import *
# Create a custom DataLoader with a collate function for batch processing
def create_orthogonalized_data_loader(dataset, model, batch_size):
    def collate_fn(batch):
        # Unpack the batch data
        x, y, lace_1, lace_2 = zip(*batch)

        # Convert lists of samples into tensors
        x = torch.stack(x)
        y = torch.stack(y)
        lace_1 = torch.stack(lace_1)
        lace_2 = torch.stack(lace_2)

        # Use the model to generate `g` and `h`
        with torch.no_grad():  # Disable gradient computation during data preprocessing
            g, _ = model(lace_1, lace_2)

        # Orthogonalize the generated `g` matrix
        orthogonal_g = orthogonalize_columns(g)

        # Return the processed batch data
        return x, y, orthogonal_g, lace_1, lace_2

    # Return a DataLoader with the custom collate function
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# Example usage
model = SimpleModel()
batch_size = 64
orthogonalized_data_loader = create_orthogonalized_data_loader(dataset, model, batch_size)

In [12]:
from tqdm import tqdm

def train_one_epoch(model, optimizer, data_loader, epoch):
    """
    Train the model for one epoch using the provided data loader with a progress bar.

    Parameters:
    model (nn.Module): The model with the trainable parameter `k`.
    optimizer (torch.optim.Optimizer): The optimizer for updating model parameters.
    data_loader (DataLoader): The data loader with orthogonalized `g` and input data.
    epoch (int): The current epoch number (for display purposes).

    Returns:
    float: The average loss for this epoch.
    """
    model.train()  # Set the model to training mode
    total_loss = 0.0
    num_batches = 0

    # Initialize progress bar for the current epoch
    progress_bar = tqdm(data_loader, desc=f"Epoch {epoch + 1}", leave=True)

    # Iterate over batches from the data loader
    for x, y, orthogonal_g, lace_1, lace_2 in progress_bar:
        optimizer.zero_grad()

        # Compute `h` using the model's parameter `k`
        _, h = model(lace_1, lace_2)

        # Compute the loss using the orthogonalized `g` and computed `h`
        loss = loss_function_orth(orthogonal_g, h)

        # Backpropagate the loss and update model parameters
        loss.backward()
        optimizer.step()

        # Accumulate the loss
        total_loss += loss.item()
        num_batches += 1

        # Update the progress bar with the current batch loss
        progress_bar.set_postfix({"Batch Loss": loss.item()})

    # Compute the average loss for the epoch
    average_loss = total_loss / num_batches
    return average_loss


In [13]:
def train(model, optimizer, data_loader, epochs):
    """
    Train the model over multiple epochs with a progress bar.

    Parameters:
    model (nn.Module): The model with the trainable parameter `k`.
    optimizer (torch.optim.Optimizer): The optimizer for updating model parameters.
    data_loader (DataLoader): The data loader with orthogonalized `g`.
    epochs (int): The number of training epochs.

    Returns:
    list: A list of average loss values for each epoch.
    """
    loss_history = []

    # Iterate over epochs
    for epoch in range(epochs):
        # Train one epoch with progress bar
        average_loss = train_one_epoch(model, optimizer, data_loader, epoch)
        loss_history.append(average_loss)

        # Print the average loss for the current epoch
        print(f"Epoch {epoch + 1}/{epochs}, Average Loss: {average_loss:.6f}")

    return loss_history


In [14]:
if __name__ == "__main__":
    # Initialize model and optimizer
    model = SimpleModel()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    batch_size = 64

    # Create the data loader
    orthogonalized_data_loader = create_orthogonalized_data_loader(dataset, model, batch_size)

    # Train the model with progress bar
    epochs = 10
    loss_history = train(model, optimizer, orthogonalized_data_loader, epochs)

    # Plot the training loss history
    import matplotlib.pyplot as plt
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, epochs + 1), loss_history, label="Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Average Loss")
    plt.title("Training Loss History")
    plt.legend()
    plt.grid()
    plt.show()


Epoch 1:  20%|█▉        | 5635/28204 [00:23<01:33, 241.50it/s, Batch Loss=0.00196]  


KeyboardInterrupt: 

In [None]:
# Example usage
d, m, n = 10, 5, 7
G = torch.randn(d, n, dtype=torch.float64, requires_grad=False)
H = torch.randn(d, m, dtype=torch.float64, requires_grad=True)

# Compute loss
loss = loss_function_orth(G, H)

# Perform backpropagation
loss.backward()

# Check if gradients are computed for H
print("Loss:", loss.item())
print("H.grad:", H.grad)


Loss: 7.142012963312397
H.grad: tensor([[ 2.9855e-01,  3.9371e-01, -2.7795e-01, -5.0511e-02,  1.4365e+00],
        [-6.4217e-01, -8.4684e-01,  5.9784e-01,  1.0864e-01, -3.0898e+00],
        [-4.9156e-03, -6.4823e-03,  4.5762e-03,  8.3164e-04, -2.3651e-02],
        [-1.0693e-01, -1.4101e-01,  9.9546e-02,  1.8090e-02, -5.1448e-01],
        [-5.2043e-01, -6.8630e-01,  4.8450e-01,  8.8048e-02, -2.5040e+00],
        [ 2.5260e-01,  3.3310e-01, -2.3516e-01, -4.2735e-02,  1.2154e+00],
        [ 6.2842e-01,  8.2871e-01, -5.8504e-01, -1.0632e-01,  3.0236e+00],
        [-5.4632e-01, -7.2044e-01,  5.0860e-01,  9.2428e-02, -2.6286e+00],
        [ 3.1809e-01,  4.1947e-01, -2.9613e-01, -5.3816e-02,  1.5305e+00],
        [ 1.1084e-01,  1.4617e-01, -1.0319e-01, -1.8753e-02,  5.3332e-01]],
       dtype=torch.float64)
