In [1]:
!pip install petals
!pip install transformers
!pip install datasets



In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from datasets import load_dataset, concatenate_datasets


In [3]:
dataset = load_dataset("sciq")



  0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
dataset=concatenate_datasets([dataset['train'],dataset['validation'],dataset['test']])
dataset=dataset.select(range(500))

In [5]:
dataset=dataset.remove_columns(set(dataset.column_names)-{'question','correct_answer'})
dataset.column_names

['question', 'correct_answer']

In [6]:
data=pd.DataFrame(dataset)
data.columns=['prompt','response']

In [7]:
data.isna().sum()

prompt      0
response    0
dtype: int64

In [8]:
data.shape

(500, 2)

In [9]:
data['prompt']=data['prompt'].map(lambda x: "Give me a concise and specific answer to the following question: "+str(x))
data

Unnamed: 0,prompt,response
0,Give me a concise and specific answer to the f...,mesophilic organisms
1,Give me a concise and specific answer to the f...,coriolis effect
2,Give me a concise and specific answer to the f...,exothermic
3,Give me a concise and specific answer to the f...,alpha decay
4,Give me a concise and specific answer to the f...,smoke and ash
...,...,...
495,Give me a concise and specific answer to the f...,ozone
496,Give me a concise and specific answer to the f...,two ions need to have opposite charges
497,Give me a concise and specific answer to the f...,comparative embryology
498,Give me a concise and specific answer to the f...,lungs


In [10]:
print(data.iloc[0]['prompt'])

Give me a concise and specific answer to the following question: What type of organism is commonly used in preparation of foods such as cheese and yogurt?


In [11]:
import torch
from transformers import AutoTokenizer
from petals import AutoDistributedModelForCausalLM
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F

model_name="enoch/llama-65b-hf"

tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer.pad_token = tokenizer.eos_token


model = AutoDistributedModelForCausalLM.from_pretrained(model_name, tuning_mode='deep_ptune', pre_seq_len=3)
model = model.cuda()


class ScienceDataset(Dataset):
    def __init__(self, df, tokenizer, max_length):
        self.df = df
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        prompt = self.df.iloc[idx]['prompt']
        response = self.df.iloc[idx]['response']
        inputs = self.tokenizer(prompt, return_tensors='pt', padding='max_length', max_length=self.max_length, truncation=True)
        labels = self.tokenizer(response, return_tensors='pt', padding='max_length', max_length=self.max_length, truncation=True)['input_ids']
        return inputs['input_ids'].squeeze(0), labels.squeeze(0)


dataset = ScienceDataset(data, tokenizer,128)
dataloader = DataLoader(dataset, batch_size=64)

opt = torch.optim.Adam(model.parameters(), 3e-5)



for batch in dataloader:
    input_ids,  labels = batch
    input_ids = input_ids.cuda()
    labels = labels.cuda()

    loss = model(input_ids=input_ids, labels=labels).loss

    opt.zero_grad()
    loss.backward()
    opt.step()
    print("opt.step()")

    print(f"loss = {loss.item():.3f}")




