In [4]:
from transformers import TrainingArguments, Trainer
from torch import nn

training_args = TrainingArguments(
    output_dir="outout",
    per_device_train_batch_size=8,
    num_train_epochs=1, #Epoch數設定
    fp16=True,
    save_steps=200,
    logging_steps=50,
    learning_rate=1e-5,
    weight_decay=1e-4,
    save_total_limit=2,
    remove_unused_columns=False, #(重要)需設為False
)


class CNN(nn.Module):
    def __init__(self, in_channels, num_labels=1000):
        super(CNN, self).__init__()
        self.num_labels = num_labels
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels = in_channels, out_channels = 32,
                      kernel_size = 3, stride = 1, padding = 1), # stride = 1, padding = (kernel_size-1)/2 = (3-1)/2
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels = 32, out_channels = 16,
                      kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.flatten = nn.Flatten()
        self.relu = nn.ReLU()
        
        # Define fully-connected layers
        self.fc_1 = nn.Linear(16*56*56, 2048)
        self.fc_2 = nn.Linear(2048, num_labels)

    def forward(self, inputs, labels=None):
        x = self.conv1(inputs)
        x = self.conv2(x)
        x = self.flatten(x)
        x = self.fc_1(x)
        x = self.relu(x)
        logits = self.fc_2(x)
        
        if labels is not None: #BertForSequenceClassification
            if self.num_labels == 1:
                #  Regression
                loss_fct = nn.MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                #  Classification
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        else:
            loss = None

        return {'loss':loss, 'logits':logits, 'hidden_states':None}

In [10]:
import numpy as np
data = [
    {
        'inputs': np.random.randn(3,224,224).astype("float32"),
        'labels': np.random.randint(0, 5, size=1)
    }
]*60

In [11]:
model = CNN(3, num_labels=6)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=data,
    eval_dataset=data,
)
trainer.train()

Using cuda_amp half precision backend
