# BART_fine-tune

This is the code to fine-tune the [**facebook/bart-base**](https://huggingface.co/facebook/bart-base) pre-train language model by [**CLOTH**](https://www.cs.cmu.edu/~glai1/data/cloth/) or [**DGen**](https://github.com/DRSY/DGen) datasets.

* Paper: "CDGP: Automatic Cloze Distractor Generation based on Pre-trained Language Model"
* Author: AndyChiangSH
* Time: 2022/10/15
* GitHub: https://github.com/AndyChiangSH/CDGP

## Download datasets

### CLOTH

In [None]:
!wget https://github.com/AndyChiangSH/CDGP/raw/main/datasets/CLOTH.zip

In [None]:
!unzip ./CLOTH.zip -d ./CLOTH

### DGen

In [None]:
!wget https://github.com/AndyChiangSH/CDGP/raw/main/datasets/DGen.zip

In [None]:
!unzip ./DGen.zip -d ./DGen

## Data preprocessing

### CLOTH

In [None]:
import json

with open("./CLOTH/CLOTH_train_cleaned.json", "r") as file:
    dataset = json.load(file)

print(len(dataset))
print(dataset[0])

### DGen

In [None]:
import json

with open("./DGen/DGen_train_cleaned.json", "r") as file:
    dataset = json.load(file)

print(len(dataset))
print(dataset[0])

### Data masking

In [None]:
from tqdm.notebook import tqdm
import os

input_list = list()
label_list = list()

for data in tqdm(dataset):
  answer = data["answer"]
  distractors = data["distractors"]
  sentence = data["sentence"]
  mask_sentence = sentence.replace("**blank**", "<mask>")
  mask_sentence += " </s> " + answer
  for distractor in distractors:
    dis_sentence = mask_sentence.replace("<mask>", distractor)
    input_list.append(mask_sentence)
    label_list.append(dis_sentence)

In [None]:
print("input_list:", len(input_list))
print(input_list[:10])

In [None]:
print("label_list:", len(label_list))
print(label_list[:10])

## Fine-tune RoBERTa

In [None]:
!pip install transformers datasets

In [None]:
PLM = "facebook/bart-base"
BATCH_SIZE = 64
EPOCH = 1
LR = 0.0001
MAX_LENGTH = 64

### Setup the Dataset

In [None]:
data_dic = {"input": input_list, "label": label_list}

In [None]:
from datasets import Dataset

dataset = Dataset.from_dict(data_dic)

In [None]:
print(len(dataset))

### Setup the DataLoader

In [None]:
from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
print(len(dataloader))

### Fine-tune the model

In [None]:
from transformers import BartTokenizer, BartForConditionalGeneration
import torch

tokenizer = BartTokenizer.from_pretrained(PLM)
model = BartForConditionalGeneration.from_pretrained(PLM)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
print(device)

In [None]:
# progrss bar
num_training_steps = EPOCH * len(dataloader)
progress_bar = tqdm(range(num_training_steps))

# start training
loss_history = []
for epoch in range(EPOCH):
  for batch in dataloader:
    inputs = tokenizer.batch_encode_plus(batch["input"], truncation=True, padding="max_length", max_length=MAX_LENGTH, return_tensors="pt")
    labels = tokenizer.batch_encode_plus(batch["label"], truncation=True, padding="max_length", max_length=MAX_LENGTH, return_tensors="pt")["input_ids"]

    output = model(**inputs.to(device), labels=labels.to(device))
    optimizer.zero_grad()
    loss = output.loss
    logits = output.logits
    loss_history.append(loss.item())
    loss.backward()
    optimizer.step()
    progress_bar.update(1)
  
  print(f"[epoch {epoch+1}] loss: {loss.item()}")

### Show the loss line chart

In [None]:
print(loss_history)
print(len(loss_history))

In [None]:
# paint training loss graph
import matplotlib.pyplot as plt

plt.plot(loss_history)
plt.title('Training loss')
plt.ylabel('loss')
plt.xlabel('batch')
plt.legend(['loss'], loc='upper right')
plt.show()

### Save the model

In [None]:
model_to_save = model.module if hasattr(model, 'module') else model
model_to_save.save_pretrained("./cdgp-csg-bart-dgen")

### Delete the model

In [None]:
del model
del model_to_save
torch.cuda.empty_cache()

## Testing

### Testing data

In [None]:
questions = {
    "q1": {
        "sentence": "To make Jane live a <mask> life, Mother was very careful about spending money. </s> happy",
        "answer": "happy",
        "distractors": ["poor", "busy", "sad"]
    },
    "q2": {
        "sentence": "<mask> , Jane didn't understand her. </s> However",
        "answer": "However",
        "distractors": ["Though", "Although", "Or"]
    },
    "q3": {
        "sentence": "Every day Mother was busy with her <mask> while Jane was studying at school, so they had little time to enjoy themselves. </s> work",
        "answer": "work",
        "distractors": ["writing", "housework", "research"]
    },
    "q4": {
        "sentence": "One day, Mother realized Jane was unhappy and even <mask> to her. </s> unfriendly",
        "answer": "unfriendly",
        "distractors": ["loyal", "kind", "cruel"]
    },
    "q5": {
        "sentence": "The old man was waiting for a ride across the <mask> . </s> river",
        "answer": "river",
        "distractors": ["town", "country", "island"]
    },
    "q6": {
        "sentence": "I felt uncomfortable and out of place as the professor carefully <mask> what she expected us to learn. </s> explained",
        "answer": "explained",
        "distractors": ["showed", "designed", "offered"]
    },
    "q7": {
        "sentence": "As I listened, I couldn't help but <mask> of my own oldest daughter. </s> think",
        "answer": "think",
        "distractors": ["speak", "talk", "hear"]
    },
    "q8": {
        "sentence": "As we were <mask> on the third floor for old people with Alzheimer, most of them stared off at the walls or floor. </s> singing",
        "answer": "singing",
        "distractors": ["meeting", "gathering", "dancing"]
    },
    "q9": {
        "sentence": "As we got <mask> with each song, she did as well. </s> louder",
        "answer": "louder",
        "distractors": ["higher", "nearer", "faster"]
    },
    "q10": {
        "sentence": "Mr. Petri, <mask> injured in the fire, was rushed to hospital. </s> seriously",
        "answer": "seriously",
        "distractors": ["blindly", "hardly", "slightly"]
    },
    "q11": {
        "sentence": "If an object is attracted to a magnet, the object is most likely made of <mask>. </s> metal",
        "answer": "metal",
        "distractors": ["wood", "plastic", "cardboard"]
    },
    "q12": {
        "sentence": "the main organs of the respiratory system are <mask>. </s> lungs",
        "answer": "lungs",
        "distractors": ["ovaries", "intestines", "kidneys"]
    },
    "q13": {
        "sentence": "The products of photosynthesis are glucose and <mask> else. </s> oxygen",
        "answer": "oxygen",
        "distractors": ["carbon", "hydrogen", "nitrogen"]
    },
    "q14": {
        "sentence": "frogs have <mask> eyelid membranes. </s> three",
        "answer": "three",
        "distractors": ["two", "four", "one"]
    },
    "q15": {
        "sentence": "the only known planet with large amounts of water is <mask>. </s> earth",
        "answer": "earth",
        "distractors": ["saturn", "jupiter", "mars"]
    },
    "q16": {
        "sentence": "<mask> is responsible for erosion by flowing water and glaciers. </s> gravity",
        "answer": "gravity",
        "distractors": ["kinetic", "electromagnetic", "weight"],
    },
    "q17": {
        "sentence": "Common among mammals and insects , pheromones are often related to <mask> type of behavior. </s> reproductive",
        "answer": "reproductive",
        "distractors": ["aggressive", "immune", "cardiac"]
    },
    "q18": {
        "sentence": "<mask> can reproduce by infecting the cell of a living host. </s> virus",
        "answer": "virus",
        "distractors": ["bacteria", "mucus", "carcinogens"]
    },
    "q19": {
        "sentence": "proteins are encoded by <mask>. </s> genes",
        "answer": "genes",
        "distractors": ["DNA", "RNA", "codons"]
    },
    "q20": {
        "sentence": "Producers at the base of ecological food webs are also known as <mask>. </s> autotrophic",
        "answer": "autotrophic",
        "distractors": ["endoscopic", "symbiotic", "mutualistic"],
    },
    "q21": {
        "sentence": "Today morning, I saw a <mask> sitting on the wall. </s> cat",
        "answer": "cat",
        "distractors": [],
    },
    "q22": {
        "sentence": "Ukrainian presidential adviser says situation is ' <mask> control' in suburbs and outskirts of Kyiv. </s> under",
        "answer": "under",
        "distractors": [],
    },
    "q23": {
        "sentence": "I don't think that after what is <mask> now, Ukraine has weak positions. </s> happening",
        "answer": "happening",
        "distractors": [],
    },
}

### Load the model

In [None]:
from transformers import BartTokenizer, BartForConditionalGeneration

tokenizer = BartTokenizer.from_pretrained(PLM)
model = BartForConditionalGeneration.from_pretrained("./cdgp-csg-bart-dgen")
model.eval()

### Generate distractors

In [None]:
from transformers import pipeline

unmasker = pipeline("fill-mask", tokenizer=tokenizer, model=model, top_k=10)

In [None]:
unmasker(questions["q1"]["sentence"])