You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565
Jul 20 08:06:21.575 [[1m[34mINFO[0m] Make sure you follow the LLaMA's terms of use: https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1
Jul 20 08:06:21.577 [[1m[34mINFO[0m] Using DHT prefix: llama-65b-hf
Jul 20 08:08:02.807 [[1m[38;5;208mWARN[0m] [[1mpetals.client.sequential_autograd.sequential_forward:99[0m] Caught exception when running forward via RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (12D3KooWQMAxYKSiSJG2v8UA9CNbXnF1QqcTFNvSttz4oDWUfU7K)>, start=60, end=80, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=4548.864555031427, public_name=None, version='2.0.0.post1', network_rps=4548.864555031427, forward_rps=4172

opt.step()
loss = 7.860


Jul 20 08:14:03.566 [[1m[38;5;208mWARN[0m] [[1mpetals.client.sequential_autograd.sequential_forward:99[0m] Caught exception when running forward via RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (12D3KooWEEvDW234YEPrqqixvtAheUcX1XUB7qbUG4pDBHjF4cGd)>, start=60, end=80, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=3199.588996772061, public_name=None, version='2.0.0.post1', network_rps=3199.588996772061, forward_rps=1705405.2247371103, inference_rps=1445.9328046309631, adapters=('timdettmers/guanaco-65b',), torch_dtype='float16', quant_type='nf4', using_relay=True, cache_tokens_left=188416, next_pings={'12D3KooWDsAtwTpuFcTxNEqQ6gPS6Co94X3tAVBz8ouk1WwWraTm': 0.20764170682216934, '12D3KooWADE41EXTAPu3fKT4QQ9Bfyb2oKgNy17hPR7atv1MJvh5': inf, '12D3KooWMtENSTvyATqWQr5qT98QPznejak5tpts9NuFk8R5LubV': inf, '12D3KooWREsBjWgF9q6uRDkLYuHqaESYudDFU5GjmFCgzhD1Ewvv': 0.18585919878458657, '12D3KooWNC2tcQduyiEkfuVncqMHhe9aBDFtWU5QJ65sXjEycmof': 0.21729072028642993, '12D3KooWKvy8sj3vh

opt.step()
loss = 6.906


Jul 20 08:18:21.549 [[1m[38;5;208mWARN[0m] [[1mpetals.client.sequential_autograd.sequential_forward:99[0m] Caught exception when running forward via RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (12D3KooWEEvDW234YEPrqqixvtAheUcX1XUB7qbUG4pDBHjF4cGd)>, start=40, end=80, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=3199.588996772061, public_name=None, version='2.0.0.post1', network_rps=3199.588996772061, forward_rps=1705405.2247371103, inference_rps=1445.9328046309631, adapters=('timdettmers/guanaco-65b',), torch_dtype='float16', quant_type='nf4', using_relay=True, cache_tokens_left=188416, next_pings={'12D3KooWDsAtwTpuFcTxNEqQ6gPS6Co94X3tAVBz8ouk1WwWraTm': 0.20257639445788692, '12D3KooWADE41EXTAPu3fKT4QQ9Bfyb2oKgNy17hPR7atv1MJvh5': inf, '12D3KooWMtENSTvyATqWQr5qT98QPznejak5tpts9NuFk8R5LubV': inf, '12D3KooWREsBjWgF9q6uRDkLYuHqaESYudDFU5GjmFCgzhD1Ewvv': 0.1857747113418206, '12D3KooWKvy8sj3vhPT8Y6LWYoDXP3CnENFqESozJnCD65Hn9RWf': 0.4621906740774412, '12D3KooWAZpDJGfcNKy

opt.step()
loss = 7.087


Jul 20 08:23:26.528 [[1m[38;5;208mWARN[0m] [[1mpetals.client.sequential_autograd.sequential_forward:99[0m] Caught exception when running forward via RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (12D3KooWEEvDW234YEPrqqixvtAheUcX1XUB7qbUG4pDBHjF4cGd)>, start=40, end=80, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=3199.588996772061, public_name=None, version='2.0.0.post1', network_rps=3199.588996772061, forward_rps=1705405.2247371103, inference_rps=1445.9328046309631, adapters=('timdettmers/guanaco-65b',), torch_dtype='float16', quant_type='nf4', using_relay=True, cache_tokens_left=188416, next_pings={'12D3KooWDsAtwTpuFcTxNEqQ6gPS6Co94X3tAVBz8ouk1WwWraTm': 0.19895373229309005, '12D3KooWKvy8sj3vhPT8Y6LWYoDXP3CnENFqESozJnCD65Hn9RWf': 0.42608014706185643, '12D3KooWAZpDJGfcNKyWLJUmuaRfTB54Mf51a8DCiujYTcNMC929': 0.1869493333637822, '12D3KooWQMAxYKSiSJG2v8UA9CNbXnF1QqcTFNvSttz4oDWUfU7K': inf, '12D3KooWEroZqPrvnaAGKCJvsPA3RC69wZ5vh6ktjN8ospYezmMc': 0.5996591610300006, '12D

opt.step()
loss = 7.070


Jul 20 08:28:13.473 [[1m[38;5;208mWARN[0m] [[1mpetals.client.sequential_autograd.sequential_forward:99[0m] Caught exception when running forward via RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (12D3KooWEroZqPrvnaAGKCJvsPA3RC69wZ5vh6ktjN8ospYezmMc)>, start=60, end=80, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=709.5161473280316, public_name=None, version='1.2.0.dev3', network_rps=709.5161473280316, forward_rps=37312.611700620575, inference_rps=419.00907164767057, adapters=('timdettmers/guanaco-65b',), torch_dtype='float16', quant_type='nf4', using_relay=False, cache_tokens_left=163840, next_pings={'12D3KooWEEvDW234YEPrqqixvtAheUcX1XUB7qbUG4pDBHjF4cGd': inf, '12D3KooWDsAtwTpuFcTxNEqQ6gPS6Co94X3tAVBz8ouk1WwWraTm': 0.25590827291194373, '12D3KooWKJvVBGa2wkzyvmCD7WeGynexPCho8emgcd7HPk4HQeGd': inf, '12D3KooWG36YZfwnAkiXcu5xBY86EHyD1S2Y2uBfNf9obK1GwPar': 0.27018776880856965, '12D3KooWE2jTj6L2tjhYcZLfAAkER9ns2JT2F5mp81b4Kn33a5gy': 0.03899554543601991, '12D3KooWSxsRrX3e5

opt.step()
loss = 6.907


Jul 20 08:31:14.160 [[1m[38;5;208mWARN[0m] [[1mpetals.client.sequential_autograd.sequential_forward:99[0m] Caught exception when running forward via RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (12D3KooWQMAxYKSiSJG2v8UA9CNbXnF1QqcTFNvSttz4oDWUfU7K)>, start=60, end=80, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=4548.864555031427, public_name=None, version='2.0.0.post1', network_rps=4548.864555031427, forward_rps=417297.91126015485, inference_rps=207.13909703531036, adapters=('timdettmers/guanaco-65b',), torch_dtype='float16', quant_type='nf4', using_relay=True, cache_tokens_left=114688, next_pings={'12D3KooWG36YZfwnAkiXcu5xBY86EHyD1S2Y2uBfNf9obK1GwPar': 0.07016981866347391, '12D3KooWMtENSTvyATqWQr5qT98QPznejak5tpts9NuFk8R5LubV': inf, '12D3KooWKvy8sj3vhPT8Y6LWYoDXP3CnENFqESozJnCD65Hn9RWf': 0.1408073555669488, '12D3KooWEEvDW234YEPrqqixvtAheUcX1XUB7qbUG4pDBHjF4cGd': inf, '12D3KooWADE41EXTAPu3fKT4QQ9Bfyb2oKgNy17hPR7atv1MJvh5': inf, '12D3KooWE2jTj6L2tjhYcZLfAAkER9ns2J

opt.step()
loss = 7.053


Jul 20 08:40:09.271 [[1m[38;5;208mWARN[0m] [[1mpetals.client.sequential_autograd.sequential_backward:176[0m] Caught exception when running backward via RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (12D3KooWKJvVBGa2wkzyvmCD7WeGynexPCho8emgcd7HPk4HQeGd)>, start=60, end=80, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=343.26418438019596, public_name='http://michaelkammes.com', version='2.0.0.post1', network_rps=343.26418438019596, forward_rps=324313.0493776356, inference_rps=435.8476915955684, adapters=('timdettmers/guanaco-65b',), torch_dtype='float16', quant_type='nf4', using_relay=True, cache_tokens_left=90112, next_pings={'12D3KooWQMAxYKSiSJG2v8UA9CNbXnF1QqcTFNvSttz4oDWUfU7K': inf, '12D3KooWE2jTj6L2tjhYcZLfAAkER9ns2JT2F5mp81b4Kn33a5gy': 0.33770354577532874, '12D3KooWADE41EXTAPu3fKT4QQ9Bfyb2oKgNy17hPR7atv1MJvh5': inf, '12D3KooWG36YZfwnAkiXcu5xBY86EHyD1S2Y2uBfNf9obK1GwPar': 0.13452703817900927, '12D3KooWDsAtwTpuFcTxNEqQ6gPS6Co94X3tAVBz8ouk1WwWraTm': 0.1049744585746

opt.step()
loss = 6.658


Jul 20 08:44:50.128 [[1m[38;5;208mWARN[0m] [[1mpetals.client.sequential_autograd.sequential_forward:99[0m] Caught exception when running forward via RemoteSpanInfo(peer_id=<libp2p.peer.id.ID (12D3KooWEEvDW234YEPrqqixvtAheUcX1XUB7qbUG4pDBHjF4cGd)>, start=40, end=80, server_info=ServerInfo(state=<ServerState.ONLINE: 2>, throughput=3199.588996772061, public_name=None, version='2.0.0.post1', network_rps=3199.588996772061, forward_rps=1705405.2247371103, inference_rps=1445.9328046309631, adapters=('timdettmers/guanaco-65b',), torch_dtype='float16', quant_type='nf4', using_relay=True, cache_tokens_left=188416, next_pings={'12D3KooWKJvVBGa2wkzyvmCD7WeGynexPCho8emgcd7HPk4HQeGd': inf, '12D3KooWE2jTj6L2tjhYcZLfAAkER9ns2JT2F5mp81b4Kn33a5gy': 0.43045059517291423, '12D3KooWA8rYqs2SbC5yAHZBr4PqZR4dFgg1idUj54VqGeShGe7d': 1.826019463800185, '12D3KooWMVzqkaDLqroKhVawrAhTxwmJGWP6FPFF9Gu2P1NNZJCY': inf, '12D3KooWG36YZfwnAkiXcu5xBY86EHyD1S2Y2uBfNf9obK1GwPar': 0.3270611594983661, '12D3KooWHUXmh37ZYu8K

opt.step()
loss = 6.592


In [17]:
inputs = tokenizer("How do airplanes fly?", return_tensors="pt")["input_ids"].cuda()
outputs = model.generate(inputs, max_new_tokens=15)
print("generated:", tokenizer.decode(outputs[0]))

Jul 20 08:55:03.891 [[1m[34mINFO[0m] Route found: 0:40 via …D1Ewvv => 40:80 via …NMC929


generated: <s> How do airplanes fly?
Airplanes fly by using the lift generated by their wings.


In [23]:
inputs = tokenizer("Why is the sky blue? Be consise and give straight answer", return_tensors="pt")["input_ids"].cuda()
outputs = model.generate(inputs, max_new_tokens=50)
print("generated:", tokenizer.decode(outputs[0]))

Jul 20 08:58:42.721 [[1m[34mINFO[0m] Route found: 0:40 via …D1Ewvv => 40:80 via …NMC929


generated: <s> Why is the sky blue? Be consise and give straight answer.
Asked by: Ajay
The sky is blue because of the way the atmosphere scatters light. The atmosphere is made up of molecules of gas, and light hitting these molecules can be scattered in various directions.


In [24]:
inputs = tokenizer("What protects the lungs?  Be consise and give straight answer", return_tensors="pt")["input_ids"].cuda()
outputs = model.generate(inputs, max_new_tokens=50)
print("generated:", tokenizer.decode(outputs[0]))

Jul 20 08:59:22.698 [[1m[34mINFO[0m] Route found: 0:40 via …D1Ewvv => 40:80 via …NMC929


generated: <s> What protects the lungs?  Be consise and give straight answer.
The lungs are protected by the rib cage.
The lungs are protected by the rib cage. The rib cage is a bony structure that protects the lungs.</s>


In [12]:
# import torch
# from transformers import AutoTokenizer
# from petals import AutoDistributedModelForCausalLM
# from torch.utils.data import Dataset, DataLoader
# import torch.nn as nn
# import torch.nn.functional as F

# model_name="enoch/llama-65b-hf"

# tokenizer = AutoTokenizer.from_pretrained(model_name)

# tokenizer.pad_token = tokenizer.eos_token


# model = AutoDistributedModelForCausalLM.from_pretrained(model_name)
# model = model.cuda()

# class LLMBasedGenerator(nn.Module):
#     def __init__(self, model):
#         super().__init__()
#         self.distributed_layers = model.transformer.h
#         self.adapter = nn.Sequential(nn.Linear(model.config.hidden_size, 64), nn.Linear(64, model.config.hidden_size))
#         self.head = model.lm_head

#     def forward(self, embeddings):

#         mid_block = len(self.distributed_layers) // 2
#         hidden_states = self.distributed_layers[:mid_block](embeddings)

#         hidden_states = hidden_states.to(self.adapter[0].weight.dtype)


#         hidden_states = self.adapter(hidden_states)
#         hidden_states = self.distributed_layers[mid_block:](hidden_states)
#         return self.head(hidden_states)

# generator = LLMBasedGenerator(model).cuda()

# class ScienceDataset(Dataset):
#     def __init__(self, df, tokenizer, max_length):
#         self.df = df
#         self.tokenizer = tokenizer
#         self.max_length = max_length

#     def __len__(self):
#         return len(self.df)

#     def __getitem__(self, idx):
#         prompt = self.df.iloc[idx]['prompt']
#         response = self.df.iloc[idx]['response']
#         inputs = self.tokenizer(prompt, return_tensors='pt', padding='max_length', max_length=self.max_length, truncation=True)
#         labels = self.tokenizer(response, return_tensors='pt', padding='max_length', max_length=self.max_length, truncation=True)['input_ids']
#         return inputs['input_ids'].squeeze(0), labels.squeeze(0)


# dataset = ScienceDataset(data, tokenizer,256)
# dataloader = DataLoader(dataset, batch_size=8)

# opt = torch.optim.Adam(model.parameters(), 3e-5)

# criterion = nn.CrossEntropyLoss()
# for epoch in range(3):
#     for batch in dataloader:
#         input_ids,  labels = batch
#         input_ids = input_ids.cuda()
#         labels = labels.cuda()

#         input_ids=model.transformer.word_embeddings(input_ids)
#         outputs = generator(input_ids)
#         loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
#         print(f"loss = {loss.item():.3f}")
#         opt.zero_grad()
#         loss.backward()
#         opt.step()

# # print('predicted:', generator(inputs).argmax(-1))  #
