In [1]:
import torch
import torch.nn as nn
from transformers import BartTokenizer, BartConfig

class CustomBartEncoder(nn.Module):
    def __init__(self, config):
        super(CustomBartEncoder, self).__init__()
        self.embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.encoder_layers = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=config.d_model, nhead=config.encoder_attention_heads
            ),
            num_layers=config.encoder_layers
        )

    def forward(self, input_ids, attention_mask=None):
        embeddings = self.embedding(input_ids)
        if attention_mask is not None:
            # Convert to boolean tensor
            attention_mask = attention_mask.bool()
        output = self.encoder_layers(
            embeddings.transpose(0, 1), src_key_padding_mask=attention_mask
        )
        return output.transpose(0, 1)


class CustomBartDecoder(nn.Module):
    def __init__(self, config):
        super(CustomBartDecoder, self).__init__()
        self.embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.decoder_layers = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(
                d_model=config.d_model, nhead=config.decoder_attention_heads
            ),
            num_layers=config.decoder_layers
        )

    def forward(self, input_ids, encoder_output, attention_mask=None, encoder_attention_mask=None):
        embeddings = self.embedding(input_ids)
        if attention_mask is not None:
            # Convert decoder attention mask to boolean
            attention_mask = attention_mask.bool()
        if encoder_attention_mask is not None:
            # Convert encoder attention mask to boolean
            encoder_attention_mask = encoder_attention_mask.bool()
        output = self.decoder_layers(
            embeddings.transpose(0, 1),
            encoder_output.transpose(0, 1),
            tgt_key_padding_mask=attention_mask,
            memory_key_padding_mask=encoder_attention_mask
        )
        return output.transpose(0, 1)


class CustomBartModel(nn.Module):
    def __init__(self, config):
        super(CustomBartModel, self).__init__()
        self.encoder = CustomBartEncoder(config)
        self.decoder = CustomBartDecoder(config)
        self.linear = nn.Linear(config.d_model, config.vocab_size)

    def forward(self, input_ids, decoder_input_ids, attention_mask=None, decoder_attention_mask=None):
        encoder_output = self.encoder(input_ids, attention_mask=attention_mask)
        decoder_output = self.decoder(
            decoder_input_ids,
            encoder_output,
            attention_mask=decoder_attention_mask,
            encoder_attention_mask=attention_mask
        )
        logits = self.linear(decoder_output)
        return logits




# if __name__ == "__main__":
#     custom_config = BartConfig(
#       vocab_size=50265,          # Vocabulary size of the model
#       d_model=768,               # Dimensionality of the encoder/decoder layers
#       encoder_layers=4,          # Number of encoder layers
#       decoder_layers=4,          # Number of decoder layers
#       encoder_attention_heads=12, # Number of attention heads in encoder
#       decoder_attention_heads=12, # Number of attention heads in decoder
#       encoder_ffn_dim=3072,      # Feed-forward layer size in encoder
#       decoder_ffn_dim=3072,      # Feed-forward layer size in decoder
#       activation_function='gelu', # Activation function
#       max_position_embeddings=1024, # Maximum sequence length
#       dropout=0.1,               # Dropout rate
#       attention_dropout=0.1,     # Dropout rate for attention weights
#       use_cache=True             # Use cache during inference
#   )
#     run_inference(custom_config)

In [2]:
from torch.utils.data import DataLoader, Dataset
class ConversationDataset(Dataset):
    def __init__(self, dialogues, summaries, tokenizer, max_input_length=512, max_target_length=150):
        self.dialogues = dialogues
        self.summaries = summaries
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.dialogues)

    def __getitem__(self, idx):
        # Get the dialogue and summary for the given index
        dialogue = self.dialogues[idx]
        summary = self.summaries[idx]

        # Tokenize the dialogue and summary
        input_encodings = self.tokenizer(
            dialogue,
            max_length=150,  # Adjusted to match `max_position_embeddings`
            truncation=True,
            padding='max_length',
            return_tensors="pt"
        )

        target_encodings = self.tokenizer(
            summary,
            max_length=self.max_target_length,
            truncation=True,
            padding='max_length',
            return_tensors="pt"
        )

        return {
            'input_ids': input_encodings['input_ids'].squeeze(0),  # Remove the batch dimension
            'attention_mask': input_encodings['attention_mask'].squeeze(0),
            'labels': target_encodings['input_ids'].squeeze(0)  # Ensure correct shape
        }




