In [1]:
# This is the initial dataset which is used in the first initial step of training after this the model should be able to complete text

In [2]:
import os

from datasets import load_dataset

from torch.utils.data import Dataset, DataLoader

from preprocess.sequencing import create_sequences
from preprocess.tokenizer import BPETokenizer

from transformer.DecoderLayer import DecoderLayer

from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm

import torch 
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SPECIAL_TOKENS = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
TEXT_COMPLETION_PATH = os.path.join("data", "text_completion.json")

train_set = load_dataset("abisee/cnn_dailymail", "3.0.0", split="train[:2%]")

# Load 5% of the validation set
valid_set = load_dataset("abisee/cnn_dailymail", "3.0.0", split="validation[:2%]")

print(f"Training set size: {len(train_set)}")
print(f"Validation set size: {len(valid_set)}")

Training set size: 5742
Validation set size: 267


In [4]:
train_articles = train_set["article"]
train_highlights = train_set["highlights"]


tokenizer = BPETokenizer(
    vocab_size=30000, min_frequency=2, special_tokens=SPECIAL_TOKENS
)

if not os.path.exists(TEXT_COMPLETION_PATH):
    tokenizer.fit(
        train_articles + train_highlights,
    )
    tokenizer.save(TEXT_COMPLETION_PATH)
else:
    tokenizer.load(TEXT_COMPLETION_PATH)

In [5]:
train_articles = [item["article"] for item in tqdm(train_set, desc="Extracting Train Articles") if item["article"] is not None]
valid_articles = [item["article"] for item in tqdm(valid_set, desc="Extracting Valid Articles") if item["article"] is not None]

def encode_article(article):
    return tokenizer.encode(article)

def parallel_encode(articles, desc):
    encoded_articles = []
    with ProcessPoolExecutor() as executor:
        futures = {executor.submit(encode_article, article): article for article in articles}
        for future in tqdm(as_completed(futures), total=len(futures), desc=desc):
            encoded_articles.append(future.result())
    return encoded_articles

train_set_encoded = parallel_encode(train_articles, "Encoding Train Set")
valid_set_encoded = parallel_encode(valid_articles, "Encoding Valid Set")

Extracting Train Articles: 100%|██████████| 5742/5742 [00:00<00:00, 11547.67it/s]
Extracting Valid Articles: 100%|██████████| 267/267 [00:00<00:00, 10474.19it/s]
Encoding Train Set: 100%|██████████| 5742/5742 [00:08<00:00, 666.79it/s]
Encoding Valid Set: 100%|██████████| 267/267 [00:00<00:00, 536.23it/s]


In [6]:
def extract_token_ids(encoded_data):
    """
    Convert each Encoding object into its list of token IDs and flatten them into a single list,
    with a progress bar showing the extraction progress.
    """
    flattened_ids = []
    for encoding in tqdm(encoded_data, desc="Extracting Token IDs"):
        flattened_ids.extend(encoding.ids)
    return flattened_ids

# Extract token IDs with progress bars for training and validation sets
train_token_ids = extract_token_ids(train_set_encoded)
valid_token_ids = extract_token_ids(valid_set_encoded)


Extracting Token IDs: 100%|██████████| 5742/5742 [00:00<00:00, 18991.07it/s]
Extracting Token IDs: 100%|██████████| 267/267 [00:00<00:00, 33581.60it/s]


In [48]:
CONTEXT_ELN = 50 # N
TARGET_ELN = 1

train_seq = create_sequences(
    tokenized_data=train_token_ids, 
    max_context_length=CONTEXT_ELN,
    max_target_length=TARGET_ELN,
)

valid_seq = create_sequences(
    tokenized_data=valid_token_ids,
    max_context_length=CONTEXT_ELN,
    max_target_length=TARGET_ELN,
)

