In [1]:
import sys
 
# setting path
sys.path.append('../')
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from citylearn.citylearn import CityLearnEnv
from citylearn_wrapper import CityLearnEnvWrapper
import pandas as pd
import numpy as np
import csv

In [2]:
env_config = {
    "schema": "../data/citylearn_challenge_2022_phase_1/schema.json",
}
hidden_size = 128
batch_norm = True
num_episodes = 50

In [3]:
env = CityLearnEnvWrapper(env_config)
obs_size = env.observation_space.shape[0]
act_size = env.action_space.shape[0]


In [4]:
class PlanningDataset(Dataset):
    def __init__(self, filename) -> None:
        super().__init__()
        self.filename = filename
        data = pd.read_csv(filename, delimiter="|").dropna()
        self.X = np.array([eval(data.iloc[i, 0]) + eval(data.iloc[i, 2]) for i in range(len(data))])
        self.X_mean = np.mean(self.X, axis=0)
        self.X_std = np.std(self.X, axis=0)
        self.y = np.array([eval(data.iloc[i, 1]) for i in range(len(data))])
        self.y_mean = np.mean(self.y, axis=0)
        self.y_std = np.std(self.y, axis=0)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        return torch.tensor(self.X[index], dtype=torch.float16), torch.tensor(self.y[index], dtype=torch.float16)

In [5]:
from multiprocessing.spawn import prepare
from turtle import forward
from typing import Any

get_unit = lambda in_size, out_size: nn.Sequential(
    nn.Linear(in_size, out_size), 
    nn.Tanh(),
    nn.BatchNorm1d(out_size) if batch_norm else nn.Identity(),
    nn.Dropout(),
    )

class LitPlanningModel(pl.LightningModule):
    def __init__(self, obs_size, act_size, hidden_size, lr = 0.0001, X_mean = 0, X_std = 1, y_mean = 0, y_std = 1) -> None:
        super().__init__()
        self.lr = lr
        self.X_mean = torch.tensor(X_mean, device=self.device, dtype=torch.float16)
        self.X_std = torch.tensor(X_std, device=self.device, dtype=torch.float16)
        self.y_mean = torch.tensor(y_mean, device=self.device, dtype=torch.float16)
        self.y_std = torch.tensor(y_std, device=self.device, dtype=torch.float16)
        self.layers = nn.ModuleList([
            get_unit(obs_size + act_size, hidden_size),
            get_unit(hidden_size, hidden_size),
            get_unit(hidden_size, hidden_size),
            get_unit(hidden_size, hidden_size),
            nn.Linear(hidden_size, obs_size)
        ])
    
    def preprocess(self, x):
        return (x - self.X_mean.to(self.device)) / self.X_std.to(self.device)

    def postprocess(self, y):
        return y * self.y_std.to(self.device) + self.y_mean.to(self.device)

    def forward(self, x):
        x = self.preprocess(x)
        for i, layer in enumerate(self.layers):
            base = 0
            # Add residual connection if this is not
            # the first or last layer
            if i != 0 and i != len(self.layers) - 1:
                base = x
            x = layer(x) + base
        return self.postprocess(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

    def shared_step(self, batch, batch_idx, return_mae=False):
        x, y = batch
        x = x.view(x.size(0), -1)
        y_hat = self.forward(x)
        loss = F.mse_loss(y_hat, y)
        if return_mae:
            eps = 0.00001
            mae_loss = torch.mean(torch.abs(y_hat - y))
            return loss, mae_loss
        return loss

    def training_step(self, train_batch, batch_idx):
        loss = self.shared_step(train_batch, batch_idx)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        loss, mae_loss = self.shared_step(val_batch, batch_idx, return_mae = True)
        self.log("val_loss", loss)
        self.log("val_mae", mae_loss)

In [6]:
dataset = PlanningDataset("planning_model_data.csv")

train_size = int(len(dataset) * 0.8)
val_size = len(dataset) - train_size

train_set, val_set = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=3)
val_loader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=3)

In [7]:
model = LitPlanningModel(obs_size, act_size, hidden_size, X_mean=dataset.X_mean, y_mean=dataset.y_mean, X_std=dataset.X_std, y_std=dataset.y_std)

In [8]:
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
wandb_logger = WandbLogger(project="active-rl-planning-model", entity="social-game-rl", log_model="all")
checkpoint_callback = ModelCheckpoint(monitor="val_loss")
trainer = pl.Trainer(gpus=1, precision=16, logger=wandb_logger, callbacks=[checkpoint_callback], auto_lr_find=False, max_epochs=200)

#trainer.tune(model, train_dataloaders = train_loader, val_dataloaders = val_loader)
trainer.fit(model, train_dataloaders = train_loader, val_dataloaders = val_loader)

  return LooseVersion(v) >= LooseVersion(check)
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdoseok[0m ([33msocial-game-rl[0m). Use [1m`wandb login --relogin`[0m to force relogin
  from IPython.core.display import display, HTML  # type: ignore


  rank_zero_deprecation(
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type       | Params
--------------------------------------
0 | layers | ModuleList | 66.7 K
--------------------------------------
66.7 K    Trainable params
0         Non-trainable params
66.7 K    Total params
0.133     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  value = torch.tensor(value, device=self.device)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]