In [1]:
import numpy as np
import matplotlib.pyplot as plt 
import torch
import torch.nn as nn
import torch.nn.functional as F

import sys
import json
import os
from chop import Chop
import json    
from torch.utils.data import TensorDataset, DataLoader, random_split
import torch.nn as nn
import torch.optim as optim


In [2]:
### torch network
class Net(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
### load data
obs_ls, act_ls = torch.load('dataset100k.pt')

print("obs_ls shape: ", obs_ls.shape)
print("act_ls shape: ", act_ls.shape)
dataset = TensorDataset(obs_ls, act_ls)

# Set split sizes 
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size

# Split the dataset
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create DataLoaders
batch_size = int(2*4096)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


  obs_ls, act_ls = torch.load('dataset100k.pt')


obs_ls shape:  torch.Size([6910796, 8])
act_ls shape:  torch.Size([6910796, 4])


In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device)


Using device:  cuda


In [11]:
model = Net(input_dim=obs_ls.shape[1], output_dim=act_ls.shape[1])
model = nn.DataParallel(model) 
model = model.to(device)
criterion = nn.MSELoss()  # Or CrossEntropyLoss, etc.
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
num_epochs = 30
epoch_losses = []

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    batch_count = 0

    for batch_X, batch_y in train_loader:
        batch_X = batch_X.to(device).float()
        batch_y = batch_y.to(device).float()

        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        batch_count += 1

    avg_loss = total_loss / batch_count
    epoch_losses.append(avg_loss)
    print(f"Epoch {epoch+1}: Avg Loss = {avg_loss:.4f}")


  return F.linear(input, self.weight, self.bias)


Epoch 1: Avg Loss = 3013.9707
Epoch 2: Avg Loss = 2968.3005
Epoch 3: Avg Loss = 2932.9305
Epoch 4: Avg Loss = 2897.6089
Epoch 5: Avg Loss = 2862.8923
Epoch 6: Avg Loss = 2826.9570
Epoch 7: Avg Loss = 2789.0576
Epoch 8: Avg Loss = 2750.1678
Epoch 9: Avg Loss = 2708.4259
Epoch 10: Avg Loss = 2668.4770
Epoch 11: Avg Loss = 2628.0162
Epoch 12: Avg Loss = 2587.8928
Epoch 13: Avg Loss = 2552.1216
Epoch 14: Avg Loss = 2519.7610
Epoch 15: Avg Loss = 2490.6486
Epoch 16: Avg Loss = 2465.1013
Epoch 17: Avg Loss = 2442.2747
Epoch 18: Avg Loss = 2420.2334
Epoch 19: Avg Loss = 2401.0850
Epoch 20: Avg Loss = 2382.4844
Epoch 21: Avg Loss = 2363.7763
Epoch 22: Avg Loss = 2349.7869
Epoch 23: Avg Loss = 2335.6760
Epoch 24: Avg Loss = 2322.1997


In [None]:
plt.plot(epoch_losses)
plt.title('Loss over epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
# plt.savefig('loss.png')
plt.show()

In [None]:
torch.save(model.state_dict(), "model_weights.pth")
