# 在Colab使用TPU Fine-tuning GPT模型

## 環境準備

In [None]:
pip install cloud-tpu-client
curl -O https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py
python env-setup.py --version 1.12 --apt-packages libomp5 libopenblas-dev
pip install torch torchvision
pip install cloud-tpu-client torch-xla

## Fine-tuning 過程

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import DataLoader, Dataset
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)

class MyDataset(Dataset):
    def __init__(self, txt_list, tokenizer, max_length=512):
        self.input_ids = []
        self.attn_masks = []
        self.labels = []

        for txt in txt_list:
            encodings_dict = tokenizer(txt, truncation=True, max_length=max_length, padding="max_length")
            self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
            self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
            self.labels.append(torch.tensor(encodings_dict['input_ids']))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {'input_ids': self.input_ids[idx], 'attention_mask': self.attn_masks[idx], 'labels': self.labels[idx]}

# You may need to change the data to the real training data.
data_list = ["This is the first sample from Kenny.", "This is the second sample from Jessica.", "This is the third sample from Mason."]

my_dataset = MyDataset(data_list, tokenizer)
train_loader = DataLoader(my_dataset, batch_size=32, shuffle=True)

def train_model():
    device = xm.xla_device()
    model.to(device)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

    num_epochs = 3
    for epoch in range(num_epochs):
        para_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
        for batch in para_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()

            xm.optimizer_step(optimizer)

            print(f"Epoch {epoch} | Loss: {loss.item()}")

            xm.mark_step()

    if xm.is_master_ordinal():
        model.to('cpu')

        model.save_pretrained('/content/drive/MyDrive/model_output')

def _mp_fn(rank, flags):
    torch.set_default_tensor_type('torch.FloatTensor')
    train_model()

FLAGS = {}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')

## 實測Fine-tuning過後的模型

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_path = '/content/drive/MyDrive/model_output'
model = AutoModelForCausalLM.from_pretrained(model_path)

tokenizer = AutoTokenizer.from_pretrained("gpt2")

model.eval()

In [None]:
input_text = '''What is the capital of France?'''

input_ids = tokenizer.encode(input_text, return_tensors='pt')

output = model.generate(input_ids, max_length=150, num_return_sequences=1)

generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(generated_text)