<a href="https://colab.research.google.com/github/Brackly/dask_n_pytorch/blob/main/daskTestRun.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import pandas as pd
from torch.utils.data import IterableDataset,DataLoader,Dataset
import dask.array as da
import torch
from distributed import Client,LocalCluster
from tqdm import tqdm
import dask.dataframe as dd
import dask
import numpy as np
import json
import time
import logging

In [2]:
logger = logging.getLogger("client")
logger.setLevel(logging.INFO)
data_path='/content/sample_train.csv'

In [3]:
def split_batch(batch,NUM_WORKERS):
    x, y = batch
    x_chunks = x.chunk(NUM_WORKERS)
    y_chunks = y.chunk(NUM_WORKERS)
    return list(zip(x_chunks, y_chunks))

def update_worker_model(state_dict):
    model.load_state_dict(state_dict)

def init_worker(main_model, main_criterion):
    import logging
    log = logging.getLogger("dask_worker_log")
    log.setLevel(logging.INFO)

    global model, criterion
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = main_model.to(device)
    criterion = main_criterion

    log.info(f"Successfully initialized model, and loss fn on worker withe device: {device}")

def all_reduce(results: list, model: torch.nn.Module):
    avg_grads = [torch.zeros_like(p) for p in model.parameters()]

    total_loss = dask.delayed(sum)([loss for loss, _ in results])
    summed_grads = [dask.delayed(sum)([g[i] for _, g in results]) for i in range(len(avg_grads))]
    computed_loss, computed_grads = dask.compute(total_loss, summed_grads)

    with torch.no_grad():
        for param, avg_grad in zip(model.parameters(), computed_grads):
            param.grad = avg_grad / len(results)

    return model, computed_loss

def dispatch(client:Client,train,batch,model,criterion,optimizer):
    client.run(init_worker, model, criterion)
    NUM_WORKERS=len(list(client.scheduler_info()['workers'].keys()))
    batches=split_batch(batch,NUM_WORKERS)
    futures = client.map(train, batches)
    results = client.gather(futures)

    model,total_loss=all_reduce(results,model)
    optimizer.step()
    optimizer.zero_grad()
    client.run(update_worker_model, model.state_dict())
    return total_loss

In [4]:
BATCH_SIZE=1000
NUMBER_OF_WORKERS=4

cluster=LocalCluster(n_workers=NUMBER_OF_WORKERS)
client = Client(cluster)

INFO:distributed.http.proxy:To route to workers diagnostics web server please install jupyter-server-proxy: python -m pip install jupyter-server-proxy
INFO:distributed.scheduler:State start
INFO:distributed.scheduler:  Scheduler at:     tcp://127.0.0.1:40549
INFO:distributed.scheduler:  dashboard at:  http://127.0.0.1:8787/status
INFO:distributed.scheduler:Registering Worker plugin shuffle
INFO:distributed.nanny:        Start Nanny at: 'tcp://127.0.0.1:39737'
INFO:distributed.nanny:        Start Nanny at: 'tcp://127.0.0.1:39949'
INFO:distributed.nanny:        Start Nanny at: 'tcp://127.0.0.1:43589'
INFO:distributed.nanny:        Start Nanny at: 'tcp://127.0.0.1:38951'
INFO:distributed.scheduler:Register worker addr: tcp://127.0.0.1:42181 name: 0
INFO:distributed.scheduler:Starting worker compute stream, tcp://127.0.0.1:42181
INFO:distributed.core:Starting established connection to tcp://127.0.0.1:49844
INFO:distributed.scheduler:Register worker addr: tcp://127.0.0.1:34903 name: 2
INFO:

In [5]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.fc1=torch.nn.Linear(384,128)
        self.fc2=torch.nn.Linear(128,64)
        self.fc3=torch.nn.Linear(64,5)

    def forward(self,x):
        x=torch.relu(self.fc1(x))
        x=torch.relu(self.fc2(x))
        x=self.fc3(x)
        return x

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, path):
        self.df=self.fetch_data(path)

    def fetch_data(self, path):
        df=pd.read_csv(path)
        return df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        x= torch.tensor(json.loads(self.df["Embeddings"][idx]))
        y= self.df["OpenStatus"][idx]
        return x,y

