In [1]:
import json
import os
import torch
import torch.nn as nn
import re

from transformers import BertTokenizer, BertTokenizerFast, BertForMaskedLM, BertModel, BertConfig,BertForSequenceClassification
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

## 0. Initialize params

In [2]:
'''
gpu/cpu
'''
cuda_index=2
if torch.cuda.is_available():
    device = torch.device('cuda:'+str(cuda_index))
else:
    device = torch.device('cpu')

'''
modeling
'''

batch_size = 64
shuffle = True

## 1. Load and explore the data

In [3]:
def load_json(path, file):
    with open(os.path.join(path, file), 'r') as json_file:
        data = json.load(json_file)
    return data

In [4]:
'''
read config file
'''
config = load_json('./', 'config.json')
config

{'model_name': 'xlm-roberta-base',
 'train_file': 'Train_Numerical_Reasoning.json',
 'dev_file': 'Dev_Numerical_Reasoning.json',
 'path': '../NumEval - Task 3/'}

In [5]:
path=config['path']
train_file=config['train_file']
dev_file=config['dev_file']

In [6]:
'''
read train and dev files
'''
train_data = load_json(path, train_file)
dev_data = load_json(path, dev_file)

In [7]:
train_data[0]

{'news': "(Oct 7, 2014  12:40 PM CDT) As of Jan. 1, Walmart will no longer offer 30,000 of its employees health insurance. Bloomberg notes that's about 2% of its workforce. The move comes as a reaction to the company's rising health care costs as far more of its employees and their families enrolled in its health care plans than it had expected following the ObamaCare rollout. The AP reports those costs will surge $500 million this fiscal year, $170 million more than had been estimated. Those affected are employees who average fewer than 30 hours of work per week; the Wall Street Journal explains they were grandfathered in when Walmart in 2012 stopped offering insurance to new hires who didn't exceed the 30-hour threshold. A benefits expert says Walmart is actually late to the game in terms of cutting insurance to some part-time workers; Target, the Home Depot, and others have already done so. Meanwhile, Walmart's full time workers will see their premiums rise in 2015. Premiums for the

In [8]:
train_data[1]

{'news': "(Oct 29, 2013  8:15 AM CDT) Dax Shepard and Kristen Bell got married at the Beverly Hills courthouse, in a ceremony about as different from Kim Kardashian's last wedding extravaganza as it is possible to be. As Shepard revealed last night on Jimmy Kimmel Live, the whole thing—including the fuel it took to get to the courthouse—cost $142.  It was just Kristen and I at this lonely courthouse,  he said, so friends showed up afterward with a cake reading, in icing,  The World's Worst Wedding.   How many people can say they threw the world's worst wedding?  Shepard asked.",
 'masked headline': 'Dax Shepard: Wedding to Kristen Bell Cost $____',
 'calculation': 'Copy(142)',
 'ans': '142'}

In [18]:
def create_gt(seqs):
    
    target = []
    number_type = []
    number_gt = []
    
    for ind, data in tqdm(enumerate(seqs)):
        ans = str(data['ans'])
        numtype = data['calculation']
        stmt = data['masked headline'].replace('____', ans)
        
        if "copy" in numtype.lower():
            calc = 1
        else:
            calc = 0
            
        target.append(stmt)
        number_gt.append(ans)
        number_type.append(calc)
    return target, number_type, number_gt

In [19]:
train_target, train_number_type, train_number_gt = create_gt(train_data)

21157it [00:00, 676358.33it/s]


In [20]:
dev_target, dev_number_type, dev_number_gt = create_gt(dev_data)

2572it [00:00, 513018.35it/s]


In [23]:
# Specify the file path
target_file_path = 'target.txt'
type_path = 'number_type.txt'
gt_path = 'number_gt.txt'

def write(tgt, numtype, numgt, trainval):    
    # Open the file in write mode
    with open(trainval+'_target.txt', 'w') as file:
        # Write each element of the list to a new line
        for item in tgt:
            file.write(f"{item}\n")
    
    with open(trainval+'_number_type.txt', 'w') as file:
        # Write each element of the list to a new line
        for item in numtype:
            file.write(f"{item}\n")
    
    with open(trainval+'_number_gt.txt', 'w') as file:
        # Write each element of the list to a new line
        for item in numgt:
            file.write(f"{item}\n")
            
write(train_target, train_number_type, train_number_gt, 'train')
write(dev_target, dev_number_type, dev_number_gt, 'dev')

## 2. Data processing and tokenization

In [9]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

In [10]:
MASK_TOKEN = tokenizer.mask_token
MASK_TOKEN

'[MASK]'

In [11]:
def process_data(sample, replace_token='mask', task='train'):
    '''
    teacher forcing only during training, hence reasoning prompt would be prepended only to the train samples
    '''
    
    news = sample['news']
    masked_headline = sample['masked headline']
    calculation = sample['calculation']
    ans = str(sample['ans'])
    
    if replace_token=='mask':
        replace_token=MASK_TOKEN
    else:
        replace_token=ans
        
    if task=='train':
        input_prompt = "Given the news article, perform " + calculation + " to fill in the mask token : " + "\n" + news + " " + masked_headline.replace('____', replace_token)
    else:
        input_prompt = "Given the news article, fill in the mask token : " + "\n" + news + " " + masked_headline.replace('____', replace_token)
    
    if "copy" in calculation.lower():
        reasoning = 1
    else:
        reasoning = 0
    
    return {"input_prompt":input_prompt, "ans": ans, "reasoning":reasoning}

def tokenize(sentence):
    return tokenizer.encode_plus(sentence,
                                 max_length=512,
                                 padding='max_length',
                                 truncation=True,
                                 return_tensors='pt',
                                 return_attention_mask=True)

In [12]:
train_processed = [process_data(sample, replace_token='ans') for sample in tqdm(train_data)]
dev_processed = [process_data(sample, replace_token='ans', task='dev') for sample in tqdm(dev_data)]

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21157/21157 [00:00<00:00, 263657.35it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2572/2572 [00:00<00:00, 265904.61it/s]


In [13]:
train_processed[20]

{'input_prompt': "Given the news article, perform Copy(5) to fill in the mask token : \n(Jul 27, 2010  9:52 AM CDT) Previous reports estimated that Chelsea Clinton’s wedding would cost a cool $2 million, but the real number is probably more like $3 million to $5 million. Wedding experts run down the costs for the New York Daily News, from $600,000 air-conditioned tents to the $150-a-pop invitations and $100 place settings for each of the 500 guests. At the more conservative $3 million estimate, the total cost comes to $6,000 per guest. But it’s not all designer dresses, fancy food, and $15,000 port-a-potties (yes, $15,000 for  outhouses  that are much nicer than your bathroom at home—TMZ has pictures). Because of the high-profile nature of the event, security will probably run at least $200,000 (even though the White House confirms President Obama won't attend)—or more if they opt to shut down air space or pay police to monitor traffic. Overcome with Clinton wedding fever? Click here f

In [14]:
dev_processed[4]

{'input_prompt': "Given the news article, fill in the mask token : \n(Oct 16, 2014  3:02 AM CDT) Tristen Kurilla, the Pennsylvania 10-year-old who confessed to killing a 90-year-old woman over the weekend, is still in an adult prison and for now, his family thinks that could be the best place for him. His attorney withdrew a bail request yesterday saying the  family just doesn't feel comfortable for numerous reasons,  including concern for the family of Helen Novak, whom Tristen allegedly killed with a walking stick, reports WBRE. The lawyer also cited the family's worries over  the supervision, of work, of everything.  The district attorney says the boy is not in the  general population. He is not in solitary confinement and he is not in isolation. He is in a cell which I believe is next to the infirmary where he can come and go from the cell to the infirmary.  The nearest juvenile detention facility is 80 miles away. The boy has been provided with coloring books and other forms of en

In [15]:
train_tokenized = [tokenize(sample['input_prompt']) for sample in tqdm(train_processed)]
dev_tokenized = [tokenize(sample['input_prompt']) for sample in tqdm(dev_processed)]

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21157/21157 [00:17<00:00, 1235.63it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2572/2572 [00:01<00:00, 1383.59it/s]


In [16]:
'''
get input ids for all the masked token
'''
masked_train_input_ids = [tokenizer(psamples['ans'])['input_ids'][1:-1] for psamples in tqdm(train_processed)]
masked_dev_input_ids = [tokenizer(psamples['ans'])['input_ids'][1:-1] for psamples in tqdm(dev_processed)]

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21157/21157 [00:00<00:00, 21170.36it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2572/2572 [00:00<00:00, 19548.02it/s]


In [17]:
'''
MLM labels
'''
train_ans_label = [inp['input_ids'].clone().detach() for inp in train_tokenized]
dev_ans_label = [inp['input_ids'].clone().detach() for inp in dev_tokenized]

In [18]:
def find_last_index_of_continuous_subset(lst, subset):
    subset_length = len(subset)
    
    for i in range(len(lst) - subset_length, -1, -1):
        if lst[i:i + subset_length] == subset:
            return i
    
    # Return -1 if the subset is not found
    return -1

In [19]:
mask_tok_id = tokenizer(MASK_TOKEN)['input_ids'][1:-1][0]
mask_tok_id

103

In [45]:
def replace_with_mask_ids(tokenized_seq, to_mask_ids):
    
    all_masked_inds = []
    
    for i, token in tqdm(enumerate(tokenized_seq)):
        masked_inds = []
        
        total_ls = token['input_ids'][0].tolist()
        sublist = to_mask_ids[i]
    
        last_index_of_subset = find_last_index_of_continuous_subset(total_ls, sublist)
        masked_inds = list(range(last_index_of_subset, last_index_of_subset+len(sublist)))
        
        all_masked_inds.append(masked_inds)
        
        for ind in masked_inds:
            tokenized_seq[i]['input_ids'][0][ind] = mask_tok_id
        
    return tokenized_seq, all_masked_inds

In [22]:
'''
Model inputs
'''
train_inp, train_masked_inds = replace_with_mask_ids(train_tokenized, masked_train_input_ids)
dev_inp, dev_masked_inds = replace_with_mask_ids(dev_tokenized, masked_dev_input_ids)

21157it [00:01, 13482.64it/s]
2572it [00:00, 12309.54it/s]


In [34]:
'''
Reasoning labels
'''
train_reasoning_label = [sample['reasoning'] for sample in tqdm(train_processed)]
dev_reasoning_label = [sample['reasoning'] for sample in tqdm(dev_processed)]

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21157/21157 [00:00<00:00, 1388519.45it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2572/2572 [00:00<00:00, 2276376.85it/s]


In [46]:
class NumDataset(Dataset):
    
    def __init__(self, input_ids, attn_masks, mask_label, reasoning_label, masked_inds):
        'Initialization'
        self.input_ids = input_ids
        self.attention_mask = attn_masks
        self.mask_labels = mask_label
        self.reasoning_label = reasoning_label
        self.masked_inds = masked_inds
        
    def __len__(self):
        'Denotes the total number of samples'
        return len(self.input_ids)

    def __getitem__(self, idx):
        'Generates one sample of data'
        return self.input_ids[idx], self.attention_mask[idx], self.mask_labels[idx], self.reasoning_label[idx], self.masked_inds[idx]

def extract_ids_and_masks(inputs):
    
    input_ids = torch.stack([item['input_ids'] for item in inputs]).squeeze(1)
    attention_mask = torch.stack([item['attention_mask'] for item in inputs]).squeeze(1)
    
    return input_ids.to(device), attention_mask.to(device)

In [41]:
train_inp_ids, train_attn_masks = extract_ids_and_masks(train_inp)
dev_inp_ids, dev_attn_masks = extract_ids_and_masks(dev_inp)

In [42]:
params = {
    'batch_size': batch_size,
    'shuffle': shuffle
}

train_set = NumDataset(train_inp_ids, train_attn_masks, train_ans_label, train_reasoning_label, train_masked_inds)
training_generator = torch.utils.data.DataLoader(train_set, **params)

dev_set = NumDataset(dev_inp_ids, dev_attn_masks, dev_ans_label, train_reasoning_label, dev_masked_inds)
dev_generator = torch.utils.data.DataLoader(dev_set, **params)

## 3. Modeling

In [38]:
config = BertConfig.from_pretrained("bert-base-uncased", output_hidden_states=True)

In [59]:
class BinaryReasoningClassifier(nn.Module):
    """
    2-class classification model : copy, not_copy
    """

    def __init__(self, hidden):
        """
        :param hidden: BERT model output size
        """
        super(BinaryReasoningClassifier, self).__init__()
        # custom layer for binary classification
        self.binary_classification_head = nn.Sequential(
            nn.Linear(config.hidden_size, 128),
            nn.ReLU(),
            nn.Dropout(0.01),
            nn.Linear(128, 2),
            nn.LogSoftmax(dim=-1)
        )
        
    def forward(self, x):
        return self.binary_classification_head(x)

class MaskedPredictor(torch.nn.Module):
    """
    predicting origin token from masked input sequence
    n-class classification problem, n-class = vocab_size
    """

    def __init__(self, d_model, vocab_size):
        """
        :param hidden: output size of BERT model
        :param vocab_size: total vocab size
        """
        super(MaskedPredictor, self).__init__()
        self.linear = nn.Linear(d_model, vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x))
    
class NumGenModel(nn.Module):
    
    def __init__(self, config):
        
        super(NumGenModel, self).__init__()
        self.bert = BertModel(config)
        
        self.vocab_size = config.vocab_size
        self.hidden_size = config.hidden_size
        
        self.classifier = BinaryReasoningClassifier(hidden=self.hidden_size)
        self.mask_lm = MaskedPredictor(d_model=self.hidden_size, vocab_size=self.vocab_size)

    def forward(self, input_ids, attention_mask):
        
        # Get the BERT model outputs
        outputs = self.bert(input_ids, attention_mask=attention_mask)
    
        # Extract the pooled output (CLS token) for binary classification
        pooled_output = outputs.pooler_output
        last_hidden = output.last_hidden_state
        
        reasoning_preds = self.classifier(pooled_output)
        masked_token_preds = self.mask_lm(last_hidden)
        
        return reasoning_preds, masked_token_preds

In [47]:
class ScheduledOptim():
    '''A simple wrapper class for learning rate scheduling'''

    def __init__(self, optimizer, d_model, n_warmup_steps):
        self._optimizer = optimizer
        self.n_warmup_steps = n_warmup_steps
        self.n_current_steps = 0
        self.init_lr = np.power(d_model, -0.5)

    def step_and_update_lr(self):
        "Step with the inner optimizer"
        self._update_learning_rate()
        self._optimizer.step()

    def zero_grad(self):
        "Zero out the gradients by the inner optimizer"
        self._optimizer.zero_grad()

    def _get_lr_scale(self):
        return np.min([
            np.power(self.n_current_steps, -0.5),
            np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])

    def _update_learning_rate(self):
        ''' Learning rate scheduling per step '''

        self.n_current_steps += 1
        lr = self.init_lr * self._get_lr_scale()

        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr

            
class Trainer:
    def __init__(
        self, 
        model, 
        train_dataloader, 
        dev_dataloader, 
        lr=1e-4,
        weight_decay=0.01,
        betas=(0.9, 0.999),
        warmup_steps=10000,
        log_freq=10
        ):

        self.model = model
        self.train_data = train_dataloader
        self.dev_data = dev_dataloader

        # Setting the Adam optimizer with hyper-param
        self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
        self.optim_schedule = ScheduledOptim(
            self.optim, self.model.bert.d_model, n_warmup_steps=warmup_steps)

        # Using Negative Log Likelihood Loss function for predicting the masked_token
        self.criterion = torch.nn.NLLLoss(ignore_index=0)
        self.log_freq = log_freq
        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
    
    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def test(self, epoch):
        self.iteration(epoch, self.dev_data, train=False)

    def iteration(self, epoch, data_loader, train=True):
        
        avg_loss = 0.0
        total_correct = 0
        total_element = 0
        
        mode = "train" if train else "test"

        # progress bar
        data_iter = tqdm.tqdm(
            enumerate(data_loader),
            desc="EP_%s:%d" % (mode, epoch),
            total=len(data_loader),
            bar_format="{l_bar}{r_bar}"
        )

        for i, batch in data_iter:

            inp_ids, attention_mask, mask_label, reasoning_label, masked_inds = batch
            
            reasoning_pred, token_pred = self.model(input_ids=inp_ids, attention_mask=attention_mask)
            
            reasoning_pred = reasoning_pred.argmax(axis=1)
            
            next_loss = self.criterion(reasoning_pred, reasoning_label)
            
            # transpose to (m, vocab_size, seq_len) vs (m, seq_len)
            # criterion(mask_lm_output.view(-1, mask_lm_output.size(-1)), data["bert_label"].view(-1))
            masked_token_preds = token_pred[masked_inds]
            mask_label = mask_label[masked_inds]
            
            mask_loss = self.criterion(masked_token_preds, mask_label)
            # mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"])

            # 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure
            loss = next_loss + mask_loss

            # 3. backward and optimization only in train
            if train:
                self.optim_schedule.zero_grad()
                loss.backward()
                self.optim_schedule.step_and_update_lr()

            # next sentence prediction accuracy
            correct = next_sent_output.argmax(dim=-1).eq(data["is_next"]).sum().item()
            avg_loss += loss.item()
            total_correct += correct
            total_element += data["is_next"].nelement()

            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": avg_loss / (i + 1),
                "avg_acc": total_correct / total_element * 100,
                "loss": loss.item()
            }

            if i % self.log_freq == 0:
                data_iter.write(str(post_fix))
        print(
            f"EP{epoch}, {mode}: \
            avg_loss={avg_loss / len(data_iter)}, \
            total_acc={total_correct * 100.0 / total_element}"
        ) 

In [61]:
model = NumGenModel(
    config=config
)

In [None]:
trainer = Trainer(model, training_generator, dev_generator)

In [None]:
epochs = 1

for epoch in range(epochs):
    bert_trainer.train(epoch)

In [21]:
class NumericalUnderstandingModel(nn.Module):
    def __init__(self, model_name, num_labels_head1=1, num_labels_head2=2):
        super(NumericalUnderstandingModel, self).__init__()
        self.transformer = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        # Freeze the pre-trained transformer parameters
        for param in self.transformer.parameters():
            param.requires_grad = False

        # Classification head 1: Fill in the mask token
        self.head1 = nn.Linear(self.transformer.config.hidden_size, num_labels_head1)

        # Classification head 2: Copy vs. Operation
        self.head2 = nn.Linear(self.transformer.config.hidden_size, num_labels_head2)

    def forward(self, input_text):
        # Tokenize input text and get transformer outputs
        input_ids = self.tokenizer(input_text, return_tensors='pt')['input_ids']
        outputs = self.transformer(input_ids)

        # Get the representation of the [MASK] token (or other relevant token)
        mask_token_index = input_text.index('[MASK]')  # Replace with the actual token used
        mask_token_representation = outputs.last_hidden_state[:, mask_token_index, :]

        # Classification head 1: Fill in the mask token
        head1_output = self.head1(mask_token_representation)

        # Classification head 2: Copy vs. Operation
        pooled_output = outputs.pooler_output
        head2_output = self.head2(pooled_output)

        return head1_output, head2_output

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

# Define your numerical understanding dataset and DataLoader
class NumericalDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return {'text': self.texts[idx], 'label': self.labels[idx]}

numerical_texts = [...]  # Replace with your numerical dataset texts
numerical_labels = [...]  # Replace with your numerical dataset labels
numerical_dataset = NumericalDataset(numerical_texts, numerical_labels)
numerical_dataloader = DataLoader(numerical_dataset, batch_size=8, shuffle=True)

# Define your headline generation dataset and DataLoader
class HeadlineDataset(Dataset):
    def __init__(self, texts, headlines):
        self.texts = texts
        self.headlines = headlines

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return {'text': self.texts[idx], 'headline': self.headlines[idx]}

headline_texts = [...]  # Replace with your headline dataset texts
headline_headlines = [...]  # Replace with your headline dataset headlines
headline_dataset = HeadlineDataset(headline_texts, headline_headlines)
headline_dataloader = DataLoader(headline_dataset, batch_size=8, shuffle=True)

# Fine-tuning Phase 1: Numerical Understanding Task
def fine_tune_numerical_task(model, dataloader, num_epochs=3, lr=1e-5):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    model.train()
    for epoch in range(num_epochs):
        total_loss = 0.0
        for batch in tqdm(dataloader, desc=f'Epoch {epoch + 1}'):
            inputs = tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True)
            labels = torch.tensor(batch['label'])

            optimizer.zero_grad()
            outputs = model(**inputs, labels=labels)
            loss = outputs.loss
            total_loss += loss.item()

            loss.backward()
            optimizer.step()

        avg_loss = total_loss / len(dataloader)
        print(f'Epoch {epoch + 1}, Average Loss: {avg_loss}')

# Load pre-trained encoder for numerical understanding
encoder_model_name = 'bert-base-uncased'  # Replace with your desired encoder model
tokenizer = AutoTokenizer.from_pretrained(encoder_model_name)
encoder_model = AutoModelForSequenceClassification.from_pretrained(encoder_model_name, num_labels=2)

# Fine-tune the encoder on the numerical understanding task
fine_tune_numerical_task(encoder_model, numerical_dataloader)

# Fine-tuning Phase 2: Headline Generation Task
# Load pre-trained decoder for headline generation
decoder_model_name = 't5-small'  # Replace with your desired decoder model
decoder_model = AutoModelForCausalLM.from_pretrained(decoder_model_name)

# Combine the fine-tuned encoder with the pre-trained decoder
combined_model = nn.Sequential(encoder_model, decoder_model)

# Fine-tune the combined model on the headline generation task
# Note: Fine-tuning details may vary based on the decoder model; adjust as needed
def fine_tune_headline_generation(model, dataloader, num_epochs=3, lr=1e-5):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    model.train()
    for epoch in range(num_epochs):
        total_loss = 0.0
        for batch in tqdm(dataloader, desc=f'Epoch {epoch + 1}'):
            inputs = tokenizer(batch['text'], return_tensors='pt', padding=True, truncation=True)
            labels = tokenizer(batch['headline'], return_tensors='pt', padding=True, truncation=True)['input_ids']

            optimizer.zero_grad()
            outputs = model(**inputs, labels=labels)
            loss = outputs.loss
            total_loss += loss.item()

            loss.backward()
            optimizer.step()

        avg_loss = total_loss / len(dataloader)
        print(f'Epoch {epoch + 1}, Average Loss: {avg_loss}')

# Fine-tune the combined model on the headline generation task
fine_tune_headline_generation(combined_model, headline_dataloader)
