In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer

class WavefunctionDataset(Dataset):
    def __init__(self, file_prefix, start_n, end_n):
        self.file_prefix = file_prefix
        self.start_n = start_n
        self.end_n = end_n
        
    def __len__(self):
        return self.end_n - self.start_n
    
    def __getitem__(self, idx):
        file_path = f'{self.file_prefix}{self.start_n + idx}.npy'
        wavefunction = np.load(file_path)
        return torch.from_numpy(wavefunction).float()

# Set the file prefix and range
file_prefix = '/user/as6154/dissert/half_l12_data/delta_'
start_n = 0
end_n = 799

# Create the dataset
dataset = WavefunctionDataset(file_prefix, start_n, end_n)

# Create the DataLoader
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Load the MPT model and tokenizer
model_name = "mosaicml/mpt-7b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)

# Define the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

# Fine-tune the MPT model on the wavefunction data
model.train()
for batch in dataloader:
    # Tokenize the wavefunction data
    inputs = tokenizer([str(item) for item in batch], return_tensors="pt", padding=True)
    # inputs = tokenizer(batch.tolist(), return_tensors="pt", padding=True)
    
    # Forward pass
    outputs = model(**inputs, labels=inputs["input_ids"])
    
    # Backward pass and optimization step
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

# Save the fine-tuned model
model.save_pretrained("fine_tuned_mpt")

configuration_mpt.py:   0%|          | 0.00/16.4k [00:00<?, ?B/s]

attention.py:   0%|          | 0.00/24.6k [00:00<?, ?B/s]

flash_attn_triton.py:   0%|          | 0.00/28.2k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/mosaicml/mpt-7b:
- flash_attn_triton.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


norm.py:   0%|          | 0.00/3.12k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/mosaicml/mpt-7b:
- norm.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


fc.py:   0%|          | 0.00/167 [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/mosaicml/mpt-7b:
- fc.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/mosaicml/mpt-7b:
- attention.py
- flash_attn_triton.py
- norm.py
- fc.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


ffn.py:   0%|          | 0.00/5.22k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/mosaicml/mpt-7b:
- ffn.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.




A new version of the following files was downloaded from https://huggingface.co/mosaicml/mpt-7b:
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


blocks.py:   0%|          | 0.00/4.04k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/mosaicml/mpt-7b:
- blocks.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/mosaicml/mpt-7b:
- configuration_mpt.py
- attention.py
- ffn.py
- blocks.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_mpt.py:   0%|          | 0.00/32.4k [00:00<?, ?B/s]

ImportError: This modeling file requires the following packages that were not found in your environment: flash_attn. Run `pip install flash_attn`