<a href="https://colab.research.google.com/github/JasonLaux/nlp/blob/main/bert_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

The configuration on the model uses 10-bert.ipython for reference...
This file is used to train the classification model and save it for the analysis in Task 2. The trained model is named as sstcls.pth. The file initially runs on the Google Colab.

In [None]:
from google.colab import drive
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
import torch.optim as optim
import time
from torch.utils.data import Dataset, DataLoader
import json
from collections import OrderedDict, deque
import datetime
import gc
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
!pip install torch torchvision transformers

In [None]:
# The path './gdrive/MyDrive/nlp' refers to the location on Google Drive. 
DATA_TRAIN_PATH = './gdrive/MyDrive/nlp/train.data.jsonl'
LABEL_TRAIN_PATH = './gdrive/MyDrive/nlp/train.label.json'
DATA_DEV_PATH = './gdrive/MyDrive/nlp/dev.data.jsonl'
LABEL_DEV_PATH = './gdrive/MyDrive/nlp/dev.label.json'
DATA_TEST_PATH = './gdrive/MyDrive/nlp/test.data.jsonl'

In [None]:
# # Use this path configuration instead if running locally...
# DATA_TRAIN_PATH = './data/train.data.jsonl'
# LABEL_TRAIN_PATH = './data/train.label.json'
# DATA_DEV_PATH = './data/dev.data.jsonl'
# LABEL_DEV_PATH = './data/dev.label.json'
# DATA_TEST_PATH = './data/test.data.jsonl'

In [None]:
'''
Define the graph struction to sort replies in a reasonable order. The mechanism 
is explained in the report.
'''
class DAG(object):
    """ Directed acyclic graph implementation. """

    def __init__(self):
        """ Construct a new DAG with no nodes or edges. """
        self.node_depth = []
        self.graph = OrderedDict()

    def add_node(self, node: tuple):
        """ Add a node if it does not exist yet, or error out. """
        if node in self.graph:
            raise KeyError('node %s already exists' % node.index)
        self.graph[node] = set()

    def add_edge(self, start_node: tuple, end_node: tuple):
        """ Add an edge (dependency) between the specified nodes. """

        if start_node not in self.graph or end_node not in self.graph:
            raise KeyError("Node is not existed in the graph.")

        self.graph[start_node].add(end_node)

    def sort_children(self):
      """ Sort children node by time in ascending order. """
        for key in self.graph:
            self.graph[key] = sorted(self.graph[key], key=lambda item: item[1])

    def topological_sort(self):
        """ Returns a topological ordering of the DAG.
        Raises an error if this is not possible (graph is not valid).
        """

        in_degree = {}
        for u in self.graph:
            in_degree[u] = 0

        for u in self.graph:
            for v in self.graph[u]:
                in_degree[v] += 1

        queue = deque()
        for u in in_degree:
            if in_degree[u] == 0:
                queue.appendleft(u)

        l = []
        while queue:
            u = queue.pop()
            l.append(u)
            for v in self.graph[u]:
                in_degree[v] -= 1
                if in_degree[v] == 0:
                    queue.appendleft(v)

        if len(l) == len(self.graph):
            return l
        else:
            raise ValueError('graph is not acyclic')

def create_graph(items):
    graph = DAG()
    root_time = ""
    idStr_idx = {}

    for idx, item in enumerate(items):
        if idx == 0:
            root_time = item["created_at"]
        commit_time = item["created_at"]
        idStr_idx.update({item["id_str"]: idx})  # Assuming every tweet is unique
        graph.add_node((idx, calc_time_diff(root_time, commit_time)))

    keys = list(graph.graph.copy())

    for idx, item in enumerate(items):
        parent_idx = idStr_idx.get(str(item["in_reply_to_status_id_str"]))
        if parent_idx is not None:
            graph.add_edge(keys[parent_idx], keys[idx])

    graph.sort_children()

    return graph


def calc_time_diff(start_time: str, end_time: str):
    start_time_formatted = datetime.datetime.strptime(start_time, "%a %b %d %H:%M:%S %z %Y")
    end_time_formatted = datetime.datetime.strptime(end_time, "%a %b %d %H:%M:%S %z %Y")
    return (end_time_formatted - start_time_formatted).total_seconds()

