<a href="https://colab.research.google.com/github/Kacper-W-Kozdon/interviewProgram/blob/main/LLM_merging.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install utils

Collecting utils
  Downloading utils-1.0.2.tar.gz (13 kB)
  Preparing metadata (setup.py) ... [?25l- done
[?25hBuilding wheels for collected packages: utils
  Building wheel for utils (setup.py) ... [?25l- \ done
[?25h  Created wheel for utils: filename=utils-1.0.2-py2.py3-none-any.whl size=13905 sha256=bc1c4ae97b33e9b1224ea35e4f31a2e7818e25855b37cc98601aaa57a05c4c94
  Stored in directory: /root/.cache/pip/wheels/b8/39/f5/9d0ca31dba85773ececf0a7f5469f18810e1c8a8ed9da28ca7
Successfully built utils
Installing collected packages: utils
Successfully installed utils-1.0.2


In [2]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os

In [3]:
class MyTrainDataset(Dataset):
    def __init__(self, size):
        self.size = size
        self.data = [(torch.rand(20), torch.rand(1)) for _ in range(size)]

    def __len__(self):
        return self.size

    def __getitem__(self, index):
        return self.data[index]

In [4]:
list(range(torch.cuda.device_count()))

[0, 1]

In [5]:
def ddp_setup(rank: int, world_size: int):
  """
  Args:
    rank: Unique identifier of each process
    world_size: Total number of processes
  """
  os.environ["MASTER_ADDR"] = "localhost"
  os.environ["MASTER_PORT"] = "12355"
  torch.cuda.set_device(rank)
  init_process_group(backend="nccl", rank=rank, world_size=world_size)

In [6]:
class SimpleCustomBatch: #Use this together with the code cell below this one.
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    # custom memory pinning method on custom type
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = MyTrainDataset(32).data

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)


for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())
    print(sample.inp, sample.tgt)

True
True
tensor([[0.0630, 0.7225, 0.7376, 0.9414, 0.6833, 0.9758, 0.1730, 0.0893, 0.6997,
         0.7725, 0.7812, 0.1246, 0.3104, 0.6186, 0.0075, 0.6901, 0.6650, 0.7725,
         0.6367, 0.6176],
        [0.6100, 0.8490, 0.8544, 0.4706, 0.5169, 0.1687, 0.1511, 0.0489, 0.0256,
         0.1157, 0.6253, 0.1889, 0.0032, 0.5133, 0.2120, 0.8789, 0.2811, 0.4613,
         0.2359, 0.3205]]) tensor([[0.7146],
        [0.4443]])
True
True
tensor([[0.7196, 0.8437, 0.4099, 0.7524, 0.9245, 0.5918, 0.1414, 0.2609, 0.5399,
         0.0932, 0.4517, 0.4344, 0.8079, 0.0501, 0.0413, 0.1512, 0.7931, 0.8622,
         0.0071, 0.4703],
        [0.5848, 0.2744, 0.9096, 0.7694, 0.2574, 0.7331, 0.3456, 0.9453, 0.4236,
         0.3252, 0.5383, 0.9695, 0.2510, 0.3706, 0.0055, 0.6465, 0.0366, 0.7954,
         0.8668, 0.0833]]) tensor([[0.7093],
        [0.6776]])
True
True
tensor([[0.2392, 0.0197, 0.7501, 0.7139, 0.7834, 0.6853, 0.3981, 0.3518, 0.9969,
         0.1109, 0.9491, 0.5322, 0.9431, 0.5542, 0.4769, 0.72

In [7]:
class Trainer:
    def __init__(
        self,
        model: torch.nn.Module,
        train_data: DataLoader,
        optimizer: torch.optim.Optimizer,
        gpu_id: int,
        save_every: int,
    ) -> None:
        self.gpu_id = gpu_id
        self.model = model.to(gpu_id)
        self.train_data = train_data
        self.optimizer = optimizer
        self.save_every = save_every
        self.model = DDP(model, device_ids=[gpu_id])

    def _run_batch(self, source, targets):
        self.optimizer.zero_grad()
        output = self.model(source)
        loss = F.cross_entropy(output, targets)
        loss.backward()
        self.optimizer.step()

    def _run_epoch(self, epoch):
        b_sz = len(next(iter(self.train_data))[0])
        print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
        self.train_data.sampler.set_epoch(epoch)
        for source, targets in self.train_data:
            source = source.to(self.gpu_id)
            targets = targets.to(self.gpu_id)
            self._run_batch(source, targets)

    def _save_checkpoint(self, epoch):
        ckp = self.model.module.state_dict()
        PATH = "checkpoint.pt"
        torch.save(ckp, PATH)
        print(f"Epoch {epoch} | Training checkpoint saved at {PATH}")

    def train(self, max_epochs: int):
        for epoch in range(max_epochs):
            self._run_epoch(epoch)
            if self.gpu_id == 0 and epoch % self.save_every == 0:
                self._save_checkpoint(epoch)


def load_train_objs():
    train_set = MyTrainDataset(2048)  # load your dataset
    model = torch.nn.Linear(20, 1)  # load your model
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    return train_set, model, optimizer


def prepare_dataloader(dataset: Dataset, batch_size: int):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        sampler=DistributedSampler(dataset)
    )


def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_size: int):
    ddp_setup(rank, world_size)
    dataset, model, optimizer = load_train_objs()
    train_data = prepare_dataloader(dataset, batch_size)
    trainer = Trainer(model, train_data, optimizer, rank, save_every)
    trainer.train(total_epochs)
    destroy_process_group()


# if __name__ == "__main__":
#     import argparse
#     parser = argparse.ArgumentParser(description='simple distributed training job')
#     parser.add_argument('total_epochs', type=int, help='Total epochs to train the model')
#     parser.add_argument('save_every', type=int, help='How often to save a snapshot')
#     parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
#     args = parser.parse_args()

#     world_size = torch.cuda.device_count()
#     mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size)