<a href="https://colab.research.google.com/github/Kacper-W-Kozdon/interviewProgram/blob/main/LLM_merging.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install utils

Collecting utils
  Downloading utils-1.0.2.tar.gz (13 kB)
  Preparing metadata (setup.py) ... [?25l- done
[?25hBuilding wheels for collected packages: utils
  Building wheel for utils (setup.py) ... [?25l- \ done
[?25h  Created wheel for utils: filename=utils-1.0.2-py2.py3-none-any.whl size=13905 sha256=1bad37a5e045572a06666c2150be54dc3c1b728e7fe2ccb54e42ecac15dd0d7a
  Stored in directory: /root/.cache/pip/wheels/b8/39/f5/9d0ca31dba85773ececf0a7f5469f18810e1c8a8ed9da28ca7
Successfully built utils
Installing collected packages: utils
Successfully installed utils-1.0.2


In [2]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os

In [3]:
class MyTrainDataset(Dataset):
    def __init__(self, size):
        self.size = size
        self.data = [(torch.rand(20), torch.rand(1)) for _ in range(size)]

    def __len__(self):
        return self.size

    def __getitem__(self, index):
        return self.data[index]

In [4]:
list(range(torch.cuda.device_count()))

[0, 1]

In [5]:
def ddp_setup(rank: int, world_size: int):
  """
  Args:
    rank: Unique identifier of each process
    world_size: Total number of processes
  """
  os.environ["MASTER_ADDR"] = "localhost"
  os.environ["MASTER_PORT"] = "12355"
  torch.cuda.set_device(rank)
  init_process_group(backend="nccl", rank=rank, world_size=world_size)

In [6]:
class Trainer:
    def __init__(
        self,
        model: torch.nn.Module,
        train_data: DataLoader,
        optimizer: torch.optim.Optimizer,
        gpu_id: int,
        save_every: int,
    ) -> None:
        self.gpu_id = gpu_id
        self.model = model.to(gpu_id)
        self.train_data = train_data
        self.optimizer = optimizer
        self.save_every = save_every
        self.model = DDP(model, device_ids=[gpu_id])

    def _run_batch(self, source, targets):
        self.optimizer.zero_grad()
        output = self.model(source)
        loss = F.cross_entropy(output, targets)
        loss.backward()
        self.optimizer.step()

    def _run_epoch(self, epoch):
        b_sz = len(next(iter(self.train_data))[0])
        print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
        self.train_data.sampler.set_epoch(epoch)
        for source, targets in self.train_data:
            source = source.to(self.gpu_id)
            targets = targets.to(self.gpu_id)
            self._run_batch(source, targets)

    def _save_checkpoint(self, epoch):
        ckp = self.model.module.state_dict()
        PATH = "checkpoint.pt"
        torch.save(ckp, PATH)
        print(f"Epoch {epoch} | Training checkpoint saved at {PATH}")

    def train(self, max_epochs: int):
        for epoch in range(max_epochs):
            self._run_epoch(epoch)
            if self.gpu_id == 0 and epoch % self.save_every == 0:
                self._save_checkpoint(epoch)


def load_train_objs():
    train_set = MyTrainDataset(2048)  # load your dataset
    model = torch.nn.Linear(20, 1)  # load your model
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    return train_set, model, optimizer


def prepare_dataloader(dataset: Dataset, batch_size: int):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        sampler=DistributedSampler(dataset)
    )


def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_size: int):
    ddp_setup(rank, world_size)
    dataset, model, optimizer = load_train_objs()
    train_data = prepare_dataloader(dataset, batch_size)
    trainer = Trainer(model, train_data, optimizer, rank, save_every)
    trainer.train(total_epochs)
    destroy_process_group()


# if __name__ == "__main__":
#     import argparse
#     parser = argparse.ArgumentParser(description='simple distributed training job')
#     parser.add_argument('total_epochs', type=int, help='Total epochs to train the model')
#     parser.add_argument('save_every', type=int, help='How often to save a snapshot')
#     parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
#     args = parser.parse_args()

#     world_size = torch.cuda.device_count()
#     mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size)

In [7]:
class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    # custom memory pinning method on custom type
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = MyTrainDataset(32).data

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)


for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())
    print(sample.inp, sample.tgt)

True
True
tensor([[0.8261, 0.3359, 0.6790, 0.6315, 0.5465, 0.3958, 0.8076, 0.0372, 0.6756,
         0.5862, 0.3407, 0.0682, 0.6702, 0.9586, 0.0720, 0.6250, 0.4441, 0.5433,
         0.2718, 0.0866],
        [0.5937, 0.2942, 0.8188, 0.5548, 0.9166, 0.8544, 0.2476, 0.8193, 0.4493,
         0.5483, 0.9565, 0.2460, 0.3784, 0.4024, 0.0301, 0.1619, 0.5343, 0.4323,
         0.9850, 0.1822]]) tensor([[0.2914],
        [0.8392]])
True
True
tensor([[0.9389, 0.0214, 0.1297, 0.8555, 0.0221, 0.3420, 0.4026, 0.7390, 0.5596,
         0.5306, 0.8684, 0.8148, 0.3416, 0.8415, 0.5406, 0.7363, 0.9450, 0.6004,
         0.0932, 0.9337],
        [0.3664, 0.7897, 0.8044, 0.2066, 0.6655, 0.5046, 0.5194, 0.8938, 0.2501,
         0.3697, 0.2161, 0.1771, 0.3987, 0.8839, 0.0355, 0.0212, 0.3009, 0.7936,
         0.5514, 0.6659]]) tensor([[0.2191],
        [0.7072]])
True
True
tensor([[0.3946, 0.9855, 0.0017, 0.8524, 0.2259, 0.0308, 0.2991, 0.9131, 0.4785,
         0.5628, 0.3593, 0.7614, 0.6699, 0.9026, 0.6441, 0.72