In [1]:
from datasets import load_dataset

dataset = load_dataset("abisee/cnn_dailymail", "3.0.0")
print(dataset)

README.md: 0.00B [00:00, ?B/s]

train-00000-of-00003.parquet:   0%|          | 0.00/257M [00:00<?, ?B/s]

train-00001-of-00003.parquet:   0%|          | 0.00/257M [00:00<?, ?B/s]

train-00002-of-00003.parquet:   0%|          | 0.00/259M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/34.7M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/30.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 287113
    })
    validation: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 13368
    })
    test: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 11490
    })
})


In [2]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/684 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/760k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.31M [00:00<?, ?B/s]

In [5]:
sos_idx = tokenizer.bos_token_id
eos_idx = tokenizer.eos_token_id

In [None]:
from torch.utils.data import Dataset,DataLoader
max_len = 512
summary_len = 128
batch_size = 16
class SummarizationDataset(Dataset):
    def __init__(self,data,tokenizer,max_len,summary_len):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.summary_len = summary_len
    def __len__(self):
        return len(self.data)
    def __getitem__(self,index):
        item = self.data[index]
        dialogue = str(item['article'])
        summary = str(item['highlights'])

        encoder_inputs = tokenizer(dialogue,
                                  max_length = self.max_len,
                                  padding = 'max_length',
                                  truncation = True,
                                  return_tensors = 'pt'
        )
        tok_out = tokenizer(
            summary,
            add_special_tokens=False,
            max_length=self.summary_len - 2,  # leave room for BOS & EOS
            truncation=True,
        )
        ids = tok_out["input_ids"]          
        
        
        bos_id = tokenizer.bos_token_id
        eos_id = tokenizer.eos_token_id
        padded_length = self.summary_len
        
        decoder_input = [bos_id] + ids
        labels             = ids + [eos_id]
        # 3) pad both to EXACTLY `self.summary_len` with pad_token_id
        pad_id = tokenizer.pad_token_id
        decoder_input = decoder_input + [pad_id] * (padded_length - len(decoder_input))
        labels             = labels             + [pad_id] * (padded_length - len(labels))
        decoder_input = torch.tensor(decoder_input)
        labels = torch.tensor(labels)
        return {
            "encoder_input_ids": encoder_inputs['input_ids'].flatten(),
            "encoder_attention_mask": encoder_inputs['attention_mask'].flatten(),
            "decoder_input_ids": decoder_input,
            "labels": labels
        }

train_dataset = SummarizationDataset(dataset['train'],tokenizer,max_len = max_len,summary_len = summary_len)
train_loader = DataLoader(train_dataset,batch_size = batch_size,shuffle = True)

valid_dataset = SummarizationDataset(dataset['validation'],tokenizer,max_len = max_len,summary_len = summary_len)
validation_loader = DataLoader(valid_dataset,batch_size = batch_size,shuffle = True)

In [None]:
import torch
import torch.nn as nn
vocab_size = tokenizer.vocab_size
d_model = 256
n_heads = 4
n = 4
head_size = d_model//n_heads
max_len=512
dropout_rate = 0.2

class Embedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len=512):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        # Use max_len here
        self.positional_embedding = nn.Embedding(max_len, d_model)

    def forward(self, x):
        # Get B and T from the input tensor x INSIDE the method
        B, T = x.shape
        token_embedded = self.token_embedding(x)
        # Use T to create the positions tensor
        positional_embedded = self.positional_embedding(torch.arange(T, device=x.device))
        return token_embedded + positional_embedded

In [8]:
class SingleHeadSelfAttention(nn.Module):
    def __init__(self,d_model,head_size):
        super().__init__()
        self.head_size = head_size
        self.d_model = d_model
        self.query = nn.Linear(d_model,head_size,bias = False)
        self.key = nn.Linear(d_model,head_size,bias = False)
        self.value = nn.Linear(d_model,head_size,bias = False)
    def forward(self,x,mask = None):
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        attention_score = torch.matmul(q,k.transpose(-2,-1))/(self.head_size**0.5)
        if mask is not None:
            mask = mask.unsqueeze(1)
            attention_score = attention_score.masked_fill(mask==0,float('-inf'))
        output = torch.matmul(torch.softmax(attention_score,dim=-1),v)
        return output

