In [5]:
import time

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import f1_score

import panther.nn as pnn


In [6]:
class SimpleModel(nn.Module):
    def __init__(self, input_dim, output_dim, use_pnn=True, k=5, l=5, mode=0):
        super(SimpleModel, self).__init__()
        if use_pnn:
            self.layer = pnn.SKLinear(
                input_dim, output_dim, low_rank=k, num_terms=l, mode=mode
            )
        else:
            self.layer = nn.Linear(input_dim, output_dim)
        if output_dim == 1:
            self.activation = nn.Sigmoid()
        else:
            self.activation = nn.Softmax(dim=1)

    def forward(self, x):
        return self.activation(self.layer(x))

In [7]:
def test_model_f1score(model, X_test, y_true, device="cuda", average="weighted"):
    model.eval()
    with torch.no_grad():
        outputs = model(X_test)
        if outputs.shape[1] > 1:  # Multi-class
            preds = torch.argmax(outputs, dim=1)
        else:  # Binary (sigmoid)
            preds = (outputs > 0.5).long().squeeze()

    y_pred = preds.cpu().numpy()
    return f1_score(y_true, y_pred, average=average)


def train_model(model, X_train, y_train, epochs=10, lr=0.001, device="cuda"):
    """
    Trains a PyTorch model on the given data.

    Parameters:
    - model: A PyTorch model.
    - X_train: Input features (numpy or tensor).
    - y_train: Labels (numpy or tensor).
    - epochs: Number of training epochs.
    - lr: Learning rate.
    - device: 'cpu' or 'cuda'.

    Returns:
    - Trained model.
    """
    model.to(device)
    model.train()

    # Convert to tensors if not already
    if not isinstance(X_train, torch.Tensor):
        X_train = torch.tensor(X_train, dtype=torch.float32)
    if not isinstance(y_train, torch.Tensor):
        y_train = torch.tensor(y_train, dtype=torch.long)

    X_train, y_train = X_train.to(device), y_train.to(device)
    print("im new")
    criterion = nn.CrossEntropyLoss() if model(X_train[:1]).dim() > 1 else nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    print("im new2")
    for epoch in range(epochs):
        optimizer.zero_grad()
        print("hi")
        outputs = model(X_train)

        if outputs.dim() == 1 or outputs.shape[1] == 1:  # Binary classification
            outputs = outputs.squeeze()
            y_train = y_train.float()

        loss = criterion(outputs, y_train)
        loss.backward()
        optimizer.step()
        print("hi2")
        if epoch % (epochs // 5) == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")

    return model


def time_forward_pass(model, X, device="cuda"):
    model.to(device)
    model.eval()
    if not isinstance(X, torch.Tensor):
        X = torch.tensor(X, dtype=torch.float32)
    X = X.to(device)

    # Warm-up (optional, helps with accurate timing especially on GPU)
    with torch.no_grad():
        _ = model(X)

    # CUDA sync before and after for accurate timing
    if device == "cuda":
        torch.cuda.synchronize()
    start_time = time.time()

    with torch.no_grad():
        _ = model(X)

    if device == "cuda":
        torch.cuda.synchronize()
    elapsed = time.time() - start_time

    return elapsed


def train_evaluate_time(model, X, y):
    # trained_model = train_model(model, X, y)
    # print("Model trained.")
    f1 = test_model_f1score(model, X, y)
    print("Model evaluated.")
    elapsed_time = time_forward_pass(model, X)
    print("Forward pass time measured.")
    return f1, elapsed_time


In [None]:
input_dim = 10
output_dim = 2
model = SimpleModel(input_dim, output_dim, use_pnn=False, k=5, l=5, mode=0)
print(model)
X_train = torch.randn(100, input_dim)  # Example training data
y_train = torch.randint(0, output_dim, (100,))
print("Training and evaluating model...")
train_evaluate_time(model, X_train, y_train)


RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