In [6]:
model=Model()
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)
dataset=CustomDataset(path=data_path)
train_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=0)

In [7]:
def train(batch):
    device='cuda' if torch.cuda.is_available() else 'cpu'
    x, y = batch
    x, y = x.to(device), y.to(device)
    optimizer.zero_grad()
    pred = model.to(device)(x)
    loss = criterion(pred, y)
    loss.backward()
    return loss.item(),[p.grad.cpu() for p in model.parameters()]

In [8]:
print(f"Training started at :{client.dashboard_link}")
total_time=[]
for epoch in tqdm(range(10)):
    epoch_loss = 0
    num_batches = 0
    for batch  in train_dataloader:
        start=time.time()
        total_loss=dispatch(client,train,batch,model,criterion,optimizer)
        stop=time.time()
        total_time.append(stop-start)
        epoch_loss += total_loss
        num_batches += 1
    print(f"Epoch {epoch+1} Loss: {epoch_loss/num_batches:.4f}")

logger.info(f"Took:{sum(total_time)/len(total_time)}")

Training started at :http://127.0.0.1:8787/status


 10%|█         | 1/10 [00:11<01:43, 11.55s/it]

Epoch 1 Loss: 6.6157


 20%|██        | 2/10 [00:12<00:42,  5.37s/it]

Epoch 2 Loss: 6.4944


 30%|███       | 3/10 [00:13<00:22,  3.27s/it]

Epoch 3 Loss: 6.3749


 40%|████      | 4/10 [00:13<00:13,  2.21s/it]

Epoch 4 Loss: 6.2497


 50%|█████     | 5/10 [00:14<00:08,  1.63s/it]

Epoch 5 Loss: 6.1127


 60%|██████    | 6/10 [00:15<00:05,  1.27s/it]

Epoch 6 Loss: 5.9598


 70%|███████   | 7/10 [00:15<00:03,  1.05s/it]

Epoch 7 Loss: 5.7866


 80%|████████  | 8/10 [00:16<00:01,  1.11it/s]

Epoch 8 Loss: 5.5893


 90%|█████████ | 9/10 [00:16<00:00,  1.24it/s]

Epoch 9 Loss: 5.3655


100%|██████████| 10/10 [00:17<00:00,  1.75s/it]
INFO:client:Took:1.4135619163513184


Epoch 10 Loss: 5.1137


In [9]:
# Cleanup
client.close()
cluster.close()

INFO:distributed.scheduler:Remove client Client-5396450a-1d3c-11f0-80be-0242ac1c000c
INFO:distributed.core:Received 'close-stream' from tcp://127.0.0.1:49850; closing.
INFO:distributed.scheduler:Remove client Client-5396450a-1d3c-11f0-80be-0242ac1c000c
INFO:distributed.scheduler:Close client connection: Client-5396450a-1d3c-11f0-80be-0242ac1c000c
INFO:distributed.scheduler:Retire worker addresses (stimulus_id='retire-workers-1745080562.0935268') (0, 1, 2, 3)
INFO:distributed.nanny:Closing Nanny at 'tcp://127.0.0.1:39737'. Reason: nanny-close
INFO:distributed.nanny:Nanny asking worker to close. Reason: nanny-close
INFO:distributed.nanny:Closing Nanny at 'tcp://127.0.0.1:39949'. Reason: nanny-close
INFO:distributed.nanny:Nanny asking worker to close. Reason: nanny-close
INFO:distributed.nanny:Closing Nanny at 'tcp://127.0.0.1:43589'. Reason: nanny-close
INFO:distributed.nanny:Nanny asking worker to close. Reason: nanny-close
INFO:distributed.nanny:Closing Nanny at 'tcp://127.0.0.1:38951'