In [None]:
# %%
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric.transforms as T
from sklearn.model_selection import train_test_split
from tensorboardX import SummaryWriter
from torch.utils.data import Subset
from torch_geometric.loader import DataLoader

from src.dataset import RobotGraph

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(MLP, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, latent_dim),
        )

    def forward(self, data):
        x = data.x
        x = self.mlp(x)
        return x



In [None]:

def train(train_loader, test_loader, writer, model, optimizer, num_epochs=50, l1_lambda=0.01):
    print("Device:", device)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.975)
    model.to(device)

    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            pred = model(batch)
            loss = F.huber_loss(pred.squeeze(), batch.y)
            # L1 regularization
            l1_reg = torch.tensor(0., device=device, requires_grad=True)
            for name, param in model.named_parameters():
                if 'weight' in name:
                    l1_reg = l1_reg + torch.norm(param, 1)
            loss_no_reg = loss.detach().cpu().numpy()
            loss = loss + l1_lambda * l1_reg
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * batch.num_graphs
        total_loss /= len(train_loader.dataset)
        writer.add_scalar("loss", total_loss, epoch)

        if epoch % 1 == 0:
            test_loss = test(test_loader, model)
            print(
                f"Epoch {epoch}. L1 Loss: {total_loss:.4f}. Huber Loss:{loss_no_reg: .4f}. Test loss: {test_loss:.4f}")
            writer.add_scalar("test_loss", test_loss, epoch)
        if epoch % 20 == 0:
            scheduler.step()

    return model


In [None]:

def test(loader, model):
    model.eval()
    total_loss = 0.0
    total_samples = 0

    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            pred = model(data)
        target = data.y.reshape(pred.shape[0], pred.shape[1])

        # loss = F.mse_loss(pred.squeeze(), target).to(device)
        loss = F.huber_loss(pred, target).to(device)
        total_loss += loss.item() * data.num_graphs
        total_samples += data.num_graphs

        avg_loss = total_loss / total_samples
    return avg_loss

In [None]:

def main():
    writer = SummaryWriter("../runs/test")
    transform = T.NormalizeFeatures()
    dataset = RobotGraph("dataset")

    train_ratio = 0.8
    num_samples = len(dataset)
    num_train = int(train_ratio * num_samples)
    indices = list(range(num_samples))
    train_indices, test_indices = train_test_split(
        indices, train_size=num_train, shuffle=True, random_state=42
    )

    train_dataset = Subset(dataset, train_indices)
    # plot_goal_pos(train_dataset)
    test_dataset = Subset(dataset, test_indices)

    # Normalize node features in the training dataset
    # train_dataset = normalize_node_features(train_dataset)
    # test_dataset = normalize_node_features(test_dataset)

    train_loader = DataLoader(
        train_dataset, batch_size=512, shuffle=True, drop_last=True, num_workers=4
    )
    test_loader = DataLoader(
        test_dataset, batch_size=512, shuffle=True, drop_last=True, num_workers=4
    )

    # model = GCN(128, 1)
    model = MLP(8, 128, 1)
    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=5e-4)

    trained_model = train(train_loader, test_loader, writer, model, optimizer)
    print("Training complete.")
    # Save the model
    torch.save(trained_model, "models/model.pt")


if __name__ == "__main__":
    main()
