In [9]:
import torch
import csv
import pandas
import transformers
import torch.nn as nn
from sklearn.model_selection import train_test_split

In [11]:
data = pandas.read_csv('song_lyrics.csv')
lyrics = data['lyrics']
labels = data['label']
df = data

In [12]:
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)

In [13]:
class BERTClassifier(nn.Module):
    def __init__(self, bert_path, num_classes):
        super(BERTClassifier, self).__init__()
        self.bert = transformers.BertModel.from_pretrained(bert_path)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        _, pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

In [15]:
model = BERTClassifier('bert-base-uncased', num_classes=2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
loss_fn = nn.CrossEntropyLoss()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [16]:
def train(model, optimizer, loss_fn, train_loader):
    model.train()
    for batch in train_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()

In [None]:
def evaluate(model, loss_fn, test_loader):
    model.eval()
    total_loss, total_correct = 0, 0
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            logits = model(input_ids, attention_mask)
            loss = loss_fn(logits, labels)
            total_loss += loss.item() * len(labels)
            preds = torch.argmax(logits, dim=1)
            total_correct += torch.sum(preds == labels).item()
    return total_loss / len(test_df), total_correct / len(test_df)