In [9]:
class MultiHeadAttention(nn.Module):
    def __init__(self,n_heads,d_model,head_size):
        super().__init__()
        self.d_model = d_model
        self.head_size = head_size
        self.n_heads = n_heads
        self.multiheads = nn.ModuleList([SingleHeadSelfAttention(d_model,head_size) for head in range(n_heads)])
        self.projection_layer = nn.Linear(d_model,d_model)
    def forward(self,x,mask=None):
        output = torch.cat([h(x,mask = mask) for h in self.multiheads],dim = -1)
        output = self.projection_layer(output)
        return output

In [10]:
class MaskedSingleHeadAttention(nn.Module):
    def __init__(self,d_model, head_size, max_len):
        super().__init__()
        self.head_size = head_size
        self.query = nn.Linear(d_model,head_size,bias = False)
        self.key = nn.Linear(d_model,head_size,bias = False)
        self.value = nn.Linear(d_model,head_size,bias = False)
        # We create a large mask once and register it as a buffer.
        self.register_buffer('tril', torch.tril(torch.ones(max_len, max_len)))
    def forward(self,x):
        B,T,d_model = x.shape
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        attention_score = torch.matmul(q,k.transpose(-2,-1))/(self.head_size**0.5)
        # We slice the pre-computed buffer to match the current sequence length T.
        masked_attention_score = attention_score.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        masked_attention_score = torch.softmax(masked_attention_score,dim=-1)
        return masked_attention_score@v

In [11]:
class MaskedMultiHeadAttention(nn.Module):
    def __init__(self,n_heads,d_model,head_size,max_len):
        super().__init__()
        self.d_model = d_model
        self.head_size = head_size
        self.n_heads = n_heads
        self.multiheads = nn.ModuleList([MaskedSingleHeadAttention(d_model,head_size,max_len) for head in range(n_heads)])
        self.projection_layer = nn.Linear(d_model,d_model)
    def forward(self,x):
        output = torch.cat([h(x) for h in self.multiheads],dim = -1)
        output = self.projection_layer(output)
        return output

In [12]:
import torch.nn.functional as F
class CrossAttention(nn.Module):
    def __init__(self,d_model,head_size):
        super().__init__()
        self.head_size = head_size
        self.query = nn.Linear(d_model,head_size,bias = False)
        self.key = nn.Linear(d_model,head_size,bias = False)
        self.value = nn.Linear(d_model,head_size,bias = False)
    def forward(self,encoder_output,masked_attention,mask = None):
        q = self.query(masked_attention)
        k = self.key(encoder_output)
        v = self.value(encoder_output)
        cross_attention_score = q@k.transpose(-2,-1)/(self.head_size**0.5)
        if mask is not None:
            mask = mask.unsqueeze(1)
            cross_attention_score = cross_attention_score.masked_fill(mask ==0,float('-inf'))
        output = F.softmax(cross_attention_score,dim=-1)@v
        return output

In [13]:
class MultiHeadCrossAttention(nn.Module):
    def __init__(self,n_heads,d_model,head_size):
        super().__init__()
        self.d_model = d_model
        self.head_size = head_size
        self.n_heads = n_heads
        self.multiheads = nn.ModuleList([CrossAttention(d_model,head_size) for head in range(n_heads)])
        self.projection_layer = nn.Linear(d_model,d_model)
    def forward(self,encoder_output,masked_attention,mask=None):
        output = torch.cat([h(encoder_output,masked_attention,mask=mask) for h in self.multiheads],dim = -1)
        output = self.projection_layer(output)
        return output

In [14]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self,d_model):
        super().__init__()
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model,4*d_model),
            nn.ReLU(),
            nn.Linear(4*d_model,d_model)
        )
    def forward(self,x):
        output = self.feed_forward(x)
        return output