In [49]:
print(len(train_seq))
for i, (context, target) in enumerate(train_seq):
    print(f"Context: {context[:10]}... (Total: {len(context)} tokens)") 
    print(f"Target: {target} (Total: {len(target)} token)") 
    print(f"Decoded: ...{tokenizer.decode(context)[-10:]}")
    print(f"Decoded: {tokenizer.decode(target)}")
    if i == 5:
        break


4551174
Context: [20604, 16, 1354, 467, 9296, 8931, 13, 596, 3949, 11831]... (Total: 50 tokens)
Target: [306] (Total: 1 token)
Decoded: ... Radcliffe
Decoded:  as
Context: [16, 1354, 467, 9296, 8931, 13, 596, 3949, 11831, 874]... (Total: 50 tokens)
Target: [3949] (Total: 1 token)
Decoded: ...dcliffe as
Decoded:  Harry
Context: [1354, 467, 9296, 8931, 13, 596, 3949, 11831, 874, 3812]... (Total: 50 tokens)
Target: [11831] (Total: 1 token)
Decoded: ...e as Harry
Decoded:  Potter
Context: [467, 9296, 8931, 13, 596, 3949, 11831, 874, 3812, 24211]... (Total: 50 tokens)
Target: [231] (Total: 1 token)
Decoded: ...rry Potter
Decoded:  in
Context: [9296, 8931, 13, 596, 3949, 11831, 874, 3812, 24211, 13676]... (Total: 50 tokens)
Target: [366] (Total: 1 token)
Decoded: ... Potter in
Decoded:  "
Context: [8931, 13, 596, 3949, 11831, 874, 3812, 24211, 13676, 2623]... (Total: 50 tokens)
Target: [22955] (Total: 1 token)
Decoded: ...otter in "
Decoded: Harry


In [50]:
class TextCompletionDataset(Dataset):
    def __init__(self, sequences):
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        context, target = self.sequences[idx]
        # Convert context and target to tensors
        context = torch.tensor(context, dtype=torch.long)
        target = torch.tensor(target, dtype=torch.long)
        return context, target

    
train_dataset = TextCompletionDataset(train_seq)
valid_dataset = TextCompletionDataset(valid_seq)

In [51]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False)

In [52]:
EPOCHS = 10
VOC_SIZE = tokenizer.get_vocab_size()
MAX_LEN = CONTEXT_ELN
D_MODEL = 512
FFN_HIDDEN = 2048
N_HEAD = 8
N_LAYERS = 6
DROP_PROB = 0.1

In [53]:
class TransformerModel(nn.Module):
    def __init__(
        self,
        vocab_size,
        embed_dim,
        num_layers,
        num_heads,
        ff_dim,
        max_len=5000,
        dropout=0.1,
    ):
        super().__init__()
        self.decoder = DecoderLayer(
            vocab_size, embed_dim, num_layers, num_heads, ff_dim, max_len, dropout
        )

    def forward(self, x, mask=None):
        return self.decoder(x, mask)


In [54]:
model = TransformerModel(
    vocab_size=VOC_SIZE,
    embed_dim=D_MODEL,
    num_layers=N_LAYERS,
    num_heads=N_HEAD,
    ff_dim=FFN_HIDDEN,
    max_len=MAX_LEN,
    dropout=DROP_PROB
).to(DEVICE)

print("Parameters: ", sum(p.numel() for p in model.parameters() if p.requires_grad))
print(model)

