In [3]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datautils import MyTrainDataset

class Trainer:
    def __init__(
        self,
        model: torch.nn.Module,
        train_data: DataLoader,
        optimizer: torch.optim.Optimizer,
        gpu_id: int,
        save_every: int, 
    ) -> None:
        self.gpu_id = gpu_id
        self.model = model.to(gpu_id)
        self.train_data = train_data
        self.optimizer = optimizer
        self.save_every = save_every

    def _run_batch(self, source, targets):
        self.optimizer.zero_grad()
        output = self.model(source)
        loss = F.cross_entropy(output, targets)
        loss.backward()
        self.optimizer.step()

    def _run_epoch(self, epoch):
        # print the cuda device, batch size and number of steps
        # run a forward and backward pass ( _run_batch() ) for each batch
        b_sz = len(next(iter(self.train_data))[0])
        print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
        for source, targets in self.train_data:
            source = source.to(self.gpu_id)
            targets = targets.to(self.gpu_id)
            self._run_batch(source, targets)

    def _save_checkpoint(self, epoch):
        ckp = self.model.state_dict()
        PATH = "checkpoint.pt"
        torch.save(ckp, PATH)
        print(f"Epoch {epoch} | Training checkpoint saved at {PATH}")

    def train(self, max_epochs: int):
        for epoch in range(max_epochs):
            self._run_epoch(epoch)
            if epoch % self.save_every == 0:
                self._save_checkpoint(epoch)


def load_train_objs():
    train_set = MyTrainDataset(2048)  # load your dataset
    model = torch.nn.Linear(20, 1)  # load your model
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    return train_set, model, optimizer


def prepare_dataloader(dataset: Dataset, batch_size: int):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True
    )

def main(device, total_epochs, save_every, batch_size):
    dataset, model, optimizer = load_train_objs()
    train_data = prepare_dataloader(dataset, batch_size)
    trainer = Trainer(model, train_data, optimizer, device, save_every)
    trainer.train(total_epochs)


if __name__ == "__main__":
    import argparse
    # parser = argparse.ArgumentParser(description='simple distributed training job')
    # parser.add_argument('total_epochs', type=int, help='Total epochs to train the model')
    # parser.add_argument('save_every', type=int, help='How often to save a snapshot')
    # parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
    # args = parser.parse_args()
    total_epochs = 10
    save_every = 2
    batch_size = 8
    device = 0  # shorthand for cuda:0
    #main(device, args.total_epochs, args.save_every, args.batch_size)
    main(device, total_epochs, save_every, batch_size)

[GPU0] Epoch 0 | Batchsize: 8 | Steps: 256
Epoch 0 | Training checkpoint saved at checkpoint.pt
[GPU0] Epoch 1 | Batchsize: 8 | Steps: 256
[GPU0] Epoch 2 | Batchsize: 8 | Steps: 256
Epoch 2 | Training checkpoint saved at checkpoint.pt
[GPU0] Epoch 3 | Batchsize: 8 | Steps: 256
[GPU0] Epoch 4 | Batchsize: 8 | Steps: 256
Epoch 4 | Training checkpoint saved at checkpoint.pt
[GPU0] Epoch 5 | Batchsize: 8 | Steps: 256
[GPU0] Epoch 6 | Batchsize: 8 | Steps: 256
Epoch 6 | Training checkpoint saved at checkpoint.pt
[GPU0] Epoch 7 | Batchsize: 8 | Steps: 256
[GPU0] Epoch 8 | Batchsize: 8 | Steps: 256
Epoch 8 | Training checkpoint saved at checkpoint.pt
[GPU0] Epoch 9 | Batchsize: 8 | Steps: 256
