In [1]:
#presuming localCPU test runs and all the bits there are installed.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm.notebook import tqdm
from distributed import Client, LocalCluster
import torch.multiprocessing as mp

# https://docs.mlerp.cloud.edu.au/tutorials/3_dask_pytorch.html


DATASET = ["Hello, my name is", 
           "You killed my father",
           "My favourite colours are:",
           "What is the airspeed of an unladen"]

batch_size = len(DATASET)

# https://jobqueue.dask.org/en/latest/debug.html
import logging
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)



In [2]:

dataloader = torch.utils.data.DataLoader(DATASET, batch_size=batch_size, shuffle=True, num_workers=1, multiprocessing_context=mp.get_context("fork"))


In [3]:
dataloader

<torch.utils.data.dataloader.DataLoader at 0x7f875b77e510>

In [4]:
for line in dataloader:
    print(line)

['You killed my father', 'My favourite colours are:', 'Hello, my name is', 'What is the airspeed of an unladen']


# Here is the `generate` function

This function is the thing that runs in the external `cheetah` nodes and is the thing that uses the llm model to compute against the dataset.

In [15]:

def generate(loader, model_id="mistralai/Mistral-7B-v0.1", output=[]): #mistralai/Mixtral-8x7B-v0.1"
    tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
    

    for line in loader:
        print("working on", line)
        # text = "Hello my name is"
        inputs = tokenizer(line, return_tensors="pt").to("cuda")
        
        outputs = model.generate(**inputs, max_new_tokens=50)
        output.append({"prompt":line,"output":tokenizer.decode(outputs[0], skip_special_tokens=True)})
        print(output)
    return output
        

In [18]:
from dask_jobqueue import SLURMCluster
from distributed import Client, progress

try:
    # cluster = LocalCluster(processes=False)
    cluster = SLURMCluster(
        memory="128g", processes=1, cores=8, job_extra_directives=["--gres=gpu:40gb:1",], nanny=False,
        # queue="lion", # check with $ sacctmgr list clusters
        # worker_extra_args=["-e slurm-%j.err", "-o slurm-%j.out"], # not needed with log_directory
        log_directory="./logs", #no trailing slash
        
    )
    cluster.scale(1)
    client = Client(cluster)
    # dataloader_future = client.scatter(dataloader)
    # future = client.submit(generate, dataloader_future)
    futuremap = client.submit(generate, DATASET)
except:
    client.shutdown()


DEBUG:Using selector: EpollSelector
DEBUG:Starting worker: SLURMCluster-0
DEBUG:writing job script: 
#!/usr/bin/env bash

#SBATCH -J dask-worker
#SBATCH -e ./logs/dask-worker-%J.err
#SBATCH -o ./logs/dask-worker-%J.out
#SBATCH -n 1
#SBATCH --cpus-per-task=8
#SBATCH --mem=120G
#SBATCH -t 00:30:00
#SBATCH --gres=gpu:40gb:1

/apps/mambaforge/envs/detectron-sam/bin/python -m distributed.cli.dask_worker tcp://192.168.0.11:36209 --nthreads 8 --memory-limit 119.21GiB --name SLURMCluster-0 --no-nanny --death-timeout 60

DEBUG:Executing the following command to command line
sbatch /tmp/tmpawa1_ytz.sh
DEBUG:Starting job: 2537


In [19]:
!squeue --me --format "%.8P %.15j %.8T %.10M %.12L %.4C %.7m %R %q"
progress(futuremap)


PARTITIO            NAME    STATE       TIME    TIME_LEFT CPUS MIN_MEM NODELIST(REASON) QOS
 BigCats     dask-worker  RUNNING       0:03        29:57    8    120G mlerp-monash-node03 cheetah
 BigCats     Jupyter Lab  RUNNING      12:56        47:04    8     64G mlerp-monash-node00 lion


VBox()

In [20]:
from pprint import pprint
result = futuremap.result()
for row in result:
    print(row)
pprint(result)
# https://distributed.dask.org/en/stable/quickstart.html#gather need to use gather instead
#outputs = client.gather(iter(futuremap))
#print(outputs)


{'prompt': 'Hello, my name is', 'output': 'Hello, my name is Katie and I am a 20-something year old living in the beautiful city of Vancouver, BC. I am a recent graduate of the University of British Columbia with a Bachelor of Arts in Psychology and a minor in Sociology. I'}
{'prompt': 'You killed my father', 'output': 'You killed my father, prepare to die.\n\nThe first time I saw this movie was in the theater. I was 12 years old and I was with my dad. I remember being so excited to see it. I had seen the trailer and I was'}
{'prompt': 'My favourite colours are:', 'output': 'My favourite colours are:\n\n- Blue\n- Green\n- Purple\n- Red\n- Yellow\n\nI like blue because it is a calming colour. It is also the colour of the sky and the sea.\n\nI like green because it is'}
{'prompt': 'What is the airspeed of an unladen', 'output': 'What is the airspeed of an unladen swallow?\n\nThe answer to this question is 11 m/s, or 24.2 mph.\n\nThe question is from the Monty Python sketch “The Philosoph

In [21]:
client.shutdown()

DEBUG:Stopping worker: SLURMCluster-0 job: 2537
DEBUG:Executing the following command to command line
scancel 2537
DEBUG:Closed job 2537
DEBUG:Executing the following command to command line
scancel 2537
DEBUG:Closed job 2537
