In [None]:
!pip3 install transformers -q
!pip3 install datasets evaluate scikit-learn sentencepiece langchain torch_xla[tpuvm] -q
!pip uninstall tensorflow -y #that's the meme part

In [None]:
import os
import pandas as pd
import numpy as np
import datasets
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl
from langchain.prompts import PromptTemplate
from datasets import Dataset
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorWithPadding
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from tqdm import tqdm
from transformers import logging as hf_logging

os.environ.pop('TPU_PROCESS_ADDRESSES')
os.environ.pop('CLOUD_TPU_TASK_ID')
hf_logging.set_verbosity_error()

MAX_INPUT=128
MODEL = "/kaggle/input/llama2-7b-hf/Llama2-7b-hf"

In [None]:
template = """Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction: You will be given question with 5 possible answers. Answer the following multiple choice question by giving the most appropriate response. Answer should be one among [A, B, C, D, E]

### Context: {context}\n

### Question: {prompt}\n
A) {a}\n
B) {b}\n
C) {c}\n
D) {d}\n
E) {e}\n

### Response:
"""#at first testing I provied answers to this template which resulted in data leak lol

prompt = PromptTemplate(template=template, input_variables=['prompt', 'a', 'b', 'c', 'd', 'e', 'context'])
train_df = pd.read_csv('/kaggle/input/corrected-context-ds/complete_context_dataset_corrected.csv')
df = train_df
df['context'] = df['context'].str.slice(0, 200)
df = df.drop(columns=['Unnamed: 0']).dropna().reset_index(drop=True)

df_val = pd.read_csv("/kaggle/input/60k-data-with-context-v2/train_with_context2.csv")
df_val['context'] = df_val['context'].str.slice(0, 200)

data = Dataset.from_pandas(df)
data_val = Dataset.from_pandas(df_val)

def plot_sequence_lengths(data, max_length=1024): #filter abnormally long samples
    sequence_lengths = []
    keep_indices = []
    for i, example in enumerate(data):
        sequence_lengths.append(len(example['prompt']) + len(example['context']))
        if sequence_lengths[i] < max_length:
            keep_indices.append(i)
    return keep_indices

keep_indices_train = plot_sequence_lengths(data)
data = data.select(keep_indices_train)

keep_indices_val = plot_sequence_lengths(data_val)
data_val = data_val.select(keep_indices_val)
df

In [None]:
def format_text(example):
    text = prompt.format(prompt=example['prompt'], 
                         a=example['A'], 
                         b=example['B'], 
                         c=example['C'], 
                         d=example['D'], 
                         e=example['E'], 
                         context=example['context'])
    return {"text": text}

data = data.map(format_text, num_proc = 56)
data_val = data_val.map(format_text, num_proc = 4)
tokenizer = AutoTokenizer.from_pretrained("/kaggle/input/llama2-7b-hf/Llama2-7b-hf", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
print(data['text'][0])

In [None]:
FLAGS = {'MAX_INPUT': 64,
         'LOGGING_STEPS': 10,
         'NUM_EPOCHS': 2,
         'BATCH_SIZE': 8,
          'NUM_STEPS': len(data['text'])}

In [None]:
def preprocess_function(example):
    text_tokens = tokenizer(example["text"], truncation=True, max_length=64, padding='max_length').input_ids
    answer_tokens = tokenizer(example["answer"], truncation=True, max_length=64, padding='max_length').input_ids
    return {
        "input_ids": text_tokens,
        "label": answer_tokens,
    }

data_train = data.map(preprocess_function, batched=False, num_proc=56).remove_columns(['prompt', 'context', 'A', 'B', 'C', 'D', 'E', 'answer', 'text']) #remove everything except for input_ids and labels

In [None]:
!export XLA_USE_BF16=1
def train(index, FLAGS):
    device = xm.xla_device()
    train_sampler = torch.utils.data.distributed.DistributedSampler(data_train, num_replicas=8, rank=xm.get_ordinal(),
                                                                    shuffle=True) #this guy is responsible for distributing data across 8 cores
    training_loader = torch.utils.data.DataLoader(data_train, batch_size=FLAGS['BATCH_SIZE'], collate_fn=DataCollatorWithPadding(tokenizer=tokenizer),
                                                  sampler=train_sampler)

    xla_train_loader = pl.MpDeviceLoader(training_loader, device)
    model = AutoModelForCausalLM.from_pretrained("/kaggle/input/llama2-7b-hf/Llama2-7b-hf", torch_dtype=torch.bfloat16).to(device)
    cnt = 0
    for param in model.parameters(): #Freezing most of the layers
        cnt += 1
        param.requires_grad = True
        if cnt < 285:
            param.requires_grad = False

    model.train()
    num_replicas = 8

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=FLAGS['NUM_STEPS'] * FLAGS['BATCH_SIZE'])
    print(f"Initiliazed model and datasets on core {index}")
    num_iterations = int(FLAGS['NUM_STEPS'] / FLAGS['BATCH_SIZE'] / num_replicas)
    for epoch in range(1, FLAGS['NUM_EPOCHS'] + 1):
        for step, data in enumerate(xla_train_loader):
            optimizer.zero_grad()
            outputs = model(**data)
            loss = outputs.loss
            loss.backward()
            xm.optimizer_step(optimizer)
            if (step + 1) % FLAGS['LOGGING_STEPS'] == 0:
                xm.master_print(f'Loss: {loss.item()}, {step + 1} steps out of {num_iterations}, LR: {optimizer.param_groups[0]["lr"]}')
            scheduler.step()
        xm.master_print(f"Trained for {epoch} epochs out of {FLAGS['NUM_EPOCHS']}")
    xm.master_print("Waiting for all processes across cores to finish")
    xm.rendezvous('init')
    xm.master_print("Saving the model")
    xm.save(model.state_dict(), "tpu-llama.bin")
xmp.spawn(train, args=(FLAGS,), start_method='fork')