In [1]:
#presuming localCPU test runs and all the bits there are installed.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm.notebook import tqdm
from distributed import Client, LocalCluster
import torch.multiprocessing as mp

# https://docs.mlerp.cloud.edu.au/tutorials/3_dask_pytorch.html


DATASET = ["Hello, my name is", 
           "You killed my father",
           "My favourite colours are:",
           "What is the airspeed of an unladen"]

batch_size = len(DATASET)

# https://jobqueue.dask.org/en/latest/debug.html
import logging
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)



In [2]:

dataloader = torch.utils.data.DataLoader(DATASET, batch_size=batch_size, shuffle=True, num_workers=1, multiprocessing_context=mp.get_context("fork"))


In [3]:
dataloader

<torch.utils.data.dataloader.DataLoader at 0x7f37c802c4d0>

In [4]:
for line in dataloader:
    print(line)

['My favourite colours are:', 'What is the airspeed of an unladen', 'You killed my father', 'Hello, my name is']


In [5]:

def generate(loader, model_id="mistralai/Mistral-7B-v0.1", output=[]): 
    tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
    

    for line in loader:
        print("working on", line)
        # text = "Hello my name is"
        inputs = tokenizer(line, return_tensors="pt").to("cuda")
        
        outputs = model.generate(**inputs, max_new_tokens=50)
        output.append({"prompt":line,"output":tokenizer.decode(outputs[0], skip_special_tokens=True)})
        print(output)
    return output
        

In [9]:
from dask_jobqueue import SLURMCluster
from distributed import Client, progress

try:
    # cluster = LocalCluster(processes=False)
    cluster = SLURMCluster(
        memory="192g", processes=1, cores=8, job_extra_directives=["--gres=gpu:40gb:1",], nanny=False,
        # queue="lion", # check with $ sacctmgr list clusters
        # worker_extra_args=["-e slurm-%j.err", "-o slurm-%j.out"], # not needed with log_directory
        log_directory="./logs", #no trailing slash
        
    )
    cluster.scale(1)
    client = Client(cluster)
    # dataloader_future = client.scatter(dataloader)
    # future = client.submit(generate, dataloader_future)
    futuremap = client.submit(generate, DATASET, model_id="mistralai/Mixtral-8x7B-v0.1")
    # Don't use client.map here. It causes OOM errors because distributed tries to run *multiple* instances of the model at the same time. :(
except:
    client.shutdown()


DEBUG:Using selector: EpollSelector
DEBUG:Starting worker: SLURMCluster-0
DEBUG:writing job script: 
#!/usr/bin/env bash

#SBATCH -J dask-worker
#SBATCH -e ./logs/dask-worker-%J.err
#SBATCH -o ./logs/dask-worker-%J.out
#SBATCH -n 1
#SBATCH --cpus-per-task=8
#SBATCH --mem=179G
#SBATCH -t 00:30:00
#SBATCH --gres=gpu:40gb:2

/apps/mambaforge/envs/detectron-sam/bin/python -m distributed.cli.dask_worker tcp://192.168.0.11:42619 --nthreads 8 --memory-limit 178.81GiB --name SLURMCluster-0 --no-nanny --death-timeout 60

DEBUG:Executing the following command to command line
sbatch /tmp/tmpavxgcd7f.sh
DEBUG:Starting job: 2542


In [12]:
!squeue --me --format "%.8P %.15j %.8T %.10M %.12L %.4C %.7m %R %q"
progress(futuremap)
# If running the large model, progress won't work as well as tail -f the log file in ./logs


PARTITIO            NAME    STATE       TIME    TIME_LEFT CPUS MIN_MEM NODELIST(REASON) QOS
 BigCats     Jupyter Lab  RUNNING      39:24        20:36    8     64G mlerp-monash-node00 lion
 BigCats     dask-worker  RUNNING      18:05        11:55    8    179G mlerp-monash-node03 cheetah


VBox()

In [13]:
from pprint import pprint
result = futuremap.result()
for row in result:
    print(row)
pprint(result)
# https://distributed.dask.org/en/stable/quickstart.html#gather need to use gather instead
#outputs = client.gather(iter(futuremap))
#print(outputs)


{'prompt': 'Hello, my name is', 'output': 'Hello, my name is Katie and I am a 20-something year old living in the Midwest. I am a wife, a mother, a daughter, a sister, a friend, a student, a teacher, a writer, a reader, a dream'}
{'prompt': 'You killed my father', 'output': 'You killed my father. Prepare to die.\n\nThe Princess Bride is a 1987 fantasy adventure film directed by Rob Reiner and starring Cary Elwes, Robin Wright, Mandy Patinkin, Chris Sarandon, Wallace'}
{'prompt': 'My favourite colours are:', 'output': 'My favourite colours are:\n\n- Blue\n- Green\n- Purple\n- Pink\n- Red\n- Yellow\n- Orange\n- Brown\n- Black\n- White\n\nI like all the colours.\n\nI like the colour blue because it is'}
{'prompt': 'What is the airspeed of an unladen', 'output': 'What is the airspeed of an unladen swallow?\n\nThe airspeed of an unladen swallow is 11 meters per second, or 24 miles per hour.\n\nWhat is the airspeed velocity of an unladen swallow?\n\nThe airspeed velocity'}
[{'output': 'Hello

In [14]:
client.shutdown()

DEBUG:Stopping worker: SLURMCluster-0 job: 2542
DEBUG:Executing the following command to command line
scancel 2542
DEBUG:Closed job 2542
DEBUG:Executing the following command to command line
scancel 2542
DEBUG:Closed job 2542