In [3]:
%pip install datasets

from datasets import load_dataset

# Load the Gigaword dataset with custom code execution enabled
dataset = load_dataset("gigaword", trust_remote_code=True)
print(dataset)

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip available: 22.2.1 -> 24.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


DatasetDict({
    train: Dataset({
        features: ['document', 'summary'],
        num_rows: 3803957
    })
    validation: Dataset({
        features: ['document', 'summary'],
        num_rows: 189651
    })
    test: Dataset({
        features: ['document', 'summary'],
        num_rows: 1951
    })
})


In [4]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import BartTokenizer, BartConfig
from torch.optim import AdamW
from tqdm import tqdm

# Dataset for sequence-to-sequence tasks
class Seq2SeqDataset(Dataset):
    def __init__(self, tokenizer, input_texts, target_texts, max_length):
        self.tokenizer = tokenizer
        self.input_texts = input_texts
        self.target_texts = target_texts
        self.max_length = max_length

    def __len__(self):
        return len(self.input_texts)

    def __getitem__(self, idx):
        input_text = self.input_texts[idx]
        target_text = self.target_texts[idx]

        input_enc = self.tokenizer(
            input_text, max_length=self.max_length, truncation=True, padding="max_length", return_tensors="pt"
        )
        target_enc = self.tokenizer(
            target_text, max_length=self.max_length, truncation=True, padding="max_length", return_tensors="pt"
        )

        return {
            "input_ids": input_enc.input_ids.squeeze(0),
            "attention_mask": input_enc.attention_mask.squeeze(0),
            "labels": target_enc.input_ids.squeeze(0)
        }

# Training function
def train_model(model, dataloader, optimizer, tokenizer, num_epochs, device):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0

        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            # Shift decoder input ids for teacher forcing
            decoder_input_ids = labels[:, :-1]
            labels = labels[:, 1:].contiguous()

            # Forward pass
            optimizer.zero_grad()
            logits = model(input_ids, decoder_input_ids, attention_mask)

            # Compute loss
            logits = logits.view(-1, logits.size(-1))
            labels = labels.view(-1)
            loss = criterion(logits, labels)

            # Backward pass
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f"Epoch {epoch+1} Loss: {epoch_loss/len(dataloader):.4f}")

# Script for training
if __name__ == "__main__":
    # Load dataset
    from datasets import load_dataset
    # raw_dataset = load_dataset("cnn_dailymail", "3.0.0")

    # Extract dialogues (articles) and summaries
    # dialogues = raw_dataset['train']['article'][:200]
    # summaries = raw_dataset['train']['highlights'][:200]
    # Extract dialogues (news articles) and summaries
    dialogues = dataset['train']['document'][:500000]  # The news articles

    # print(dialogues)
    summaries = dataset['train']['summary'][:500000] # The summaries
    # Load the tokenizer

    # Configuration and tokenizer
    tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
    # custom_config = BartConfig(
    #     vocab_size=50265,
    #     d_model=768,
    #     encoder_layers=4,
    #     decoder_layers=4,
    #     encoder_attention_heads=6,
    #     decoder_attention_heads=6,
    #     encoder_ffn_dim=1024,
    #     decoder_ffn_dim=1024,
    #     activation_function='gelu',
    #     max_position_embeddings=1024,
    #     use_cache=True
    # )

    custom_config = BartConfig(
        vocab_size=50265,  # Adjust according to your tokenizer
        encoder_layers=4,  # Number of encoder layers
        decoder_layers=4,  # Number of decoder layers
        d_model=256,       # Dimensionality of the model
        decoder_ffn_dim=1024,  # FFN size
        encoder_ffn_dim=1024,
        max_position_embeddings=512
    )

    # Custom BART model
    model = CustomBartModel(custom_config)
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)

    # Dataset parameters
    max_input_length = 512
    max_target_length = 150
    dataset =  ConversationDataset(dialogues, summaries, tokenizer, max_input_length, max_target_length)
    dataloader = DataLoader(dataset, batch_size=6, shuffle=True)

    # Optimizer
    optimizer = AdamW(model.parameters(), lr=0.0001)

    # Training
    device = torch.device(device)
    train_model(model, dataloader, optimizer, tokenizer, num_epochs=3, device=device)

    # Save the trained model
    torch.save(model.state_dict(), "custom_bart_model.pth")
    print("Model training complete and saved!")


