# Exploring transformer.Trainer API

In [1]:
from transformers import Trainer, TrainingArguments


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# train a simple mlp using the Trainer API

import torch
from torch import nn

# MLP for classification of 3D points
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(3, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
# create a simple dataset
import numpy as np
from torch.utils.data import Dataset, DataLoader

class PointCloudDataset(Dataset):
    def __init__(self, num_points):
        self.points = np.random.rand(num_points, 3)
        self.labels = np.random.randint(0, 2, num_points)
        
    def __len__(self):
        return len(self.points)
    
    def __getitem__(self, idx):

        return {'x': torch.tensor(self.points[idx], dtype=torch.float), 'y': torch.tensor(self.labels[idx], dtype=torch.float)}
    
# create the dataset and dataloader
dataset = PointCloudDataset(1000)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# create the model, loss function, and optimizer
model = MLP()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# train the model using the Trainer API
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=10,              # total number of training epochs
    per_device_train_batch_size=32,  # batch size per device during training
    logging_dir='./logs',            # directory for storing logs
)

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=dataset,         # training dataset
)

trainer.train()



Step,Training Loss


TrainOutput(global_step=160, training_loss=-0.3117362976074219, metrics={'train_runtime': 2.8663, 'train_samples_per_second': 3488.846, 'train_steps_per_second': 55.822, 'total_flos': 0.0, 'train_loss': -0.3117362976074219, 'epoch': 10.0})