In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import os
import json
import wandb
from tqdm import tqdm
from datetime import datetime

os.environ["TOKENIZERS_PARALLELISM"] = "false"
# wandb.init(project="nlp_project_cls", name="bert base uncased test run")

# PATHS
DEV_CLAIMS_BASELINE_JSON_PATH = "/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/data/dev-claims-baseline.json"
DEV_CLAIMS_JSON_PATH = "/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/data/dev-claims.json"
EVIDENCE_JSON_PATH = "/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/data/evidence.json"
SMALL_EVIDENCE_JSON_PATH = "/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/data/small_evidence.json"
TINY_EVIDENCE_JSON_PATH = "/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/data/tiny_evidence.json"
CODE_DEV_EVIDENCE_JSON_PATH = "/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/data/code_dev_evidence.json"
TEST_CLAIMS_UNLABELLED_JSON_PATH = "/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/data/test-claims-unlabelled.json"
TEST_CLAIMS_LABELLED_JSON_PATH = "/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/data/test-claims-labelled.json"
TRAIN_CLAIMS_JSON_PATH = "/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/data/train-claims.json"
RETRIEVAL_TEST_CLAIMS_JSON_PATH = "/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/data/retrieval-test-claims.json"

# ARGS
BATCH_SIZE = 4
EPOCH = 1
MODEL_NAME = "distilbert-base-uncased"
MAX_LR = 2e-5
MAX_LENGTH = 256
RETRIEVAL_NUM = 3
K = 4

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("using cuda")
else:
    if torch.backends.mps.is_available():
        device = torch.device("mps")
        print("using mps")
    else:
        device = torch.device("cpu")
        print("using cpu")


LABEL_TO_INT = {
                'NOT_ENOUGH_INFO': 0,
                'DISPUTED': 1,
                'REFUTES': 2,
                'SUPPORTS': 3,
                    }

INT_TO_LABEL = {
                0: 'NOT_ENOUGH_INFO',
                1: 'DISPUTED',
                2: 'REFUTES',
                3: 'SUPPORTS',
                }

class NNClassifier(nn.Module):
    def __init__(self, pre_encoder):

        super(NNClassifier, self).__init__()
        self.encoder = AutoModel.from_pretrained(pre_encoder)
        self.cls = nn.Sequential(
            nn.Linear(self.encoder.config.hidden_size, 256),
            nn.Tanh(),
            nn.Linear(256, 4)
        )

    def forward(self, batch):
        input_ids=batch["batched_input_ids"]
        attention_mask=batch["batched_attention_mask"]
        texts_emb = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        texts_emb = texts_emb[:, 0, :]
        res = self.cls(texts_emb)

        return res

# bert model
model = NNClassifier(MODEL_NAME)
model = model.to(device)  # Move model to device
loss_function = nn.CrossEntropyLoss(label_smoothing=0.1)

# Instantiate the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=MAX_LR, weight_decay=1e-4)
optimizer.zero_grad()

# Instantiate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

load = True
if load:
    model.load_state_dict(torch.load(os.path.join("/Users/taylortang/Life-at-UniMelb/Semester_3/COMP90042_NLP/Project_2/code/rtv/model/cls_15_05_2023/cls_checkpoint.bin")))
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def process(sentence):
    sentence = sentence.lower()
    cleaned = ''.join([char if char.isalnum() or char.isspace() else '' for char in sentence])
    return cleaned

# small util
def tokenization(text):
    tokens = tokenizer(text, max_length=MAX_LENGTH, padding=True, return_tensors="pt", truncation=True)
    return tokens

# small util
def makedir(sub_dir):
    date = datetime.now().strftime("%d_%m_%Y")
    save_dir = f"./model/{sub_dir}_{date}"
    os.makedirs(save_dir, exist_ok=True)

    return save_dir

def concat_texts(claim, evidences):
    claim_text = [process(claim["claim_text"])]
    if "evidences" in claim.keys():
        for evidence_id in claim["evidences"]:
            claim_text.append(process(evidences[evidence_id]))
    return claim_text


def collate_fn(batch):
    batch_encoding = dict()
    claim_texts, labels, claims, claim_ids = zip(*batch)
    tokens = tokenization(claim_texts)
    batch_encoding["batched_input_ids"] = tokens.input_ids
    batch_encoding["batched_attention_mask"] = tokens.attention_mask
    batch_encoding["claims"] = claims
    batch_encoding["claim_ids"] = claim_ids
    batch_encoding["label"] = labels

    return batch_encoding

