In [None]:
from transformers import BartTokenizer, BartForSequenceClassification
import pandas as pd
import torch

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 


train_data_full = pd.read_csv('captions_train - Sheet1-3.csv',header=None,names=["caption","result"])
test_data = pd.read_csv('captions_test - Sheet1-4.csv',header=None,names=["caption","result"])

In [None]:
print(train_data_full)

In [None]:
train_data = train_data_full
print(f"\nTraining on {len(train_data)} examples\n")

print(train_data.sort_values(by=["result"]).to_string())

In [None]:
num_results = 3
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
model = BartForSequenceClassification.from_pretrained('facebook/bart-large', num_labels=num_results)
mode = model.to(device)

In [None]:
# Convert result column to one-hot encoding
one_hot_train = pd.get_dummies(train_data['result'])
one_hot_test = pd.get_dummies(test_data['result'])

# Tokenize captions and convert to PyTorch dataset
inputs_train = tokenizer(list(train_data['caption']), return_tensors='pt', padding=True)
labels_train = torch.tensor(one_hot_train.values, dtype=torch.float32)
dataset_train = torch.utils.data.TensorDataset(inputs_train['input_ids'], inputs_train['attention_mask'], labels_train)
inputs_test = tokenizer(list(test_data['caption']), return_tensors='pt', padding=True)
labels_test = torch.tensor(one_hot_test.values, dtype=torch.float32)
dataset_test = torch.utils.data.TensorDataset(inputs_test['input_ids'], inputs_test['attention_mask'], labels_test)

In [None]:
import matplotlib.pyplot as plt
def graphLoss(epoch_counter, train_loss_hist, test_loss_hist, loss_name="Loss", start = 0):
  fig = plt.figure()
  plt.plot(epoch_counter[start:], train_loss_hist[start:], color='blue')
  plt.plot(epoch_counter[start:], test_loss_hist[start:], color='red')
  plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
  plt.xlabel('#Epochs')
  plt.ylabel(loss_name)

In [None]:
def logResults(epoch, num_epochs, train_loss, train_loss_history, test_loss, test_loss_history, epoch_counter, print_interval=1000):
  if (epoch%print_interval == 0):  print('Epoch [%d/%d], Train Loss: %.4f, Test Loss: %.4f' %(epoch+1, num_epochs, train_loss, test_loss))
  train_loss_history.append(train_loss)
  test_loss_history.append(test_loss)
  epoch_counter.append(epoch)

In [None]:
# Define training parameters
epochs = 10
batch_size = 16
learning_rate = 2e-5

# Train model
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
data_loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size)

test_loss_history = []
train_loss_history = []
epoch_counter = []

print(f"\nTraining on {len(train_data)} examples\n")
print("Num. Parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))

for epoch in range(epochs):
    # Compute average loss after 100 steps
    avg_loss = 0
    for step, batch in enumerate(data_loader_train):
        input_ids, attention_mask, labels = batch
        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
        avg_loss += loss.item()
        if step % 100 == 0:
            print(f"Step {step}/{len(data_loader_train)} Loss {loss} Avg Train Loss {avg_loss / (step + 1)}")
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss = avg_loss / len(data_loader_train)
    # Print loss after every epoch
    print(f"Epoch {epoch+1} Test Loss {loss}")
    # Compute accuracy after every epoch
    correct = 0
    total = 0
    for step, batch in enumerate(data_loader_test):
        input_ids, attention_mask, labels = batch
        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
        outputs = model(input_ids, attention_mask=attention_mask)
        predicted = torch.argmax(outputs[0], dim=1)
        total += labels.size(0)
        correct += (predicted == torch.argmax(labels, dim=1)).sum().item()
    print(f"Test Accuracy {100*correct/total}%\n")
    logResults(epoch, epochs, avg_loss, train_loss_history, loss, test_loss_history, epoch_counter, 1)

graphLoss(epoch_counter, train_loss_history, test_loss_history)
# Save model
model.save_pretrained('fine-tuned-bart_captions')

In [None]:
model.eval()
correct = 0
total = 0
for step, batch in enumerate(data_loader_test):
    input_ids, attention_mask, labels = batch
    input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
    outputs = model(input_ids, attention_mask=attention_mask)
    predicted = torch.argmax(outputs[0], dim=1)
    total += labels.size(0)
    correct += (predicted == torch.argmax(labels, dim=1)).sum().item()

print(f"Accuracy {correct/total}")

In [None]:
new_names = ["Unfortunate outcome.", "HUGE COMEBACK TONIGHT", "Final in Dallas."]
inputs = tokenizer(new_names, return_tensors='pt', padding=True)
outputs = model(inputs['input_ids'].to(device), attention_mask=inputs['attention_mask'].to(device))
predicted = torch.argmax(outputs[0], dim=1)
for i in range(len(new_names)):
  print(f"{new_names[i]}: {one_hot_train.columns[predicted[i].item()]}")

In [None]:
# Make confusion matrix
confusion_matrix = torch.zeros(len(one_hot_test.columns), len(one_hot_test.columns))
for step, batch in enumerate(data_loader_test):
    input_ids, attention_mask, labels = batch
    input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
    outputs = model(input_ids, attention_mask=attention_mask)
    predicted = torch.argmax(outputs[0], dim=1)
    for i in range(len(predicted)):
        confusion_matrix[torch.argmax(labels[i])][predicted[i]] += 1

print(confusion_matrix)

# Plot confusion matrix
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
plt.figure(figsize=(10,10))
sns.heatmap(confusion_matrix, annot=True, fmt=".0f", linewidths=.5, square = True, cmap = 'Blues_r', xticklabels=one_hot_train.columns, yticklabels=one_hot_train.columns)
plt.ylabel('Actual label')
plt.xlabel('Predicted label')
plt.show()