In [None]:
class TweetDataset(Dataset):

    def __init__(self, fn_data, fn_label=None, maxlen=256):
        # Store the contents of the file in a pandas dataframe
        self.data = open(fn_data, encoding="utf-8").readlines()
        if fn_label is not None:
            self.label_dict = json.load(open(fn_label, encoding="utf-8"))
        else:
            self.label_dict = None
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        # Initialize the BERT tokenizer
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        self.maxlen = maxlen  # the max length of the sentence in the corpus

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        items = json.loads(self.data[index])
        id_str = items[0]["id_str"]  # Tweet index
        idx_list = [pair[0] for pair in create_graph(items).topological_sort()]
        tokens_list = []
        for idx in idx_list:
            username = items[idx]["user"]["name"]
            text = items[idx]["text"]
            current_sentence = username + ":" + text
            tokens_list.append(self.tokenizer.tokenize(current_sentence))

        tokens_concat = ['[CLS]'] + [token for item in tokens_list for token in item] + ['[SEP]']
        if len(tokens_concat) < self.maxlen:
            padded_tokens = tokens_concat + ['[PAD]' for _ in range(self.maxlen - len(tokens_concat))]
        else:
            padded_tokens = tokens_concat[:self.maxlen - 1] + ['[SEP]']

        tokens_ids = self.tokenizer.convert_tokens_to_ids(padded_tokens)
        attn_mask = [1 if token != '[PAD]' else 0 for token in padded_tokens]
        tokens_ids_tensor = torch.tensor(tokens_ids)
        attn_masks_tensor = torch.tensor(attn_mask)
        label_idx = torch.tensor(-1)
        if self.label_dict:
            label = self.label_dict[id_str]  # Tweet label
            if label == "non-rumour":
                label_idx = torch.tensor(1)
            else:
                label_idx = torch.tensor(0)
        return tokens_ids_tensor, attn_masks_tensor, label_idx, id_str

In [None]:
class RumourClassifier(nn.Module):

    def __init__(self):
        super(RumourClassifier, self).__init__()
        # Instantiating BERT model object
        self.bert_layer = BertModel.from_pretrained('bert-base-uncased')

        # Classification layer
        # input dimension is 768 because [CLS] embedding has a dimension of 768
        # output dimension is 1 because we're working with a binary classification problem
        self.cls_layer = nn.Linear(768, 1)

    def forward(self, tokens_ids, attn_masks):
        '''
        Inputs:
            -seq : Tensor of shape [B, T] containing token ids of sequences
            -attn_masks : Tensor of shape [B, T] containing attention masks to be used to avoid contibution of PAD tokens
        '''

        # Feeding the input to BERT model to obtain contextualized representations
        outputs = self.bert_layer(input_ids=tokens_ids, attention_mask=attn_masks)
        cont_reps = outputs.last_hidden_state

        # Obtaining the representation of [CLS] head (the first token)
        cls_rep = cont_reps[:, 0]

        # Feeding cls_rep to the classifier layer
        logits = self.cls_layer(cls_rep)

        return logits


def get_accuracy_from_logits(logits, labels):
    probs = torch.sigmoid(logits.unsqueeze(-1))
    soft_probs = (probs > 0.5).long()
    acc = (soft_probs.squeeze() == labels).float().mean()
    return acc


def evaluate(net, criterion, dataloader, gpu):
    net.eval()

    mean_acc, mean_loss = 0, 0
    count = 0

    with torch.no_grad():
        for seq, attn_masks, labels, _ in dataloader:
            seq, attn_masks, labels = seq.cuda(gpu), attn_masks.cuda(gpu), labels.cuda(gpu)
            logits = net(seq, attn_masks)
            mean_loss += criterion(logits.squeeze(-1), labels.float()).item()
            mean_acc += get_accuracy_from_logits(logits, labels)
            count += 1
    
    if count == 0:
      raise KeyError("Dataloader is loaded incorrectly!")
    else:
      return mean_acc / count, mean_loss / count


