# Fine-tuning com PyTorch Lightning

Demonstra treino rápido usando Lightning.

In [ ]:
!pip install -q pytorch-lightning transformers datasets

In [ ]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pytorch_lightning as pl
import torch

train_ds = load_dataset('imdb', split='train[:1%]')
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

def collate(batch):
    enc = tokenizer([x['text'] for x in batch], padding=True, truncation=True, return_tensors='pt')
    labels = torch.tensor([x['label'] for x in batch])
    return {**enc, 'labels': labels}

loader = torch.utils.data.DataLoader(train_ds, batch_size=8, collate_fn=collate)
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = model
    def forward(self, **x):
        return self.model(**x)
    def training_step(self, batch, batch_idx):
        out = self(**batch)
        loss = out.loss
        self.log('train_loss', loss)
        return loss
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=5e-5)

trainer = pl.Trainer(max_epochs=1, logger=False, enable_checkpointing=False)
trainer.fit(LitModel(), loader)