In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from  torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.nn import GATv2Conv
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import random_split

import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

seed_everything(1024, workers=True)

Global seed set to 1024


1024

### TODO

[ ] CPU -> GPU 

[ ] add trainer of pytorch-lightning

[ ] Black and White assymentry

[ ] Train and Valid set

In [2]:
###  Load data
# edge_index        (2, E) -> (2, e)  batch
# edge_attr         (E, 3) -> (e, 3)  batch
# node_features     (N, X) -> (e, 2X) will be sampled at ChessModel

data = torch.load('./data/graph_data.pt')
edge_index = data['edge_index']
edge_attr = data['edge_attr']
num_node = data['num_node']

print(edge_index)
print(edge_attr)
print(num_node)

tensor([[    0,     1,     2,  ...,  7410,   884,  6958],
        [   56,   171,    65,  ...,  1125, 29994, 23078]])
tensor([[0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 2.],
        ...,
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.]])
37393


In [3]:
class GraphDataModule(pl.LightningDataModule):
    def __init__(self, data_dir:str="path/to/dir", proportion:list=[0.8, 0.1, 0.1], batch_size:int=32, num_workers:int=32):
        super().__init__()
        self.data_dir = data_dir
        
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        self.train_portion, self.val_portion, self.test_portion = proportion


    def setup(self, stage=None):
        data = torch.load(self.data_dir)
        edge_index = data['edge_index']
        edge_attr = data['edge_attr']

        # size
        dataset_size = len(edge_attr)
        train_size = int(self.train_portion * dataset_size)
        val_size = int(self.val_portion * dataset_size)
        test_size = dataset_size - train_size - val_size
        print(dataset_size, train_size, val_size)

        # preprocessing
        epsilon = 1e-8
        edge_attr = edge_attr / (edge_attr.sum(dim=1, keepdim=True) + epsilon)
        edge_index = edge_index.T
        
        # random split
        self.data = TensorDataset(edge_index, edge_attr)
        self.train_dataset, self.val_dataset, self.test_dataset = random_split(self.data, [train_size, val_size, test_size])

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)


In [4]:
edge_index_doubled = edge_index.repeat(1, 2)

ones = torch.ones_like(edge_attr[:,[0]])
zeros = torch.zeros_like(edge_attr[:,[0]])
edge_attr_WB = torch.concat([edge_attr, zeros], dim=1)
edge_attr_BW = torch.concat([torch.flip(edge_attr, dims=[1]), ones], dim=1)
edge_attr_doubled = torch.concat([edge_attr_WB, edge_attr_BW], dim=0)

print(edge_index_doubled.shape)
print(edge_attr_doubled.shape)

torch.Size([2, 1956716])
torch.Size([1956716, 4])


In [5]:
myData = GraphDataModule(data_dir='./data/graph_data.pt', batch_size=32, num_workers=32)
myData.prepare_data()
myData.setup(stage="fit")
train_dataloader = myData.train_dataloader()
val_dataloader = myData.val_dataloader()
test_dataloader = myData.test_dataloader()

978358 782686 97835


In [6]:
class Encoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_node, num_layers, heads=1):
        super().__init__()
        
        x = torch.randn(num_node, in_channels)
        self.register_buffer('x', x)
        
        self.num_node = num_node
        self.in_channels =in_channels
        self.convs = nn.ModuleList()
        self.convs.append(GATv2Conv(in_channels, hidden_channels, heads=heads, edge_dim=4))
        for _ in range(num_layers - 2):
            self.convs.append(GATv2Conv(hidden_channels, hidden_channels, heads=heads, edge_dim=4))
        self.convs.append(GATv2Conv(hidden_channels * heads, out_channels, heads=1, edge_dim=4))


    def forward(self, edge_index, edge_attr):
        x = self.x
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index, edge_attr)
            if i != len(self.convs) - 1:
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)
        return x

In [7]:
class Decoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels=3):
        super().__init__()
        hidden1, hidden2 = hidden_channels
        self.Fc1 = nn.Linear(in_channels, hidden1, bias=False)
        self.Fc2 = nn.Linear(hidden1, hidden2, bias=False)
        self.Fc3 = nn.Linear(hidden2, out_channels, bias=False)

    def forward(self, x):
        h = F.relu(self.Fc1(x))
        h = F.relu(self.Fc2(h))
        output = self.Fc3(h)
        return output