In [15]:
class EncoderBlock(nn.Module):
    def __init__(self,n_heads,d_model,head_size,dropout_rate):
        super().__init__()
        self.d_model = d_model
        self.head_size = head_size
        self.n_heads = n_heads
        self.dropout_rate = dropout_rate
        self.multi_head_self_att = MultiHeadAttention(self.n_heads,self.d_model,self.head_size)
        self.ffd = PositionWiseFeedForward(self.d_model)
        self.ln1 = nn.LayerNorm(self.d_model)
        self.ln2 = nn.LayerNorm(self.d_model)
        self.dropout = nn.Dropout(self.dropout_rate)
    def forward(self,x,mask = None):
        x = self.ln1(x + self.dropout(self.multi_head_self_att(x,mask = mask)))
        output = self.ln2(x + self.dropout(self.ffd(x)))
        return output

In [16]:
class Encoder(nn.Module):
    def __init__(self,n,n_heads,d_model,head_size,dropout_rate):
        super().__init__()
        self.n = n
        self.d_model = d_model
        self.head_size = head_size
        self.n_heads = n_heads
        self.dropout_rate = dropout_rate
        self.blocks = nn.ModuleList([EncoderBlock(self.n_heads,self.d_model,self.head_size,self.dropout_rate) for _ in range(self.n)])
    def forward(self,x,mask = None):
        for block in self.blocks:
            x = block(x,mask = mask)
        return x

In [17]:
class DecoderBlock(nn.Module):
    def __init__(self,n_heads,d_model,head_size,max_len,dropout_rate):
        super().__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.head_size = head_size
        self.max_len = max_len
        self.dropout_rate = dropout_rate
        self.masked_att = MaskedMultiHeadAttention(n_heads,d_model,head_size,max_len)
        self.ffd = PositionWiseFeedForward(d_model)
        self.cross_att = MultiHeadCrossAttention(n_heads,d_model,head_size)
        # self.encoder = Encoder(n,n_heads,d_model,head_size,max_len)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.ln3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout_rate)
    def forward(self,encoder_output,decoder_input,mask = None):
        masked_attention = self.ln1(decoder_input + self.dropout(self.masked_att(decoder_input)))
        cross_attention = self.ln2(masked_attention + self.dropout(self.cross_att(encoder_output,masked_attention,mask = mask)))
        ffd_output = self.ln3(cross_attention + self.dropout(self.ffd(cross_attention)))
        return ffd_output

In [18]:
class Decoder(nn.Module):
    def __init__(self,n,n_heads,d_model,head_size,max_len,dropout_rate):
        super().__init__()
        self.n = n
        self.n_heads = n_heads
        self.d_model = d_model
        self.head_size = head_size
        self.max_len = max_len
        self.dropout_rate = dropout_rate
        self.blocks = nn.ModuleList([DecoderBlock(n_heads,d_model,head_size,max_len,dropout_rate) for i in range(n)])
    def forward(self,encoder_output,decoder_input,mask = None):
        for block in self.blocks:
            decoder_input = block(encoder_output,decoder_input,mask = mask)
        output = decoder_input
        return output

In [19]:
bos_idx = tokenizer.bos_token_id  # BOS stands for "Beginning Of Sequence"
eos_idx = tokenizer.eos_token_id
class Transformer(nn.Module):
    def __init__(self,n,n_heads,d_model,head_size,max_len,dropout_rate,vocab_size):
        super().__init__()
        self.n = n
        self.n_heads = n_heads
        self.d_model = d_model
        self.head_size = head_size
        self.max_len = max_len
        self.dropout_rate = dropout_rate
        self.embedding = Embedding(vocab_size, d_model, max_len)
        #self.embedding2 = Embedding(vocab_size, d_model, max_len)
        self.encoder = Encoder(n,n_heads,d_model,head_size,dropout_rate)
        self.decoder = Decoder(n,n_heads,d_model,head_size,max_len,dropout_rate)
        self.final_layer = nn.Linear(d_model,vocab_size)
    def forward(self,dialogue,summary,mask = None):
        embedded_dialogue = self.embedding(dialogue)
        embedded_summary = self.embedding(summary)
        encoder_output = self.encoder(embedded_dialogue,mask = mask)
        decoder_output = self.decoder(encoder_output,embedded_summary, mask = mask)
        logits = self.final_layer(decoder_output)
        return logits


In [20]:
model = Transformer(n,n_heads,d_model,head_size,max_len,dropout_rate,vocab_size)

