In [51]:
import os
import torch
import math
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
from torch.utils.data import IterableDataset, TensorDataset
import torch_optimizer


In [52]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)


class Encoder(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(106, 64), nn.ReLU(), nn.Linear(64, d_model))

    def forward(self, x):
        return self.l1(x)


class Decoder(nn.Module):
    def __init__(self, d_model, action_dim):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(d_model, 64), nn.ReLU(), nn.Linear(64, action_dim))

    def forward(self, x):
        return self.l1(x)

class TransformerModel(nn.Module):
    def __init__(self, max_seq_length:int, feature_dim:int, action_dim:int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, dropout: float = 0.1):
        super().__init__()
        self.model_type = 'Transformer'
        self.max_seq_length = max_seq_length
        self.pos_encoder = PositionalEncoding(d_model, dropout=0)
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, d_hid, dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Sequential(nn.Linear(feature_dim, d_model), nn.Sigmoid())
        #self.encoder = Encoder(d_model)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, action_dim)
        #self.decoder = Decoder(d_model, action_dim)
        self.src_mask = generate_square_subsequent_mask(self.max_seq_length)

        #self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: torch.Tensor) -> torch.Tensor:
        """
        Args:
            src: Tensor, shape [seq_len, batch_size]
            src_mask: Tensor, shape [seq_len, seq_len]

        Returns:
            output Tensor of shape [seq_len, batch_size, ntoken]
        """
        self.src_mask = self.src_mask.to(device=src.device)
        assert src.shape[1] == self.max_seq_length, f"{src.shape[1]} != {self.max_seq_length}"
        src = self.encoder(src) * math.sqrt(self.d_model)

        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, self.src_mask)
        #output = src

        output = self.decoder(output)
        return output

def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

In [53]:
class LitModule(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.model = TransformerModel(max_seq_length=32, feature_dim=106, action_dim=50, d_model=128, nhead=2, d_hid=64, nlayers=3)
    
    def forward(self, batch, batch_idx):
        features = batch[0]
        outputs = batch[1].long()

        batch_size = features.shape[0]

        # training_step defines the train loop.

        #z = self.encoder(features)
        #x_hat = self.decoder(z, 50)


        x_hat = self.model(features)

        x_hat = x_hat.view(-1,50)
        outputs = outputs.view(-1)

        assert outputs.shape == (32*batch_size,), f"{outputs.shape}"
        assert x_hat.shape == (32*batch_size,50)


        #loss = F.mse_loss(x_hat, outputs)
        #print(x_hat.view(-1,50).shape)
        #print(outputs.view(-1).shape)
        loss = nn.CrossEntropyLoss(ignore_index=-1)(x_hat, outputs)

        assert not torch.any(torch.isnan(loss))

        return loss
    
    def training_step(self, batch, batch_idx):
        return {"loss":self.forward(batch, batch_idx)}

    def validation_step(self, batch, batch_idx):
        # print("BATCH")
        # features = batch[0][0]
        # outputs = batch[1][0].squeeze(dim=1).long()
        # print(type(batch))
        # print(features)
        # print(outputs)
        # print(features.shape)
        # print(outputs.shape)
        # assert False
        loss = self.forward(batch, batch_idx)
        self.log("val_loss", loss)
        return {"val_loss":loss}

    def configure_optimizers(self):
        #optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        #optimizer = torch.optim.AdamW(self.parameters(), lr=5e-5)

        #optimizer = torch.optim.SGD(self.parameters(), lr=0.001)

        optimizer = torch_optimizer.Shampoo(self.parameters(), lr=1e-1)
        return optimizer

In [54]:
td =torch.load("../training_data.pt")
features = td["train_data"]
outputs = td["train_output"]
dataset = TensorDataset(features, outputs)
train_loader = DataLoader(dataset, batch_size=16)

#test_loader = DataLoader(dataset)

td =torch.load("../validation_data.pt")
features = td["train_data"]
outputs = td["train_output"]
dataset = TensorDataset(features, outputs)
val_loader = DataLoader(dataset, batch_size=128)




In [55]:
# model
autoencoder = LitModule(Encoder(3), Decoder(3, 50))

from pytorch_lightning.callbacks.early_stopping import EarlyStopping

# train model
trainer = pl.Trainer(accelerator="gpu", devices=1, callbacks=[EarlyStopping(monitor="val_loss", mode="min")])
trainer.fit(model=autoencoder, train_dataloaders=train_loader, val_dataloaders=val_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | encoder | Encoder          | 7.0 K 
1 | decoder | Decoder          | 3.5 K 
2 | model   | TransformerModel | 269 K 
---------------------------------------------
280 K     Trainable params
0         Non-trainable params
280 K     Total params
1.120     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Traceback (most recent call last):
  File "/home/pawn/miniconda3/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/pawn/miniconda3/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/pawn/miniconda3/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/home/pawn/miniconda3/lib/python3.9/site-packages/traitlets/config/application.py", line 972, in launch_instance
    app.start()
  File "/home/pawn/miniconda3/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 712, in start
    self.io_loop.start()
  File "/home/pawn/miniconda3/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 199, in start
    self.asyncio_loop.run_forever()
  File "/home/pawn/miniconda3/lib/python3.9/asyncio/base_events.py", line 596, in run_forever
    self._run_once()
  File "/home/pawn/miniconda3/lib/python3.9/asyncio/base_events.py", l