In [None]:
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler


from pathlib import Path
from os import environ
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import numpy as np

In [None]:
# === Data Handler ===
class DataHandler:
    def __init__(self, batch_size=32):
        self.batch_size = batch_size

    def load_dataset(self, dataset_name, scaler_transform):
        from buildings_bench import load_torch_dataset
        return list(load_torch_dataset(
            dataset_name,
            apply_scaler_transform=scaler_transform,
            scaler_transform_path=Path(environ["TRANSFORM_PATH"])
        ))

    def create_dataloader(self, dataset):
        return DataLoader(dataset, batch_size=self.batch_size, shuffle=False)

In [None]:
# === Base Model and Subclasses ===

class Model(nn.Module):
    DEFAULT_CONTEXT_LEN = 168
    DEFAULT_PRED_LEN = 24

    def __init__(self, activation):
        super().__init__()
        self.context_len = self.DEFAULT_CONTEXT_LEN
        self.pred_len = self.DEFAULT_PRED_LEN
        self.activation = self._get_activation(activation)
        self.embeddings = self._create_embeddings()

    def _create_embeddings(self):
        return nn.ModuleDict({
            'power': nn.Linear(1, 64),
            'building': nn.Embedding(2, 32),
            'lat': nn.Linear(1, 32),
            'lon': nn.Linear(1, 32)
        })

    def _get_activation(self, name):
        return {
            "relu": nn.ReLU(),
            "tanh": nn.Tanh(),
            "gelu": nn.GELU(),
            "leaky_relu": nn.LeakyReLU()
        }.get(name.lower(), nn.ReLU())

    def _data_pre_process(self, x):
        lat = self.embeddings['lat'](x['latitude'])
        lon = self.embeddings['lon'](x['longitude'])
        btype = self.embeddings['building'](x['building_type'].squeeze(-1))
        load = self.embeddings['power'](x['load'])
        return torch.cat([lat, lon, btype, load], dim=2)


class NN(Model):
    def __init__(self, activation):
        super().__init__(activation)
        self.model = self._build_model()

    def _build_model(self):
        input_dim = self.context_len * 160
        return nn.Sequential(
            nn.Linear(input_dim, 512), self.activation,
            nn.Linear(512, 256), self.activation,
            nn.Linear(256, 128), self.activation,
            nn.Linear(128, self.pred_len)
        )

    def forward(self, x):
        ts_embed = self._data_pre_process(x)
        x_flat = ts_embed[:, :self.context_len, :].reshape(x['load'].shape[0], -1)
        return self.model(x_flat).unsqueeze(-1)


class RNN(Model):
    def __init__(self, activation="relu"):
        super().__init__(activation)
        self.rnn1, self.rnn2, self.output_layer = self._build_model()

    def _build_model(self):
        rnn1 = nn.RNN(160, 128, batch_first=True)
        rnn2 = nn.RNN(128, 128, batch_first=True)
        output_layer = nn.Linear(128, self.pred_len)
        return rnn1, rnn2, output_layer

    def forward(self, x):
        ts_embed = self._data_pre_process(x)
        out1, _ = self.rnn1(ts_embed)
        out2, _ = self.rnn2(out1)
        last_hidden = self.activation(out2[:, -1, :])
        return self.output_layer(last_hidden).unsqueeze(-1)


class LSTM(Model):
    def __init__(self, activation="relu"):
        super().__init__(activation)
        self.lstm1, self.lstm2, self.output_layer = self._build_model()

    def _build_model(self):
        lstm1 = nn.LSTM(160, 128, batch_first=True)
        lstm2 = nn.LSTM(128, 128, batch_first=True)
        output_layer = nn.Linear(128, self.pred_len)
        return lstm1, lstm2, output_layer

    def forward(self, x):
        ts_embed = self._data_pre_process(x)
        out1, _ = self.lstm1(ts_embed)
        out2, _ = self.lstm2(out1)
        last_hidden = self.activation(out2[:, -1, :])
        return self.output_layer(last_hidden).unsqueeze(-1)


class GRU(Model):
    def __init__(self, activation="relu"):
        super().__init__(activation)
        self.gru1, self.gru2, self.output_layer = self._build_model()

    def _build_model(self):
        gru1 = nn.GRU(160, 128, batch_first=True)
        gru2 = nn.GRU(128, 128, batch_first=True)
        output_layer = nn.Linear(128, self.pred_len)
        return gru1, gru2, output_layer

    def forward(self, x):
        ts_embed = self._data_pre_process(x)
        out1, _ = self.gru1(ts_embed)
        out2, _ = self.gru2(out1)
        last_hidden = self.activation(out2[:, -1, :])
        return self.output_layer(last_hidden).unsqueeze(-1)

In [None]:
# === DDP-Compatible Trainer Class ===

