In [1]:
import os
CURRENT_PATH = os.getcwd()

import numpy as np

import sklearn
import torch
from torch.nn import functional as F
import torch.optim as optim

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

data_path = os.path.join(CURRENT_PATH, './data/wrestling')
demo_obs = np.load(os.path.join(data_path, 'obs_10000.npy')).astype(np.float32)
demo_actions = np.load(os.path.join(data_path, 'actions_10000.npy')).astype(np.float32)
#demo_actions = ((demo_actions + [100.0, 30.0]) / [300.0, 60.0]).astype(np.float32)
demo_obs = np.expand_dims(demo_obs, axis=1)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

print(f'obs shape = {demo_obs.shape}, action shape = {demo_actions.shape}') 

obs shape = (10000, 1, 40, 40), action shape = (10000, 2)


In [2]:
class DiabetesDataset(Dataset):
    def __init__(self, data, target):
        self.len = data.shape[0]                          #确定数据的个数
        self.X_data = torch.from_numpy(data)
        self.y_data = torch.from_numpy(target)
        print(self.X_data.shape, self.y_data.shape)

    def __getitem__(self, index):
        return self.X_data[index], self.y_data[index]

    def __len__(self):
        return self.len

train_dataset = DiabetesDataset(demo_obs, demo_actions)
train_loader = DataLoader(dataset=train_dataset, batch_size=16, shuffle=True, num_workers=4)


torch.Size([10000, 1, 40, 40]) torch.Size([10000, 2])


In [10]:
class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.Conv = torch.nn.Conv2d(in_channels=1, out_channels=1 ,kernel_size=3, padding=1)
        self.BN = torch.nn.BatchNorm2d(num_features=1)
        self.Pool = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1)

        self.NN1 = torch.nn.Linear(1600, 400)
        self.NN2 = torch.nn.Linear(400, 128)
        self.NN3 = torch.nn.Linear(128, 64)
        self.NN4 = torch.nn.Linear(64, 2)


    def forward(self, input):
        x = self.Conv(input)
        x = self.BN(x)
        x = F.relu(x)
        x = self.Pool(x)
        x = F.dropout(x, p=0.2)
        x = torch.flatten(input, start_dim=1)

        x = self.NN1(x)
        x = F.relu(x)
        x = self.NN2(x)
        x = F.relu(x)
        x = self.NN3(x)
        x = F.relu(x)
        x = self.NN4(x)

        x[:, 0] = torch.tanh(x[:, 0])*150 + 50
        x[:, 1] = torch.tanh(x[:, 1])*30

        return x


model = CNN()
model.to(device)

test_data = torch.randn(3, 1, 40, 40).to(device)
out = model(test_data)
print(out)



tensor([[36.2897,  1.0834],
        [41.4800,  1.7945],
        [47.3131,  2.6361]], device='cuda:0', grad_fn=<CopySlices>)


In [11]:
criterion = torch.nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.0000001)

def train(epoch, input_data):
    running_loss = 0.0
    for batch_idx, data in enumerate(input_data):
        inputs, target = data
        inputs, target = inputs.to(device), target.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, target)
        
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if batch_idx % 100 == 99:
            print('[%d, %5d] loss:%.3f' % (epoch+1, batch_idx+1, running_loss/300))
            running_loss = 0.0

In [13]:
for epoch in range(2000):
    train(epoch, train_loader)

[1,   100] loss:1.023
[1,   200] loss:0.972
[1,   300] loss:1.075
[1,   400] loss:0.937
[1,   500] loss:1.246
[1,   600] loss:1.187
[2,   100] loss:1.131
[2,   200] loss:1.066
[2,   300] loss:1.040
[2,   400] loss:1.112
[2,   500] loss:1.054
[2,   600] loss:1.055
[3,   100] loss:1.165
[3,   200] loss:1.190
[3,   300] loss:1.054
[3,   400] loss:1.049
[3,   500] loss:1.020
[3,   600] loss:0.989
[4,   100] loss:1.047
[4,   200] loss:1.055
[4,   300] loss:1.173
[4,   400] loss:0.933
[4,   500] loss:1.168
[4,   600] loss:1.104
[5,   100] loss:1.091
[5,   200] loss:1.164
[5,   300] loss:1.040
[5,   400] loss:0.999
[5,   500] loss:1.100
[5,   600] loss:1.087
[6,   100] loss:1.239
[6,   200] loss:1.007
[6,   300] loss:1.014
[6,   400] loss:1.029
[6,   500] loss:1.095
[6,   600] loss:1.101
[7,   100] loss:1.204
[7,   200] loss:1.056
[7,   300] loss:0.967
[7,   400] loss:1.010
[7,   500] loss:0.998
[7,   600] loss:1.154
[8,   100] loss:1.096
[8,   200] loss:1.034
[8,   300] loss:0.995
[8,   400]

KeyboardInterrupt: 

In [15]:
torch.save(model.state_dict(), './paramet.pt')