Parameters:  49664304
TransformerModel(
  (decoder): DecoderLayer(
    (embedding): InputEmbeddings(
      (embed): Embedding(30000, 512)
    )
    (positional_encoding): PositionalEncoding()
    (layers): ModuleList(
      (0-5): 6 x DecoderBlock(
        (attention): MultiHeadAttention(
          (qkv_proj): Linear(in_features=512, out_features=1536, bias=True)
          (out_proj): Linear(in_features=512, out_features=512, bias=True)
        )
        (feed_forward): FeedForward(
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (fc_out): Linear(in_features=512, out_features=30000, bias=True)
  )
)


In [56]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001)

In [62]:
def look_ahead_mask(seq_len):
    mask = (torch.triu(torch.ones(seq_len, seq_len)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
    return mask

In [66]:
def train_epoch(model, loader, criterion, optimizer):
    model.train()
    total_loss = 0
    for context, target in tqdm(loader, desc="Training"):
        context = context.to(DEVICE)
        target = target.to(DEVICE)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Create the look-ahead mask for the context sequence
        mask = look_ahead_mask(context.size(1)).to(DEVICE)
        
        # Forward pass: model should predict the next token based on context
        # Ensure that the model is only outputting one prediction per context
        output = model(context, mask)  # Expect shape: (batch_size, 1, vocab_size)

        # Extract only the last token's prediction: output should be (batch_size, vocab_size)
        output = output[:, -1, :]  # Get the last prediction along the sequence dimension
        
        # Ensure target is correctly shaped: should be (batch_size,)
        target = target.squeeze(-1)  # Remove the extra dimension from the target
        
        # Debug shapes
        print(f"Output shape: {output.shape}")  # Should be (batch_size, vocab_size)
        print(f"Target shape: {target.shape}")  # Should be (batch_size,)
        
        # Calculate loss
        loss = criterion(output, target)
        
        # Backpropagation
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    return total_loss / len(loader)

# Execute the training step
train_epoch(model, train_loader, criterion, optimizer)


Training:   0%|          | 0/142225 [00:00<?, ?it/s]

Output shape: torch.Size([32, 30000])
Target shape: torch.Size([32])


Training:   0%|          | 1/142225 [00:03<147:11:58,  3.73s/it]

Output shape: torch.Size([32, 30000])
Target shape: torch.Size([32])


Training:   0%|          | 2/142225 [00:06<136:20:29,  3.45s/it]

Output shape: torch.Size([32, 30000])
Target shape: torch.Size([32])


Training:   0%|          | 3/142225 [00:09<121:14:45,  3.07s/it]

Output shape: torch.Size([32, 30000])
Target shape: torch.Size([32])


Training:   0%|          | 4/142225 [00:12<113:04:39,  2.86s/it]

Output shape: torch.Size([32, 30000])
Target shape: torch.Size([32])


Training:   0%|          | 5/142225 [00:15<125:31:49,  3.18s/it]

Output shape: torch.Size([32, 30000])
Target shape: torch.Size([32])


Training:   0%|          | 6/142225 [00:20<144:02:27,  3.65s/it]

Output shape: torch.Size([32, 30000])
Target shape: torch.Size([32])


Training:   0%|          | 7/142225 [00:29<215:05:54,  5.44s/it]

Output shape: torch.Size([32, 30000])
Target shape: torch.Size([32])


Training:   0%|          | 8/142225 [00:32<188:52:02,  4.78s/it]

Output shape: torch.Size([32, 30000])
Target shape: torch.Size([32])


Training:   0%|          | 9/142225 [00:36<169:41:36,  4.30s/it]

Output shape: torch.Size([32, 30000])
Target shape: torch.Size([32])


Training:   0%|          | 10/142225 [00:39<157:47:26,  3.99s/it]

Output shape: torch.Size([32, 30000])
Target shape: torch.Size([32])


Training:   0%|          | 11/142225 [00:42<148:52:50,  3.77s/it]

Output shape: torch.Size([32, 30000])
Target shape: torch.Size([32])


Training:   0%|          | 12/142225 [00:46<151:08:41,  3.83s/it]

Output shape: torch.Size([32, 30000])
Target shape: torch.Size([32])


Training:   0%|          | 13/142225 [00:50<145:43:08,  3.69s/it]

Output shape: torch.Size([32, 30000])
Target shape: torch.Size([32])


Training:   0%|          | 14/142225 [00:53<143:21:21,  3.63s/it]

Output shape: torch.Size([32, 30000])
Target shape: torch.Size([32])


Training:   0%|          | 15/142225 [00:57<151:43:50,  3.84s/it]


KeyboardInterrupt: 