In [8]:
class ChessModel(pl.LightningModule):
  def __init__(self, node_feature_dim, num_layers, heads, num_node, edge_index_doubled, edge_attr_doubled, learning_rate=1e-6):
    
    super().__init__()
    self.edge_index_doubled = edge_index_doubled.to(self.device)
    self.edge_attr_doubled = edge_attr_doubled.to(self.device)
    self.save_hyperparameters("learning_rate")

    # init layers
    self.Encoder = Encoder(node_feature_dim, node_feature_dim, node_feature_dim, num_node, num_layers, heads)
    self.Decoder = Decoder(2 * node_feature_dim, [100, 100], 3)
    
    # init parameters reccursively
    self.apply(self._init_parameters)

  def _init_parameters(self, module):
    if isinstance(module, nn.Linear) or isinstance(module, nn.LayerNorm):
      nn.init.xavier_normal_(module.weight)
      if module.bias is not None:
        nn.init.zeros_(module.bias)

  #### forward pass ###################################################
  def _shared_step(self, batch, edge_index_doubled, edge_attr_doubled):
    # 0. unpack batch
    batch_edge_index, batch_edge_attr = batch
    batch_edge_index = batch_edge_index.T

    edge_index_doubled = edge_index_doubled.to(self.device)
    edge_attr_doubled = edge_attr_doubled.to(self.device)

    # 1. Forward pass through Encoder
    node_features = self.Encoder(edge_index_doubled, edge_attr_doubled)
   
    # 2. Sample node features
    sampled_node_features = torch.cat([node_features[batch_edge_index[0]], node_features[batch_edge_index[1]]], dim=-1)

    # 3. Forward pass through Decoder
    outputs = self.Decoder(sampled_node_features)
    
    # 4. Set targets
    targets = batch_edge_attr

    loss = F.mse_loss(outputs, targets)
    return loss

  def training_step(self, batch, batch_idx):
    loss = self._shared_step(batch, self.edge_index_doubled, self.edge_attr_doubled)
    self.log('train_loss', loss)
    return loss

  def validation_step(self, batch, batch_idx):
    loss = self._shared_step(batch, self.edge_index_doubled, self.edge_attr_doubled)
    self.log('val_loss', loss)

  def predict_step(self, batch, batch_idx):
    loss = self._shared_step(batch, self.edge_index_doubled, self.edge_attr_doubled)
    self.log('pred_loss', loss)
  #####################################################################

  #### configure optimizer
  def configure_optimizers(self):
    optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate)
    return {
      "optimizer": optimizer,
      "lr_scheduler": {
        "scheduler": ReduceLROnPlateau(optimizer, mode='min'),
        "monitor": "val_loss",
      },
    }

In [9]:
# # module
# myChess = ChessModel(node_feature_dim=10, num_layers=2, heads=3, 
#                      num_node=num_node, 
#                      edge_index_doubled=edge_index_doubled, 
#                      edge_attr_doubled=edge_attr_doubled)

# # trainer
# logger = TensorBoardLogger("tb_logs", name="test5")
# callbacks = [ModelCheckpoint(monitor="val_loss", mode="min"),
#              EarlyStopping(monitor="val_loss", mode="min", patience=3)]
# trainer = Trainer(logger=logger, callbacks=[], val_check_interval=0.1, max_epochs=1000)
# trainer.fit(myChess, train_dataloader, val_dataloader)
# trainer.predict(myChess, test_dataloader)

In [12]:
from elo import ELO

test_edge_index, test_edge_attr= myData.test_dataset.dataset.tensors
print(test_edge_index.shape)
print(test_edge_attr.shape)
elo = ELO("./data/temp.csv")
elo.train(test_edge_index)
outputs = elo.predict(test_edge_index)
print(outputs.shape)
loss = F.mse_loss(outputs, test_edge_attr)
print(loss)

torch.Size([978358, 2])
torch.Size([978358, 3])


KeyboardInterrupt: 