In [32]:
import os
import torch
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer
)

In [33]:
MODEL_NAME = 'distilbert-base-uncased'
NUM_LABELS = 6
MAX_LENGTH = 128
CHECKPOINT_DIR = './checkpoints'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [44]:
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=NUM_LABELS)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


def load_model(model, optimizer, load_dir, device):
    model_path = os.path.join(load_dir, 'checkpoint_epoch_3.pt')
    
    # load model to current device
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    if device != 'cpu':
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
    
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    
    print(f"Model loaded from epoch {epoch} with loss: {loss:.4f}")
    return model, optimizer, epoch, loss
    
optimizer = torch.optim.AdamW(model.parameters())
model, _, _, _ = load_model(model, optimizer, CHECKPOINT_DIR, device)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded from epoch 3 with loss: 0.1300


In [49]:
# Example predictions
decode_labels = {
    0: 'sadness',
    1: 'joy',
    2: 'love',
    3: 'anger',
    4: 'fear',
    5: 'surprise',
}

sentences = [
    'Let\'s go play soccer!',
    'I hate you!',
    "You are hopeless!",
    "Weather is terrible, and I can't finish my homework. I'm so tired.",
    "I'm tired.",
    "Love you, my dear!"
]
tokens = tokenizer(sentences, padding='max_length', max_length=MAX_LENGTH, truncation=True, return_tensors='pt')
tokens = {k: v.to(device) for k, v in tokens.items()}
outputs = model(**tokens)
preds = torch.argmax(outputs.logits, dim=1).cpu()

for sentence, pred in zip(sentences, preds):
    print(f"{sentence}: {decode_labels[pred.item()]}")

Let's go play soccer!: joy
I hate you!: anger
You are hopeless!: sadness
Weather is terrible, and I can't finish my homework. I'm so tired.: sadness
I'm tired.: sadness
Love you, my dear!: joy
