In [16]:
# 1. Importing dataset

# Command to install all packages for project 
# pip install transformers datasets pandas torch

from datasets import load_dataset
import pandas as pd

ds = load_dataset("SoccerNet/SN-echoes", "whisper_v1")

# en_dataset purely used just to see the structure of the data
en_dataset = pd.DataFrame(ds['en'])

print(en_dataset.head())

   segment_index  start_time   end_time  \
0              0        0.00   3.000000   
1              1        3.00   8.240000   
2              2        8.24   9.680000   
3              3        9.68  11.400000   
4              4       11.40  17.879999   

                                                text  \
0  The duel has already started, Barley handles t...   
1  It must be said that they also faced each othe...   
2                                   Chelsea won 1-3.   
3                           The Barley came forward.   
4  You have to remember that he got ahead in that...   

                                                game  
0  england_epl/2014-2015/2015-02-21 - 18-00 Chels...  
1  england_epl/2014-2015/2015-02-21 - 18-00 Chels...  
2  england_epl/2014-2015/2015-02-21 - 18-00 Chels...  
3  england_epl/2014-2015/2015-02-21 - 18-00 Chels...  
4  england_epl/2014-2015/2015-02-21 - 18-00 Chels...  


In [17]:
# 2. Define Label Categories
labels = {
    'goal': 0,
    'foul': 1,
    'offside': 2,
    'substitution': 3,
    'yellow_card': 4,
    'red_card': 5,
}

In [18]:
# 3. Tokenize text using BERT
from transformers import BertTokenizer

# Load the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenize the text in your dataset
def tokenize_function(examples):
    return tokenizer(examples['text'], padding='max_length', truncation=True)

# Tokenize the entire dataset
tokenized_dataset = ds['en'].map(tokenize_function, batched=True)

In [19]:
# 4. Creating Input Features
# Example of how to convert 'text' and 'labels' into features
def add_labels(example):
    # Example of assigning labels based on 'text' content
    if 'goal' in example['text'].lower():
        example['label'] = labels['goal']
    elif 'foul' in example['text'].lower():
        example['label'] = labels['foul']
    # Add more rules for other labels as necessary
    else:
        example['label'] = -1  # If no match, mark as unknown
    return example

# Apply the labeling function
labeled_dataset = tokenized_dataset.map(add_labels)

In [20]:
# 5. Training the dataset
from datasets import DatasetDict

# Split the dataset (e.g., 80% training, 20% validation)
split_dataset = labeled_dataset.train_test_split(test_size=0.2)

# DatasetDict containing training and validation data
train_dataset = split_dataset['train']
val_dataset = split_dataset['test']

In [21]:
# 6. Fine-tuning the pre-trained BERT model
from transformers import BertForSequenceClassification, Trainer, TrainingArguments

# Load a pre-trained BERT model with a classification head
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(labels))

# Set up training arguments
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

# Train the model
trainer.train()


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ImportError: Using the `Trainer` with `PyTorch` requires `accelerate>=0.26.0`: Please run `pip install transformers[torch]` or `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`

In [None]:
eval_results = trainer.evaluate()
print(eval_results)

predictions = trainer.predict(val_dataset)
print(predictions)