Epoch 1/3: 100%|███████████████████████████████████████████████████████████████| 83334/83334 [3:18:37<00:00,  6.99it/s]


Epoch 1 Loss: 2.7325


Epoch 2/3: 100%|███████████████████████████████████████████████████████████████| 83334/83334 [3:18:01<00:00,  7.01it/s]


Epoch 2 Loss: 1.5333


Epoch 3/3: 100%|███████████████████████████████████████████████████████████████| 83334/83334 [3:18:03<00:00,  7.01it/s]


Epoch 3 Loss: 1.3008
Model training complete and saved!


In [5]:
def run_inference(custom_config, model_weights_path="custom_bart_model.pth"):
    # Load tokenizer and configuration
    tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

    # Initialize custom model
    model = CustomBartModel(custom_config)

    # Load model weights
    # model.load_state_dict(torch.load(model_weights_path))
    model.eval()

    # Input text
    input_text = "The quick brown fox jumps over the lazy dog. But he trips and lands on his face."
    input_ids = tokenizer(input_text, return_tensors='pt').input_ids

    # Prepare decoder input (start with BOS token)
    decoder_input_ids = torch.tensor([[tokenizer.bos_token_id]])

    # Iterative decoding
    max_length = 50  # Define a reasonable maximum length
    output_ids = decoder_input_ids
    for _ in range(max_length):
        with torch.no_grad():
            logits = model(input_ids, output_ids)
        next_token_id = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0)

        # Break if EOS token is generated
        if next_token_id.item() == tokenizer.eos_token_id:
            break

        # Append the new token to the sequence
        output_ids = torch.cat([output_ids, next_token_id], dim=1)

    # Decode the output tokens
    output_text = tokenizer.decode(output_ids.squeeze(0), skip_special_tokens=True)

    print(f"Input: {input_text}")
    print(f"Output: {output_text}")


# Example usage
if __name__ == "__main__":
    custom_config = BartConfig(
        vocab_size=50265,  # Adjust according to your tokenizer
        encoder_layers=4,  # Number of encoder layers
        decoder_layers=4,  # Number of decoder layers
        d_model=256,       # Dimensionality of the model
        decoder_ffn_dim=1024,  # FFN size
        encoder_ffn_dim=1024,
        max_position_embeddings=512
    )


    run_inference(custom_config, model_weights_path="custom_bart_model.pth")




Input: The quick brown fox jumps over the lazy dog. But he trips and lands on his face.
Output:  peril toxic ginmone804 account diplomacy Register TRmonearious mantraFaithorks crawled804 account diplomacy Register stabilization dismantled upkeepBILL reduce TireAIDSarious mantra Santana faux Diagn Resist sky stabilization dismantled diplomacy faux Diagn Resist sky stabilization dismantled diplomacy faux Diagn Resist compuls TR 388BILL


In [6]:
import json
import pandas as pd
with open('train.json', 'r', encoding='utf-8', errors='ignore') as file:
    train = json.load(file)
# Since it's too computationally expensive to run this model I won't be doing any validation testing.
# with open('val.json', 'r', encoding='utf-8', errors='ignore') as file:
#     val = json.load(file)

with open('test.json', 'r', encoding='utf-8', errors='ignore') as file:
    test = json.load(file)

# Convert to DataFrame
df_train = pd.DataFrame(train)
# df_val = pd.DataFrame(val)
df_test = pd.DataFrame(test)

In [7]:
# Fine tuning
train_dataset = ConversationDataset(df_train['dialogue'].tolist(), df_train['summary'].tolist(), tokenizer)
dataloader = DataLoader(train_dataset, batch_size=6, shuffle=True)
train_model(model, dataloader, optimizer, tokenizer, num_epochs=3, device=device)

