In [22]:
import torch
from transformers import BertForSequenceClassification, BertTokenizer
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
import random
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm

# Geração de labels aleatórios
def get_random_label(description, animals):
    words = description.split()
    mentioned_animals = [animal for animal in animals if animal in words]
    if not mentioned_animals:
        return None
    chosen_animal = random.choice(mentioned_animals)
    label = label_encoder.transform([chosen_animal])[0]
    return label


# Dataset personalizado
class AnimalDataset(Dataset):
    def __init__(self, descriptions, labels, tokenizer, max_len=128):
        self.descriptions = descriptions
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.descriptions)

    def __getitem__(self, index):
        description = self.descriptions[index].strip()
        label = self.labels[index]
        inputs = self.tokenizer(
            description,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )
        return {
            'input_ids': inputs['input_ids'].flatten(),
            'attention_mask': inputs['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [16]:
# Load data
with open('data/animal_descriptions.txt', 'r') as file:
    descriptions = file.readlines()

# Lista de animais
animals = ["dog", "horse", "elephant", "butterfly", "chicken", "cat", "cow", "sheep", "spider", "squirrel"]

# Inicializar o LabelEncoder
label_encoder = LabelEncoder()
label_encoder.fit(animals)

# Aplicar a função para gerar labels aleatórios
random_labels = [get_random_label(desc, animals) for desc in descriptions]
filtered_descriptions = [desc for desc, label in zip(descriptions, random_labels) if label is not None]
filtered_labels = [label for label in random_labels if label is not None]


# Configuração do modelo BERT
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    num_labels=len(animals)  # Número de classes
)

# Configuração do DataLoader e Otimizador
dataset = AnimalDataset(filtered_descriptions, filtered_labels, tokenizer)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
optimizer = AdamW(model.parameters(), lr=2e-5)

# Definir a função de perda
loss_fn = torch.nn.CrossEntropyLoss()

# Definir o dispositivo (GPU ou CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

loss_arr = []

# Loop de Treinamento
model.train()
for epoch in range(10):  # Número de épocas
    for batch in tqdm(dataloader):
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = loss_fn(outputs.logits, labels)

        loss.backward()
        optimizer.step()
        loss_arr.append(loss.item())

    print(f"Epoch {epoch + 1} completed with loss {loss.item()}")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 63/63 [00:10<00:00,  5.79it/s]


Epoch 1 completed with loss 2.0102224349975586


100%|██████████| 63/63 [00:10<00:00,  5.75it/s]


Epoch 2 completed with loss 1.1343129873275757


100%|██████████| 63/63 [00:10<00:00,  5.80it/s]


Epoch 3 completed with loss 1.1851214170455933


100%|██████████| 63/63 [00:10<00:00,  5.86it/s]


Epoch 4 completed with loss 1.1384773254394531


100%|██████████| 63/63 [00:10<00:00,  5.88it/s]


Epoch 5 completed with loss 1.3271236419677734


100%|██████████| 63/63 [00:10<00:00,  5.89it/s]


Epoch 6 completed with loss 1.40555739402771


100%|██████████| 63/63 [00:10<00:00,  5.90it/s]


Epoch 7 completed with loss 0.9634329080581665


100%|██████████| 63/63 [00:10<00:00,  5.88it/s]


Epoch 8 completed with loss 0.9415532350540161


100%|██████████| 63/63 [00:10<00:00,  5.87it/s]


Epoch 9 completed with loss 0.6611301898956299


100%|██████████| 63/63 [00:10<00:00,  5.84it/s]

Epoch 10 completed with loss 1.4666528701782227





In [17]:
def print_animal_probabilities(model, description, tokenizer, label_encoder):
    model.eval()
    inputs = tokenizer(description, return_tensors="pt", padding=True, truncation=True, max_length=128)
    inputs = {key: value.to(device) for key, value in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = F.softmax(logits, dim=-1)  # Aplicar softmax para obter as probabilidades

    # Converter as probabilidades para um array no CPU
    probabilities = probabilities.cpu().numpy().flatten()

    # Obter os nomes das classes usando o LabelEncoder
    class_names = label_encoder.inverse_transform(range(len(probabilities)))

    # Associar probabilidades com os nomes das classes corretamente
    animal_probabilities = {class_name: f"{prob*100:.2f}%" for class_name, prob in zip(class_names, probabilities)}

    # Imprimir as probabilidades para cada animal
    print(f"Description: {description}")
    print("Probabilities:")
    for animal, prob in animal_probabilities.items():
        print(f"{animal}: {prob}")

    return animal_probabilities

In [25]:
# Exemplo de frase para teste
test_description = "A butterfly flying over a dog"
animal_probabilities = print_animal_probabilities(model, test_description, tokenizer, label_encoder)

Description: A butterfly flying over a dog
Probabilities:
butterfly: 52.07%
cat: 0.94%
chicken: 1.56%
cow: 0.72%
dog: 35.96%
elephant: 2.40%
horse: 2.27%
sheep: 1.22%
spider: 1.27%
squirrel: 1.60%


In [24]:
torch.save(model.state_dict(), 'model/model.pt')