class Trainer:
    def __init__(self, model_name, device, scaler_transform, rank, world_size,
                 activation='relu', optimizer_name='adam', lr=1e-3):
        self.model_name = model_name
        self.device = device
        self.scaler_transform = scaler_transform
        self.activation = activation
        self.optimizer_name = optimizer_name
        self.lr = lr
        self.rank = rank
        self.world_size = world_size

        self._setup_distributed()
        self.model = self._load_model()
        self.optimizer = self._get_optimizer()
        self.loss_fn = nn.MSELoss()
        self.handler = DataHandler(batch_size=config["batch_size"])

    def _setup_distributed(self):
        dist.init_process_group(backend="nccl", rank=self.rank, world_size=self.world_size)

    def _load_model(self):
        model_map = {'NN': NN, 'RNN': RNN, 'LSTM': LSTM, 'GRU': GRU}
        model = model_map[self.model_name](activation=self.activation).to(self.device)
        return DDP(model, device_ids=[self.rank])

    def _get_optimizer(self):
        opt_map = {
            'adam': torch.optim.Adam,
            'sgd': torch.optim.SGD,
            'adamw': torch.optim.AdamW
        }
        return opt_map.get(self.optimizer_name.lower(), torch.optim.Adam)(
            self.model.parameters(), lr=self.lr
        )

    def _get_ddp_dataloader(self, dataset):
        sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=True)
        return torch.utils.data.DataLoader(dataset, batch_size=32, sampler=sampler)

    def train(self, train_buildings, epochs=5):
        self.model.train()
        for epoch in range(epochs):
            total_loss = 0.0
            for building_id, building_dataset in train_buildings:
                dataloader = self._get_ddp_dataloader(building_dataset)
                for batch in dataloader:
                    for key, value in batch.items():
                        batch[key] = value.to(self.device)
                    self.optimizer.zero_grad()
                    predictions = self.model(batch)
                    targets = batch['load'][:, self.model.module.context_len:, 0]
                    loss = self.loss_fn(predictions[:, :, 0], targets)
                    loss.backward()
                    self.optimizer.step()
                    total_loss += loss.item()
            if self.rank == 0:
                print(f"[{self.model_name}] Epoch {epoch + 1}: Loss = {total_loss:.4f}")
        self.model.eval()

    def evaluate(self, test_buildings):
        self.model.eval()
        results = {}
        mae_total = 0.0
        rmse_total = 0.0
        r2_total = 0.0
        count = 0
        for building_id, building_dataset in test_buildings:
            inverse_transform = building_dataset.datasets[0].load_transform.undo_transform
            dataloader = self._get_ddp_dataloader(building_dataset)
            target_list = []
            prediction_list = []
            with torch.no_grad():
                for batch in dataloader:
                    for key, value in batch.items():
                        batch[key] = value.to(self.device)
                    predictions = self.model(batch)
                    targets = batch['load'][:, self.model.module.context_len:]
                    targets = inverse_transform(targets)
                    predictions = inverse_transform(predictions)
                    prediction_list.append(predictions.detach().cpu())
                    target_list.append(targets.detach().cpu())
            predictions_all = torch.cat(prediction_list)
            targets_all = torch.cat(target_list)
            mae = torch.abs(predictions_all - targets_all).mean().item()
            rmse = torch.sqrt(((predictions_all - targets_all) ** 2).mean()).item()
            r2 = 1 - (((predictions_all - targets_all) ** 2).sum() / ((targets_all - targets_all.mean()) ** 2).sum()).item()
            mae_total += mae
            rmse_total += rmse
            r2_total += r2
            count += 1
            results[building_id] = (predictions_all, targets_all)
        return results, mae_total / count, rmse_total / count, r2_total / count

In [None]:
# === Main (SLURM-Aware) DDP Launcher ===

def run_ddp(rank, world_size):
    os.environ["MASTER_ADDR"] = os.environ.get("SLURM_LAUNCH_NODE_IPADDR", "127.0.0.1")
    os.environ["MASTER_PORT"] = "12355"
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")

    # Set environment variables for dataset access
    environ["PATH"] = config["PATH"]
    environ["REPO_PATH"] = f"{config['PATH']}/BuildingsBenchTutorial/BuildingsBench/"
    environ["BUILDINGS_BENCH"] = f"{config['PATH']}/Dataset"
    environ["TRANSFORM_PATH"] = f"{config['PATH']}/Dataset/metadata/transforms"

    # Dataset loading
    handler = DataHandler(batch_size=config["batch_size"])
    all_buildings = handler.load_dataset(config["dataset_name"], config["scaler_transform"])
    train_buildings = all_buildings[:int(0.8 * len(all_buildings))]
    test_buildings = all_buildings[int(0.8 * len(all_buildings)):]

    for model_class in [NN, RNN, LSTM, GRU]:
        if rank == 0:
            print(f"\n--- Training {model_class.__name__} ---")
        trainer = Trainer(
            model_name=model_class.__name__,
            device=device,
            scaler_transform=config["scaler_transform"],
            activation=config["activation"],
            optimizer_name=config["optimizer_name"],
            lr=config["lr"],
            rank=rank,
            world_size=world_size
        )
        trainer.train(train_buildings, epochs=config["epochs"])
        if rank == 0:
            _, mae, rmse, r2 = trainer.evaluate(test_buildings)
            print(f"[{model_class.__name__}] MAE: {mae:.4f}, RMSE: {rmse:.4f}, R²: {r2:.4f}")
    dist.destroy_process_group()

def main():
    world_size = int(os.environ["SLURM_NTASKS"])
    rank = int(os.environ["SLURM_PROCID"])
    run_ddp(rank, world_size)

if __name__ == "__main__":
    main()