<a href="https://colab.research.google.com/github/GuyRobot/AINotesBook/blob/main/MiniBARTModeling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import random

# Define a simple toy dataset for text summarization
class ToyDataset(Dataset):
    def __init__(self, num_samples, max_seq_len):
        self.data = []
        for _ in range(num_samples):
            src_len = random.randint(5, max_seq_len)
            src = torch.randint(1, 100, (src_len,))  # Random source sequence
            tgt_len = random.randint(2, src_len // 2)
            tgt = torch.randint(1, 100, (tgt_len,))  # Random target summary
            self.data.append((src, tgt))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Create a toy dataset
toy_dataset = ToyDataset(num_samples=100, max_seq_len=50)
train_dataloader = DataLoader(toy_dataset, batch_size=8, shuffle=True)

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=0.1)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

# Define a basic Transformer Encoder
class TransformerEncoder(nn.Module):
    def __init__(self, d_model, nhead, num_encoder_layers):
        super(TransformerEncoder, self).__init__()
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead),
            num_encoder_layers)

    def forward(self, src):
        return self.transformer_encoder(src)

# Define a basic Transformer Decoder
class TransformerDecoder(nn.Module):
    def __init__(self, d_model, nhead, num_decoder_layers):
        super(TransformerDecoder, self).__init__()
        self.transformer_decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model, nhead),
            num_decoder_layers)

    def forward(self, tgt, memory):
        return self.transformer_decoder(tgt, memory)

# Define the BART model
class BART(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers):
        super(BART, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.encoder = TransformerEncoder(d_model, nhead, num_encoder_layers)
        self.decoder = TransformerDecoder(d_model, nhead, num_decoder_layers)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt):
        src_embed = self.positional_encoding(self.embedding(src))
        tgt_embed = self.positional_encoding(self.embedding(tgt))
        memory = self.encoder(src_embed)
        output = self.decoder(tgt_embed, memory)
        return self.fc(output)

# Example usage
vocab_size = 10000
d_model = 512
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6

model = BART(vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers)

In [None]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch_idx, (src, tgt) in enumerate(train_dataloader):
        optimizer.zero_grad()

        # Forward pass
        output = model(src, tgt[:-1])  # Exclude the last token in target

        # Compute loss
        loss = criterion(output.view(-1, vocab_size), tgt[1:].view(-1))  # Shift target by one position

        # Backpropagation
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    average_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {average_loss:.4f}")

# Save the trained model
torch.save(model.state_dict(), "bart_mini_dataset.pth")