def train(net, criterion, opti, train_loader, dev_loader, max_eps, gpu):
    best_acc = 0
    st = time.time()
    for ep in range(max_eps):

        for it, (seq, attn_masks, labels, _) in enumerate(train_loader):
            # Clear gradients
            opti.zero_grad()
            # Converting these to cuda tensors
            seq, attn_masks, labels = seq.cuda(gpu), attn_masks.cuda(gpu), labels.cuda(gpu)

            # Obtaining the logits from the model
            logits = net(seq, attn_masks)

            # Computing loss
            loss = criterion(logits.squeeze(-1), labels.float())

            # Backpropagating the gradients
            loss.backward()

            # Optimization step
            opti.step()

            if it % 100 == 0:
                acc = get_accuracy_from_logits(logits, labels)
                print("Iteration {} of epoch {} complete. Loss: {}; Accuracy: {}; Time taken (s): {}"
                      .format(it, ep, loss.item(), acc, (time.time() - st)))
                st = time.time()

        dev_acc, dev_loss = evaluate(net, criterion, dev_loader, gpu)
        print("Epoch {} complete! Development Accuracy: {}; Development Loss: {}".format(ep, dev_acc, dev_loss))
        if dev_acc > best_acc:
            print("Best development accuracy improved from {} to {}, saving model...".format(best_acc, dev_acc))
            best_acc = dev_acc
            torch.save(net.state_dict(), 'sstcls.pth'.format(ep))
    return net

In [None]:
def predict(net, dataloader, gpu):
    net.eval()
    dict_pred = {}
    with torch.no_grad():
        for seq, attn_masks, _, id_str in dataloader:
            seq, attn_masks = seq.cuda(gpu), attn_masks.cuda(gpu)
            logits = net(seq, attn_masks)
            probs = torch.sigmoid(logits.unsqueeze(-1))
            y_pred = (probs > 0.5).long().squeeze()
            if y_pred == torch.tensor(0):
                label = "rumour"
            else:
                label = "non-rumour"
            dict_pred.update({id_str[0]: label})
    with open("test_label.json", "w+") as f:
        json.dump(dict_pred, f)
    return dict_pred

In [None]:
gpu = 0
print("Creating the classifier, initialised with pretrained BERT-BASE parameters...")
net = RumourClassifier()
net.cuda(gpu)  # Enable gpu support for the model
print("Done creating the classifier.")

# Define loss function based on binary cross-entropy.
criterion = nn.BCEWithLogitsLoss()
opti = optim.Adam(net.parameters(), lr=2e-5)
num_epoch = 3 

# The maxlen is limited to 256 due to the out-of-memory issue
train_dataset = TweetDataset(DATA_TRAIN_PATH, LABEL_TRAIN_PATH, maxlen=256)
dev_dataset = TweetDataset(DATA_DEV_PATH, LABEL_DEV_PATH, maxlen=256)
train_dataloader = DataLoader(train_dataset, batch_size=10)
dev_dataloader = DataLoader(dev_dataset, batch_size=10)
test_dataset = TweetDataset(DATA_TEST_PATH, maxlen=256)
test_dataloader = DataLoader(test_dataset, batch_size=1)


net_trained = train(net, criterion, opti, train_dataloader, dev_dataloader, num_epoch, gpu)
print(predict(net_trained, test_dataloader, gpu))

Creating the classifier, initialised with pretrained BERT-BASE parameters...
Done creating the classifier.
Iteration 0 of epoch 0 complete. Loss: 0.9059423804283142; Accuracy: 0.30000001192092896; Time taken (s): 1.0903050899505615
Iteration 100 of epoch 0 complete. Loss: 0.7962539196014404; Accuracy: 0.4000000059604645; Time taken (s): 93.4114236831665
Iteration 200 of epoch 0 complete. Loss: 0.292822927236557; Accuracy: 0.9000000357627869; Time taken (s): 93.8758008480072
Iteration 300 of epoch 0 complete. Loss: 0.5758542418479919; Accuracy: 0.800000011920929; Time taken (s): 94.28265810012817
Iteration 400 of epoch 0 complete. Loss: 0.6749157905578613; Accuracy: 0.800000011920929; Time taken (s): 93.87975454330444
Epoch 0 complete! Development Accuracy: 0.8758621215820312; Development Loss: 0.29929105385110294
Best development accuracy improved from 0 to 0.8758621215820312, saving model...
Iteration 0 of epoch 1 complete. Loss: 0.38689208030700684; Accuracy: 0.800000011920929; Time 