In [21]:
total_params = sum(p.numel() for p in model.parameters())
print(total_params)

22884656


In [None]:
import os
import torch
import torch.nn as nn
from torch.optim import AdamW
import time
import glob
import numpy as np


INITIAL_CHECKPOINT_PATH = "/kaggle/input/transformer-20-1/ckpt_epoch20.pt"


SAVE_CHECKPOINT_DIR = "/kaggle/working/checkpoints"
os.makedirs(SAVE_CHECKPOINT_DIR, exist_ok=True)


TOTAL_EPOCHS = 40
LEARNING_RATE = 1e-4


EARLY_STOPPING_PATIENCE = 3


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

pad_id = tokenizer.pad_token_id
criterion = nn.CrossEntropyLoss(ignore_index=pad_id, label_smoothing=0.1)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)


start_epoch = 0
latest_checkpoint_path = None
start_message = ""


session_checkpoints = sorted(glob.glob(os.path.join(SAVE_CHECKPOINT_DIR, "ckpt_epoch*.pt")))
if session_checkpoints:
    latest_checkpoint_path = session_checkpoints[-1]

elif os.path.exists(INITIAL_CHECKPOINT_PATH):
    latest_checkpoint_path = INITIAL_CHECKPOINT_PATH

if latest_checkpoint_path:
    print(f"Loading checkpoint from: {latest_checkpoint_path}")
    checkpoint = torch.load(latest_checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    
    if "working" in latest_checkpoint_path:
        start_message = f"--- Resuming training from epoch {start_epoch} (loaded from session checkpoint) "
    else:
        start_message = f"--- Starting training from epoch {start_epoch} (loaded from initial checkpoint)"
else:
    start_message = "--- No checkpoint found. Starting training from scratch (epoch 0)"


best_val_loss = float('inf')
epochs_no_improve = 0

# Print the clear, dynamic starting message
print(start_message)

for epoch in range(start_epoch, TOTAL_EPOCHS):
    start_time = time.time()

    # --- Training Phase ---
    model.train()
    total_train_loss = 0
    for batch in train_loader:
        dialogue = batch['encoder_input_ids'].to(device)
        summary = batch['decoder_input_ids'].to(device)
        labels = batch['labels'].to(device)
        dialogue_mask = batch['encoder_attention_mask'].to(device)
        
        optimizer.zero_grad()
        logits = model(dialogue, summary, mask=dialogue_mask)
        loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    
    model.eval()
    total_valid_loss = 0
    with torch.no_grad():
        for batch in validation_loader:
            dialogue = batch['encoder_input_ids'].to(device)
            summary = batch['decoder_input_ids'].to(device)
            labels = batch['labels'].to(device)
            dialogue_mask = batch['encoder_attention_mask'].to(device)

            logits = model(dialogue, summary, mask=dialogue_mask)
            loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
            total_valid_loss += loss.item()

   
    avg_train_loss = total_train_loss / len(train_loader)
    avg_valid_loss = total_valid_loss / len(validation_loader)
    epoch_duration = time.time() - start_time

    print("-" * 60)
    print(f"Epoch: {epoch + 1:02}/{TOTAL_EPOCHS} | Time: {epoch_duration:.2f}s")
    print(f"\tTrain Loss: {avg_train_loss:.4f}")
    print(f"\t Val. Loss: {avg_valid_loss:.4f}")

    
    if avg_valid_loss < best_val_loss:
        best_val_loss = avg_valid_loss
        epochs_no_improve = 0
        # Save a copy of the best model's weights
        best_model_path = os.path.join(SAVE_CHECKPOINT_DIR, "best_model.pt")
        torch.save(model.state_dict(), best_model_path)
        print(f" Validation loss improved to {best_val_loss:.4f}. Best model saved to {best_model_path}")
    else:
        epochs_no_improve += 1
        print(f" Val. loss did not improve. Count: {epochs_no_improve}/{EARLY_STOPPING_PATIENCE}")

    if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
        print(f"--- Early stopping triggered after {EARLY_STOPPING_PATIENCE} epochs with no improvement. ")
        break

    
    ckpt = {
        "epoch": epoch + 1,  # Save the *next* epoch number
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }
    ckpt_path = os.path.join(SAVE_CHECKPOINT_DIR, f"ckpt_epoch{epoch+1:02}.pt")
    torch.save(ckpt, ckpt_path)
    print(f" Saved epoch checkpoint: {ckpt_path}")

print("-" * 60)
print(f"Training complete or stopped early. Best validation loss: {best_val_loss:.4f} ")

✅ Loading checkpoint from: /kaggle/working/checkpoints/ckpt_epoch36.pt
--- Resuming training from epoch 36 (loaded from session checkpoint) ---
------------------------------------------------------------
Epoch: 37/40 | Time: 4611.58s
	Train Loss: 3.6329
	 Val. Loss: 3.5688
✅ Validation loss improved to 3.5688. Best model saved to /kaggle/working/checkpoints/best_model.pt
🔖 Saved epoch checkpoint: /kaggle/working/checkpoints/ckpt_epoch37.pt
------------------------------------------------------------
Epoch: 38/40 | Time: 4596.61s
	Train Loss: 3.6082
	 Val. Loss: 3.5610
✅ Validation loss improved to 3.5610. Best model saved to /kaggle/working/checkpoints/best_model.pt
🔖 Saved epoch checkpoint: /kaggle/working/checkpoints/ckpt_epoch38.pt
------------------------------------------------------------
Epoch: 39/40 | Time: 4601.42s
	Train Loss: 3.5881
	 Val. Loss: 3.5490
✅ Validation loss improved to 3.5490. Best model saved to /kaggle/working/checkpoints/best_model.pt
🔖 Saved epoch checkpoin

In [23]:
!zip -r /kaggle/working/checkpoint_40.zip /kaggle/working/checkpoints/ckpt_epoch40.pt


  adding: kaggle/working/checkpoints/ckpt_epoch40.pt

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


 (deflated 13%)


In [None]:
import os, torch

CHECKPOINT = "/kaggle/working/checkpoints/ckpt_epoch40.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1) instantiate your model (replace args if needed)
model = Transformer(n,n_heads,d_model,head_size,max_len,dropout_rate,vocab_size)           
# 2) load checkpoint (handles DataParallel 'module.' prefix)
ckpt = torch.load(CHECKPOINT, map_location=device)
state = ckpt.get("model_state_dict", ckpt)
state = {k.replace("module.", ""): v for k, v in state.items()}
model.load_state_dict(state)
model.to(device).eval()

Transformer(
  (embedding): Embedding(
    (token_embedding): Embedding(30000, 256)
    (positional_embedding): Embedding(512, 256)
  )
  (encoder): Encoder(
    (blocks): ModuleList(
      (0-3): 4 x EncoderBlock(
        (multi_head_self_att): MultiHeadAttention(
          (multiheads): ModuleList(
            (0-3): 4 x SingleHeadSelfAttention(
              (query): Linear(in_features=256, out_features=64, bias=False)
              (key): Linear(in_features=256, out_features=64, bias=False)
              (value): Linear(in_features=256, out_features=64, bias=False)
            )
          )
          (projection_layer): Linear(in_features=256, out_features=256, bias=True)
        )
        (ffd): PositionWiseFeedForward(
          (feed_forward): Sequential(
            (0): Linear(in_features=256, out_features=1024, bias=True)
            (1): ReLU()
            (2): Linear(in_features=1024, out_features=256, bias=True)
          )
        )
        (ln1): LayerNorm((256,), eps=1e

In [None]:
@torch.no_grad()
def generate(model: Transformer, dialogue: str, max_len: int = 200) -> str:
    tok_out = tokenizer(dialogue, return_tensors="pt").to(device)
    encoder_ids = tok_out["input_ids"]
    output_ids = torch.tensor([tokenizer.bos_token_id], device=device).unsqueeze(0)

    for _ in range(max_len):
        logits = model(encoder_ids, output_ids)
        next_logits = logits[:, -1, :]
        probs = nn.functional.softmax(next_logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        output_ids = torch.cat([output_ids, next_id], dim=-1)
        if next_id.item() == eos_idx:
            break

    return tokenizer.decode(output_ids.squeeze(), skip_special_tokens=True)

# later, after reload…
hightlights = generate(model, test_article)
print(hightlights)