class ClassificationTrainDataset(Dataset):
    def __init__(self):
        with open(TRAIN_CLAIMS_JSON_PATH, "r") as file:
            self.data = json.load(file)

        with open(TRAIN_CLAIMS_JSON_PATH, "r") as file:
            self.evidences = json.load(file)

        self.claim_ids = list(self.data.keys())

    def __len__(self):
        return len(self.claim_ids)

    def __getitem__(self, id):
        claim_id = self.all_claim_keys[id]
        claim = self.data.get(claim_id)
        claim_text = concat_texts(claim, self.evidences)
        label = LABEL_TO_INT[claim["claim_label"]]

        return [claim_text, label, claim, self.claim_ids[id]]
    
class ClassificationValidationDataset(Dataset):
    def __init__(self):
        with open(DEV_CLAIMS_JSON_PATH, "r") as file:
            self.data = json.load(file)

        with open(TRAIN_CLAIMS_JSON_PATH, "r") as file:
            self.evidences = json.load(file)

        self.claim_ids = list(self.data.keys())

    def __len__(self):
        return len(self.claim_ids)

    def __getitem__(self, id):
        claim_id = self.all_claim_keys[id]
        claim = self.data.get(claim_id)
        claim_text = concat_texts(claim, self.evidences)
        label = LABEL_TO_INT[claim["claim_label"]]

        return [claim_text, label, claim, self.claim_ids[id]]
    
class ClassificationTestDataset(Dataset):
    def __init__(self):
        with open(TEST_CLAIMS_UNLABELLED_JSON_PATH, "r") as file:
            self.data = json.load(file)

        with open(EVIDENCE_JSON_PATH, "r") as file:
            self.evidences = json.load(file)

        self.claim_ids = list(self.data.keys())

    def __len__(self):
        return len(self.claim_ids)

    def __getitem__(self, id):
        claim_id = self.all_claim_keys[id]
        claim = self.data.get(claim_id)
        claim_text = [process(claim["claim_text"])]
        claim_text = concat_texts(claim, self.evidences)
        label = LABEL_TO_INT[claim["claim_label"]]

        return [claim_text, label, claim, self.claim_ids[id]]

def predict():
    test_set = ClassificationTestDataset()
    dataloader = DataLoader(test_set, batch_size=BATCH_SIZE, collate_fn=collate_fn)
    model.eval()

    prediction = {}
    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
            output = model(batch)
            predicted_labels = output.argmax(-1)
            predicted_labels = predicted_labels.tolist()

            for index, (data, predicted_label) in enumerate(zip(batch["datas"], predicted_labels)):
                string_label = INT_TO_LABEL[predicted_label]
                data["claim_label"] = string_label
                claim_id = batch["claim_ids"][index]
                prediction[claim_id] = data

        with open(TEST_CLAIMS_UNLABELLED_JSON_PATH, 'w') as file:
            json.dump(prediction, file)
    
def validate(dataloader, model):
    model.eval()  # switch model to the evaluation mode

    total_examples = 0.0
    total_correct = 0.0

    for data_batch in tqdm(dataloader):
        data_batch = {key: value.to(device) if torch.is_tensor(value) else value for key, value in data_batch.items()}

        outputs = model(data_batch)
        predicted_labels = outputs.argmax(-1)

        # Check for each label if the prediction matches the true label
        is_prediction_correct = []
        for i in range(len(predicted_labels)):
            if predicted_labels[i] == data_batch["label"][i]:
                is_prediction_correct.append(True)
            else:
                is_prediction_correct.append(False)

        # Sum up the number of correct predictions
        sum_of_correct_predictions = sum(is_prediction_correct)  # Corrected here

        total_correct += sum_of_correct_predictions
        total_examples += predicted_labels.size(0)

    accuracy = total_correct / total_examples
    model.train()  # switch model back to training mode

    return accuracy

def train(train_dataloader, val_dataloader, save_dir):
    model.train()
    maximum_accuracy = 0

    for i in range(EPOCH):
        for idx, batch in enumerate(tqdm(train_dataloader)):
            batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}

            res = model(batch)
            loss = loss_function(res, batch["label"])
            loss = loss / BATCH_SIZE
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        # wandb.log({"accuracy": acc}, step=all_step_cnt)

        accuracy = validate(val_dataloader, model)
        if accuracy > maximum_accuracy:
            maximum_accuracy = accuracy
            torch.save(model.state_dict(), os.path.join(save_dir, "cls_checkpoint.bin"))
            print("mmaximum_accuracy", accuracy)

In [None]:
save_dir = makedir("cls")

train_set = ClassificationTrainDataset()
val_set = ClassificationValidationDataset()

train_dataloader = DataLoader(train_set, batch_size=BATCH_SIZE, collate_fn=collate_fn)
val_dataloader = DataLoader(val_set, batch_size=BATCH_SIZE, collate_fn=collate_fn)

train(train_dataloader, val_dataloader, save_dir)
predict = False
if predict:
    predict()