Epoch 1/3: 100%|███████████████████████████████████████████████████████████████████| 2456/2456 [05:56<00:00,  6.89it/s]


Epoch 1 Loss: 4.5846


Epoch 2/3: 100%|███████████████████████████████████████████████████████████████████| 2456/2456 [05:55<00:00,  6.91it/s]


Epoch 2 Loss: 3.2949


Epoch 3/3: 100%|███████████████████████████████████████████████████████████████████| 2456/2456 [05:55<00:00,  6.91it/s]

Epoch 3 Loss: 2.7698





In [9]:
def run_inference(custom_config, model_weights_path="custom_bart_model.pth"):
    # Load tokenizer and configuration
    tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

    # Initialize custom model
    model = CustomBartModel(custom_config)

    # Load model weights
    # model.load_state_dict(torch.load(model_weights_path))
    model.eval()

    # Input text
    input_text = "A: Hi Tom, are you busy tomorrow’s afternoon?\r\nB: I’m pretty sure I am. What’s up?\r\nA: Can you go with me to the animal shelter?.\r\nB: What do you want to do?\r\nA: I want to get a puppy for my son.\r\nB: That will make him so happy.\r\nA: Yeah, we’ve discussed it many times. I think he’s ready now.\r\nB: That’s good. Raising a dog is a tough issue. Like having a baby ;-) \r\nA: I'll get him one of those little dogs.\r\nB: One that won't grow up too big;-)\r\nA: And eat too much;-))\r\nB: Do you know which one he would like?\r\nA: Oh, yes, I took him there last Monday. He showed me one that he really liked.\r\nB: I bet you had to drag him away.\r\nA: He wanted to take it home right away ;-).\r\nB: I wonder what he'll name it.\r\nA: He said he’d name it after his dead hamster – Lemmy  - he's  a great Motorhead fan :-)))"
    input_ids = tokenizer(input_text, return_tensors='pt').input_ids

    # Prepare decoder input (start with BOS token)
    decoder_input_ids = torch.tensor([[tokenizer.bos_token_id]])

    # Iterative decoding
    max_length = 50  # Define a reasonable maximum length
    output_ids = decoder_input_ids
    for _ in range(max_length):
        with torch.no_grad():
            logits = model(input_ids, output_ids)
        next_token_id = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0)

        # Break if EOS token is generated
        if next_token_id.item() == tokenizer.eos_token_id:
            break

        # Append the new token to the sequence
        output_ids = torch.cat([output_ids, next_token_id], dim=1)

    # Decode the output tokens
    output_text = tokenizer.decode(output_ids.squeeze(0), skip_special_tokens=True)

    print(f"Input: {input_text}")
    print(f"Output: {output_text}")


# Example usage
if __name__ == "__main__":
    custom_config = BartConfig(
        vocab_size=50265,  # Adjust according to your tokenizer
        encoder_layers=4,  # Number of encoder layers
        decoder_layers=4,  # Number of decoder layers
        d_model=256,       # Dimensionality of the model
        decoder_ffn_dim=1024,  # FFN size
        encoder_ffn_dim=1024,
        max_position_embeddings=512
    )


    run_inference(custom_config, model_weights_path="custom_bart_model.pth")

Input: A: Hi Tom, are you busy tomorrow’s afternoon?
B: I’m pretty sure I am. What’s up?
A: Can you go with me to the animal shelter?.
B: What do you want to do?
A: I want to get a puppy for my son.
B: That will make him so happy.
A: Yeah, we’ve discussed it many times. I think he’s ready now.
B: That’s good. Raising a dog is a tough issue. Like having a baby ;-) 
A: I'll get him one of those little dogs.
B: One that won't grow up too big;-)
A: And eat too much;-))
B: Do you know which one he would like?
A: Oh, yes, I took him there last Monday. He showed me one that he really liked.
B: I bet you had to drag him away.
A: He wanted to take it home right away ;-).
B: I wonder what he'll name it.
A: He said he’d name it after his dead hamster – Lemmy  - he's  a great Motorhead fan :-)))
Output: IAL decency Sharing Britons vile Tribcrimcategory Fear veterin provocation convertingCatal invites Pierre undergoneivation extrContactIAL Tata stakeholders Fearicasagi dependencies convertingCatal 